拔草,给训练集验证集和测试集的多类图像制作tfrecords文件

明天吧再来一个读取tfrecords的,昨天做的时候遇到了问题,电脑不行,老显示一些库函数不存在,其实库已经导入进去了,但是python就是这样,所以还没入坑的小伙伴去学caffe吧。不要被python毒害了。把代码粘上,有几个函数是没有用的,看之前大神的帖子上的,他做了好多函数来测试他的records有没有做成功,就是厉害,大神就是大神。

import tensorflow as tf
import numpy as np
import os
import random
from PIL import Image

def _int64_feature(label):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))

def _bytes_feature(imgdir):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[imgdir]))

def float_list_feature(value):
  return tf.train.Feature(float_list=tf.train.FloatList(value=value))

def get_example_nums(tf_records_filenames):
    nums= 0
    for record in tf.python_io.tf_record_iterator(tf_records_filenames):
        nums += 1
    return nums

def get_example_num(records_file_dir):
    nums=0
    for record in tf.io.tf_record_iterator(records_file_dir):
        nums+=1
    return nums

def load_file(imagestxtdir,shuffle=False):
    images=[]#存储各个集中图像地址的列表
    labels=[]
    with open(imagestxtdir) as f:
        lines_list=f.readlines()#读取文件列表中所有的行
        if shuffle:
            random.shuffle(lines_list)#将图像库中的图像地址进行随机的打乱
        for line in lines_list:
            line_list=line.rstrip().split(' ')#rstrip函数是将每一行首尾的空白都去除然后
            label=[]
            for i in range(1):
                label.append(int(line_list[i+1]))
            #cur_img_dir=images_base_dir+'/'+line_list[0]
            images.append(line_list[0])
            labels.append(label)
    return images,labels

def get_batch_images(images,labels,batch_size,labels_num,one_hot=False,shuffle=False,num_threads=1):
    min_after_dequeue=200
    capacity=min_after_dequeue+3*batch_size
    if shuffle:
        images_batch,labels_batch=tf.train.shuffle_batch([images,labels],
                                                         batch_size=batch_size,
                                                         capacity=capacity,
                                                         min_after_dequeue=min_after_dequeue,
                                                         num_threads=num_threads)
    else:
        images_batch,labels_batch=tf.train.batch([images,labels],
                                                 batch_size=batch_size,
                                                 num_threads=num_threads,
                                                 capacity=capacity)
    if one_hot:
        labels_batch=tf.one_hot(labels_batch,labels_num,1,0)
    return images_batch,labels_batch


def create_tf_records(image_base_dir,image_txt_dir,tfrecords_dir,resise_height,resize_weight,shuffle,log=5):
    images_list,labels_list=load_file(image_txt_dir,shuffle)
    writer=tf.io.TFRecordWriter(tfrecords_dir)
    for i,[image_name,single_label_list] in enumerate(zip(images_list,labels_list)):
        cur_image_dir=image_base_dir+'/'+images_list[i]
        if not os.path.exists(cur_image_dir):
            print('the image path is not exists')
            continue
        image=Image.open(cur_image_dir)
        image=image.resize((resise_height,resize_weight))
        image_raw=image.tobytes()
        single_label=single_label_list[0]
        if i % log == 0 or i == len(images_list) - 1:
            print('------------processing:%d-th------------' % (i))
        example=tf.train.Example(features=tf.train.Features(feature={
            'image_raw':_bytes_feature(image_raw),
            'label':_int64_feature(single_label)
        }))
        writer.write(example.SerializeToString())
    writer.close()




if __name__=='__main__':
    resize_height=224
    resize_width=224
    shuffle=True
    log=5

    train_image_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/train'
    train_txt_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/train.txt'
    train_records_dir='D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/record/train.tfrecords'
    create_tf_records(train_image_dir,train_txt_dir,train_records_dir,resize_height,resize_width,shuffle,log)
    train_nums=get_example_nums(train_records_dir)
    print('the train records number is:',train_nums)

    validation_image_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/validation'
    validation_txt_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/validation.txt'
    validation_records_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/record/validation.tfrecords'
    create_tf_records(validation_image_dir,validation_txt_dir,validation_records_dir,resize_height, resize_width, shuffle, log)
    validation_nums = get_example_nums(validation_records_dir)
    print('the validation records number is:', validation_nums)

    test_image_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/test'
    test_txt_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/test.txt'
    test_records_dir = 'D:/软件/pycharmProject/wenyuPy/Dataset/VGG16/record/test.tfrecords'
    create_tf_records(test_image_dir, test_txt_dir, test_records_dir, resize_height, resize_width, shuffle, log)
    test_nums = get_example_nums(test_records_dir)
    print('the test records number is:', test_nums)

这个是我自己电脑的环境,就五类图像,如果你想做很多类的也行,都是一个道理,改一下路径就可以,注释懒得写了,因为代码写得比较简单,哈哈哈,想转的话随便转,但是真正想学的人还是得自己敲,但是我的博客写的很一般我估计没有人看应该

原文地址:https://www.cnblogs.com/daremosiranaihana/p/11429560.html