读取tfrecord 代码——可用任意照片均可2

代码

  1 # -*- coding: utf-8 -*-
  2 # @Time    : 2018/12/1 11:06
  3 # @Author  : MaochengHu
  4 # @Email   : wojiaohumaocheng@gmail.com
  5 # @File    : read_tfrecord.py
  6 # @Software: PyCharm
  7 import os
  8 import tensorflow as tf
  9 flags = tf.app.flags
 10 flags.DEFINE_string('tfrecord_path', '/data1/humaoc_file/classify/data/train_tfrecord/train.record',
 11                     'path to tfrecord file')
 12 flags.DEFINE_integer('resize_height', 800, 'resize height of image')
 13 flags.DEFINE_integer('resize_width', 800, 'resize width of image')
 14 FLAG = flags.FLAGS
 15 slim = tf.contrib.slim
 16 
 17 def print_data(image, resized_image, label, height, width):
 18     with tf.Session() as sess:
 19         init_op = tf.global_variables_initializer()
 20         sess.run(init_op)
 21         coord = tf.train.Coordinator()
 22         threads = tf.train.start_queue_runners(coord=coord)
 23         for i in range(20):
 24             print("______________________image({})___________________".format(i))
 25             print_image, print_resized_image, print_label, print_height, print_width = sess.run(
 26                 [image, resized_image, label, height, width])
 27             print("resized_image shape is: ", print_resized_image.shape)
 28             print("image shape is: ", print_image.shape)
 29             print("image label is: ", print_label)
 30             print("image height is: ", print_height)
 31             print("image width is: ", print_width)
 32         coord.request_stop()
 33         coord.join(threads)
 34 
 35 def reshape_same_size(image, output_height, output_width):
 36     """Resize images by fixed sides.
 37 
 38     Args:
 39         image: A 3-D image `Tensor`.
 40         output_height: The height of the image after preprocessing.
 41         output_ The width of the image after preprocessing.
 42 
 43     Returns:
 44         resized_image: A 3-D tensor containing the resized image.
 45     """
 46     output_height = tf.convert_to_tensor(output_height, dtype=tf.int32)
 47     output_width = tf.convert_to_tensor(output_width, dtype=tf.int32)
 48 
 49     image = tf.expand_dims(image, 0)
 50     resized_image = tf.image.resize_nearest_neighbor(
 51         image, [output_height, output_width], align_corners=False)
 52     resized_image = tf.squeeze(resized_image)
 53     return resized_image
 54 
 55 def read_tfrecord(tfrecord_path, num_samples=14635, num_classes=7, resize_height=800, resize_width=800):
 56     keys_to_features = {
 57         'image/encoded': tf.FixedLenFeature([], default_value='', dtype=tf.string, ),
 58         'image/format': tf.FixedLenFeature([], default_value='jpeg', dtype=tf.string),
 59         'image/class/label': tf.FixedLenFeature([], tf.int64, default_value=0),
 60         'image/height': tf.FixedLenFeature([], tf.int64, default_value=0),
 61         'image/width': tf.FixedLenFeature([], tf.int64, default_value=0)
 62     }
 63 
 64     items_to_handlers = {
 65         'image': slim.tfexample_decoder.Image(image_key='image/encoded', format_key='image/format', channels=3),
 66         'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[]),
 67         'height': slim.tfexample_decoder.Tensor('image/height', shape=[]),
 68         'width': slim.tfexample_decoder.Tensor('image/width', shape=[])
 69     }
 70     decoder = slim.tfexample_decoder.TFExampleDecoder(keys_to_features, items_to_handlers)
 71 
 72     labels_to_names = None
 73     items_to_descriptions = {
 74         'image': 'An image with shape image_shape.',
 75         'label': 'A single integer between 0 and 9.'}
 76 
 77     dataset = slim.dataset.Dataset(
 78         data_sources=tfrecord_path,
 79         reader=tf.TFRecordReader,
 80         decoder=decoder,
 81         num_samples=num_samples,
 82         items_to_descriptions=None,
 83         num_classes=num_classes,
 84     )
 85     provider = slim.dataset_data_provider.DatasetDataProvider(dataset=dataset,
 86                                                               num_readers=3,
 87                                                               shuffle=True,
 88                                                               common_queue_capacity=256,
 89                                                               common_queue_min=128,
 90                                                               seed=None)
 91     image, label, height, width = provider.get(['image', 'label', 'height', 'width'])
 92     resized_image = tf.squeeze(tf.image.resize_bilinear([image], size=[resize_height, resize_width]))
 93     return resized_image, label, image, height, width
 94 
 95 if __name__ == '__main__':
 96     resized_image, label, image, height, width = read_tfrecord(tfrecord_path='train.record',
 97                                                                resize_height=800,
 98                                                                resize_width=800)
 99     # resized_image = reshape_same_size(image, FLAG.resize_height, FLAG.resize_width)
100     # resized_image = tf.squeeze(tf.image.resize_bilinear([image], size=[FLAG.resize_height, FLAG.resize_width]))
101     print_data(image, resized_image, label, height, width)
102 
103     init_g = tf.global_variables_initializer()
104     init_l = tf.local_variables_initializer()
105     with tf.Session() as sess:
106         sess.run(init_g)
107         sess.run(init_l)
108         tf.train.start_queue_runners(sess)
109         print("SDDFA")
110         trX = image.eval(session=sess)
111         trY = label.eval(session=sess)
112     print("AA")
113     print(trX.shape)
原文地址:https://www.cnblogs.com/smartisn/p/12438866.html