tensorflow 相关

1.从checkpoint中获取全部的变量名和变量值

tf.contrib.framework.list_variables(model_dir)
tf.contrib.framework.load_variable(model_dir, var_name)

2.清除 tf.Session

tf.reset_default_graph() 重置计算图

3。 使用tf_record

numpy数据可以直接制作dataset ds = tf.data.Dataset.from_tensor_slices(trg)

正常情况的话

 1 a = np.random.randint(0,10,(10))
 2  2 b = np.random.rand(10,20)
 3  3 a1 = a.tobytes()
 4  4 b1 = b.tobytes()
 5  5 writer= tf.python_io.TFRecordWriter("./tfr/train.tfrecords")
 6  6 example = tf.train.Example(features=tf.train.Features(feature={
 7  7             "soft_targets": tf.train.Feature(bytes_list=tf.train.BytesList(value=[b1])),
 8  8             'src_wids': tf.train.Feature(bytes_list=tf.train.BytesList(value=[a1]))
 9  9         }))
10 10 for _ in range(10):
11 11     writer.write(example.SerializeToString())
12 12 writer.close()
View Code
 1  def _parse_function(example_proto):
 2  2     features = tf.parse_single_example(
 3  3         example_proto,
 4  4         features={
 5  5             'src_wids': tf.FixedLenFeature([], tf.string),
 6  6             'soft_targets': tf.FixedLenFeature([], tf.string)
 7  7         }
 8  8     )
 9  9     # 取出我们需要的数据(标签,图片)
