window环境pycharm中使用cityscapes数据集训练deeplabv3+经验总结

由于现在的教程大都是linux环境下deeplabv3+的实现,并且很多都是使用的voc数据集,因此本人在windows中使用cityscapes数据集训练deeplabv3+的过程中遇到了很多问题,查阅了很多前辈和大佬的博文才能够实现,在此我对整个训练过程中遇到的问题进行了整理。由于问题较多,没有分先问题出现的先后问题。

在最开始下载deeplab源码时,选择不同branch可能会导致不同问题,参见Issue #6567。本人最开始选择了master branch,最后出现 eval.py  vis.py 不出结果,及tensorflow:Waiting for new checkpoint at ...问题,切换r1.12.0 branch可用。


在此首先提醒各位读者,在实现deeplab之前一定要先阅读官方提供的文档以及文件内容,从官网下载下来的源代码中默认的大都是基于voc数据集的,并且不能像在linux中使用命令直接设置参数,因此很多参数需要我们手动去修改,否则我们会走很多弯路,遇到各式各样的问题。

官方提供的cityscapes数据集训练教程:https://github.com/tensorflow/models/blob/r1.12.0/research/deeplab/g3doc/cityscapes.md

参数设置:

1.train.py

其中 model_variant 在common文件中,将其修改为 xception_65 的同时将 decoder_output_stride 设置为4。

在训练时,batch_size和 crop_size 要根据自己的电脑显存而定,由于本人机子较为落后,2g的独显,因此将 train_batch_size 设置为1,fine_tune_batch_norm 设置为falsetrain_crop_size设置为[321,321],其中train_crop_size最小为321。如果仅为测试,training_number_of_steps 可以设置小一点,比如1000,否则会训练很长时间。

tf_initial_checkpoint 为预训练模型路径,可在 https://github.com/tensorflow/models/blob/r1.12.0/research/deeplab/g3doc/model_zoo.md 中下载,大小为439兆。

train_logdir 为检查点保存路径,使用官方提供的目录结构可保存在 cityscapes/exp/train_on_train_set/train 目录中。

使用 xception_65 将 output_stride 设置为16,atrous_rates 设置为 [6, 12, 18]。

dataset 改为 cityscapes。dataset_dir 为读取数据集的路径,及tfrecord保存路径。

2.eval.py 和 vis.py

这两个文件中的大部分参数和train.py保持一致,个别参数在下方作出说明:

checkpoint_dir 为检查点的路径,及train.py中的 cityscapes/exp/train_on_train_set/train 目录。

eval_logdir 和 vis_logdir 为写入评估事件的目录,分别保存在 cityscapes/exp/train_on_train_set/eval 和 cityscapes/exp/train_on_train_set/vis 中。

eval_crop_size 和 vis_crop_size 设置为读入图片的大小,cityscapes数据集为[1025,2049]。


其他问题及解决方法:

问题:ModuleNotFoundError: No module named 'nets'No module named 'deployment'

在运行model_test和train时会现,这两个文件在models/research文件夹下,将其添加到环境中即可。或者直接将其中用到的文件复制到外部库中。

问题:InvalidArgumentError (see above for traceback): padded_shape[0]=49 is not divisible by block_shape[0]=2

官方默认给的crop_size为[1025,2049]为测试的原图片的大小,如果将其更改可能会出现此问题。

问题:data split name train not recognized

此问题出现在master分支中,出现的原因为代码中已经没有“train”这个变量,而是train_fine,后面的eval和vis同理。此时需要把生成的tfrecord文件名修改一下

 改为  

如果使用r1.12.0分支则没有此问题。

问题:OOM when allocating tensor with shape ... and type ...

出现原因:显卡内存不够,可将batch_size或crop_size调小

问题:lhs shape= [1,1,512,256] rhs shape= [1,1,1280,256]

出现原因可能是由于export_model中atrous_rates参数没有设置

问题:tensorflow:Waiting for new checkpoint at ...

master分支下运行eval和vis出现的问题,具体原因不清楚,可使用r1.12.0分支下的源代码


使用导出的模型进行测试:

其中官方给出了deeplab_demo.ipynb,大家可以将其转换为 py 文件,或从网上直接查询其 py 源代码,将其中图片路径和模型路径修改为自己本地的存储目录。并将其中类别和颜色修改为cityscapes数据集的。具体可参考 https://blog.csdn.net/zz2230633069/article/details/84591532