10 10     label = features['soft_targets']
11 11     feature = features['src_wids']
12 12     label = tf.decode_raw(label, tf.float32)
13 13     feature = tf.decode_raw(feature, tf.int64)
14 14     return feature, label
15 15 
16 16 dataset = tf.contrib.data.TFRecordDataset("./tfr/train.tfrecords")
17 17 dataset = dataset.map(_parse_function)
18 18 dataset = dataset.batch(2)
19 19 iterator = dataset.make_initializable_iterator()
20 20 with tf.Session() as sess:
21 21     sess.run(iterator.initializer)
22 22     ids,trg = sess.run(iterator.get_next())
View Code
  1     #Author:Anthony  
  2     #导入相应的模块  
  3     import tensorflow as tf  
  4     import os  
  5     import random  
  6     import math  
  7     import sys  
  8     #划分验证集训练集  
  9     _NUM_TEST = 40  
 10     #random seed  
 11     _RANDOM_SEED = 0  
 12     #数据块  
 13     _NUM_SHARDS = 2  
 14     #数据集路径  
 15     DATASET_DIR = '/home/anthony/文档/数据集_带标签/SHIYAN_SAMEZIZE'  
 16     #标签文件  
 17     LABELS_FILENAME = '/home/anthony/文档/数据集_带标签/SHIYAN_SAMEZIZE_labels.txt'  
 18     #定义tfrecord 的路径和名称  
 19     def _get_dataset_filename(dataset_dir,split_name,shard_id):  
 20         output_filename = 'image_%s_%05d-of-%05d.tfrecord' % (split_name,shard_id,_NUM_SHARDS)  
 21         return os.path.join(dataset_dir,output_filename)  
 22     #判断tfrecord文件是否存在  
 23     def _dataset_exists(dataset_dir):  
 24         for split_name in ['train','test']:  
 25             for shard_id in range(_NUM_SHARDS):  
 26                 #定义tfrecord的路径名字  
 27                 output_filename = _get_dataset_filename(dataset_dir,split_name,shard_id)  
 28             if not tf.gfile.Exists(output_filename):  
 29                 return False  
 30         return True  
 31     #获取图片以及分类  
 32     def _get_filenames_and_classes(dataset_dir):  
 33         #数据目录  
 34         directories = []  
 35         #分类名称  
 36         class_names = []  
 37         for filename in os.listdir(dataset_dir):  
 38             #合并文件路径  
 39             path = os.path.join(dataset_dir,filename)  
 40             #判断路径是否是目录  
 41             if os.path.isdir(path):  
 42                 #加入数据目录  
 43                 directories.append(path)  
 44                 #加入类别名称  
 45                 class_names.append(filename)  
 46         photo_filenames = []  
 47         #循环分类的文件夹  
 48         for directory in directories:  
 49             for filename in os.listdir(directory):  
 50                 path = os.path.join(directory,filename)  
 51                 #将图片加入图片列表中  
 52                 photo_filenames.append(path)  
 53         #返回结果  
 54         return photo_filenames ,class_names  
 55     def int64_feature(values):  
 56         if not isinstance(values,(tuple,list)):  
 57             values = [values]  
 58         return tf.train.Feature(int64_list=tf.train.Int64List(value=values))  
 59     def bytes_feature(values):  
 60         return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))  
 61     #图片转换城tfexample函数  
 62     def image_to_tfexample(image_data,image_format,class_id):  
 63         return tf.train.Example(features=tf.train.Features(feature={  
 64             'image/encoded': bytes_feature(image_data),  
 65             'image/format': bytes_feature(image_format),  
 66             'image/class/label': int64_feature(class_id)  
 67         }))  
 68     def write_label_file(labels_to_class_names,dataset_dir,filename=LABELS_FILENAME):  
 69         label_filename = os.path.join(dataset_dir,filename)  
 70         with tf.gfile.Open(label_filename,'w') as f:  
 71             for label in labels_to_class_names:  
 72                 class_name = labels_to_class_names[label]  
 73                 f.write('%d:%s
' % (label, class_name))  
 74     #数据转换城tfrecorad格式  
 75     def _convert_dataset(split_name,filenames,class_names_to_ids,dataset_dir):  
 76         assert split_name in ['train','test']  
 77         #计算每个数据块的大小  
 78         num_per_shard = int(len(filenames) / _NUM_SHARDS)  
 79         with tf.Graph().as_default():  
 80             with tf.Session() as sess:  
 81                 for shard_id in range(_NUM_SHARDS):  
 82                 #定义tfrecord的路径名字  
 83                     output_filename = _get_dataset_filename(dataset_dir,split_name,shard_id)  
 84                     with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:  
 85                         #每个数据块开始的位置  
 86                         start_ndx = shard_id * num_per_shard  
 87                         #每个数据块结束的位置  
 88                         end_ndx = min((shard_id+1) * num_per_shard,len(filenames))  
 89                         for i in range(start_ndx,end_ndx):  
 90                             try:  
 91                                 sys.stdout.write('
>> Converting image %d/%d shard %d '% (i+1,len(filenames),shard_id))  
 92                                 sys.stdout.flush()  
 93                                 #读取图片  
 94                                 image_data = tf.gfile.FastGFile(filenames[i],'rb').read()  
 95                                 #获取图片的类别名称  
 96                                 #basename获取图片路径最后一个字符串  
 97                                 #dirname是除了basename之外的前面的字符串路径r  
 98                                 class_name = os.path.basename(os.path.dirname(filenames[i]))  
 99                                 #获取图片的id  
100                                 class_id = class_names_to_ids[class_name]  
101                                 #生成tfrecord文件  
102                                 example = image_to_tfexample(image_data,b'jpg',class_id)  
103                                 #写入数据  
104                                 tfrecord_writer.write(example.SerializeToString())  
105                             except IOError  as e:  
106                                 print ('could not read:',filenames[1])  
107                                 print ('error:' , e)  
108                                 print ('skip it 
')  
109         sys.stdout.write('
')  
110         sys.stdout.flush()  
111       
112     if __name__ == '__main__':  
113         #判断tfrecord文件是否存在  
114         if _dataset_exists(DATASET_DIR):  
115             print ('tfrecord exists')  
116         else:  
117             #获取图片以及分类  
118             photo_filenames,class_names = _get_filenames_and_classes(DATASET_DIR)  
119             #将分类的list转换成dictionary{‘house':3,'flowers:2'}  
120             class_names_to_ids = dict(zip(class_names,range(len(class_names))))  
121             #切分数据为测试训练集  
122             random.seed(_RANDOM_SEED)  
123             random.shuffle(photo_filenames)  
124             training_filenames = photo_filenames[_NUM_TEST:]  
125             testing_filenames = photo_filenames[:_NUM_TEST]  
126             #数据转换  
127             _convert_dataset('train',training_filenames,class_names_to_ids,DATASET_DIR)  
128             _convert_dataset('test',testing_filenames,class_names_to_ids,DATASET_DIR)  
129             #输出lables文件  
130             #与前面的 class_names_to_ids中的元素位置相反{1:'people,2:'flowers'}  
131             labels_to_class_names = dict(zip(range(len(class_names)),class_names))  
132             write_label_file(labels_to_class_names,DATASET_DIR)
View Code

 4.tf,argmax(data, axis)

获取最大元素的索引

-------------------------------------------keras--------------------------------------

1 . Keras---text.Tokenizer: https://blog.csdn.net/lovebyz/article/details/77712003




原文地址:https://www.cnblogs.com/wb-learn/p/11596980.html