修改后的文件:

  1 # -*- coding: utf-8 -*-
  2 import os
  3 
  4 from matplotlib import gridspec
  5 from matplotlib import pyplot as plt
  6 import numpy as np
  7 from PIL import Image
  8 
  9 import tensorflow as tf
 10 from tensorflow import ConfigProto
 11 from tensorflow import InteractiveSession
 12 
 13 config = ConfigProto()
 14 config.gpu_options.allow_growth = True
 15 session = InteractiveSession(config=config)
 16 
 17 
 18 #这个地方指定输出的模型路径
 19 TEST_PB_PATH    = './output_model/frozen_inference_graph.pb'
 20 
 21 #这个地方指定需要测试的图片
 22 TEST_IMAGE_PATH = "./image/1.jpg"
 23 
 24 
 25 class DeepLabModel(object):
 26   INPUT_TENSOR_NAME  = 'ImageTensor:0'
 27   OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
 28   INPUT_SIZE         = 513
 29   FROZEN_GRAPH_NAME  = 'frozen_inference_graph'
 30 
 31   def __init__(self):
 32     self.graph = tf.Graph()
 33 
 34     graph_def = None
 35 
 36     with open(TEST_PB_PATH, 'rb') as fhandle:
 37         graph_def = tf.GraphDef.FromString(fhandle.read())
 38 
 39     if graph_def is None:
 40       raise RuntimeError('Cannot find inference graph in tar archive.')
 41 
 42     with self.graph.as_default():
 43       tf.import_graph_def(graph_def, name='')
 44 
 45     self.sess = tf.Session(graph=self.graph)
 46 
 47   def run(self, image):
 48     width, height = image.size
 49     resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
 50     target_size = (int(resize_ratio * width), int(resize_ratio * height))
 51     resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)
 52     batch_seg_map = self.sess.run(
 53         self.OUTPUT_TENSOR_NAME,
 54         feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})
 55     seg_map = batch_seg_map[0]
 56     return resized_image, seg_map
 57 
 58 
 59 def create_pascal_label_colormap():
 60   return np.asarray([
 61     [128, 64, 128],
 62     [244, 35, 232],
 63     [70, 70, 70],
 64     [102, 102, 156],
 65     [190, 153, 153],
 66     [153, 153, 153],
 67     [250, 170, 30],
 68     [220, 220, 0],
 69     [107, 142, 35],
 70     [152, 251, 152],
 71     [70, 130, 180],
 72     [220, 20, 60],
 73     [255, 0, 0],
 74     [0, 0, 142],
 75     [0, 0, 70],
 76     [0, 60, 100],
 77     [0, 80, 100],
 78     [0, 0, 230],
 79     [119, 11, 32],
 80   ])
 81 
 82 
 83 def label_to_color_image(label):
 84   if label.ndim != 2:
 85     raise ValueError('Expect 2-D input label')
 86 
 87   colormap = create_pascal_label_colormap()
 88 
 89   if np.max(label) >= len(colormap):
 90     raise ValueError('label value too large.')
 91 
 92   return colormap[label]
 93 
 94 
 95 def vis_segmentation(image, seg_map):
 96   plt.figure(figsize=(15, 5))
 97   grid_spec = gridspec.GridSpec(1, 4, width_ratios=[6, 6, 6, 1])
 98 
 99   plt.subplot(grid_spec[0])
100   plt.imshow(image)
101   plt.axis('off')
102   plt.title('input image')
103 
104   plt.subplot(grid_spec[1])
105   seg_image = label_to_color_image(seg_map).astype(np.uint8)
106   plt.imshow(seg_image)
107   plt.axis('off')
108   plt.title('segmentation map')
109 
110   plt.subplot(grid_spec[2])
111   plt.imshow(image)
112   plt.imshow(seg_image, alpha=0.7)
113   plt.axis('off')
114   plt.title('segmentation overlay')
115 
116   unique_labels = np.unique(seg_map)
117   ax = plt.subplot(grid_spec[3])
118   plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation='nearest')
119   ax.yaxis.tick_right()
120   plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
121   plt.xticks([], [])
122   ax.tick_params(width=0.0)
123   plt.grid('off')
124   plt.show()
125 
126 
127 LABEL_NAMES = np.asarray([
128     'road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'traffic light',
129     'traffic sign', 'vegetation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck',
130     'bus', 'train', 'motorcycle', 'bicycle'
131 ])
132 
133 FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
134 FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
135 
136 
137 MODEL = DeepLabModel()
138 print('model loaded successfully!')
139 
140 
141 def run_visualization(path):
142     oringnal_im = Image.open(path)
143     print('running deeplab on image %s...' % path)
144     resized_im, seg_map = MODEL.run(oringnal_im)
145     vis_segmentation(resized_im, seg_map)
146 
147 run_visualization(TEST_IMAGE_PATH)
deeplab_demo.py测试代码
原文地址:https://www.cnblogs.com/yuanxiaochou/p/12800552.html