python caffe 在师兄的代码上修改成自己风格的代码

首先,感谢师兄的帮助。师兄的代码封装成类,流畅精美,容易调试。我的代码是堆积成的,被师兄嘲笑说写脚本。好吧!我的代码只有我懂,哈哈! 希望以后代码能写得工整点。现在还是让我先懂。这里,我做了一个简单的任务:0,1,2三个数字的分类。准确率:0.9806666666666667

(部分)代码分为:

1 train_net.py

 1 #import some module
 2 import time
 3 import os
 4 import numpy as np
 5 import sys
 6 import cv2
 7 sys.path.append("/home/wang/Downloads/caffe-master/python")
 8 import caffe
 9 #from prepare_data import DataConfig
10 #from data_config import DataConfig
11 
12 #configure GPU mode
13 ''' uncommend below line to use gpu '''
14 caffe.set_mode_gpu()
15 
16 # about dataset
17 ##dataset = Dataset('/home/wang/Downloads/object/extract/')
18 ##dataset = dataset.Split('train')
19 ##data_config = DataConfig(dataset)
20 ##data_config.SetBatchSize(256)
21 data_config='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/train/'
22 
23 
24 
25 #configure solve.prototxt
26 solver = caffe.SGDSolver('models/solver.prototxt')
27 
28 # load pretrain model
29 print('load pretrain model')
30 solver.net.copy_from('models/bvlc_reference_caffenet.caffemodel')
31 
32 solver.net.layers[0].SetDataConfig(data_config)
33 
34 for i in range(1, 10000):
35     # Make one SGD update
36     solver.step(5)
37     if i % 100 == 0:
38         solver.net.save('tmp.caffemodel')
39         ''' TODO:  test code '''  

2 test_net.py

 1 #import setup
 2 import time
 3 import os
 4 import random
 5 import sys
 6 sys.path.append("/home/wang/Downloads/caffe-master/python")
 7 import caffe
 8 import cv2
 9 import numpy as np
10 import random
11 
12 
13 from utils import PrepareImage
14 #from dataset import Dataset
15 from test_data import test_data_pre 
16 
17 test_num_once=10
18 
19 
20 ''' uncommend below line to use gpu '''
21 # caffe.set_mode_gpu()
22 
23 # dataset
24 #dataset = Dataset('/home/wang/Downloads/object/extract/')
25 #dataset = dataset.Split('test')
26 
27 # load net
28 net = caffe.Net('models/deploy.prototxt', caffe.TEST)
29 
30 
31 # load train model
32 print('load pretrain model')
33 net.copy_from('tmp.caffemodel')
34 
35 #test all samples one by one
36 data_pre='/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/test/'
37 #(imgPaths, gt_label) = dataset[int(random.random()*num_obj)]
38 (imgPaths, gt_label)=test_data_pre(data_pre) 
39 num_img = len(imgPaths)
40 correct_num=0
41 for idx in range(num_img):
42     img = cv2.imread(imgPaths[idx])
43     img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
44     tmp_img = img.copy() # for display
45     img = PrepareImage(img, (227, 227))
46     net.blobs['data'].reshape(test_num_once, 3, 227, 227)
47     net.blobs['data'].data[...] = img
48     #net.blobs['data'].data[i,:,:,:] = img
49     net.forward()
50     score = net.blobs['cls_prob'].data
51     if score.argmax()==gt_label[idx]:
52         correct_num=correct_num+1
53     if idx%100==0:
54         print("Please wait some minutes...")
55 correct_rate=correct_num*1.0/num_img
56 print('The correct rate is :',correct_rate)
57 
58 
59     

3 test_data.py

 1 import os
 2 import numpy as np
 3 from random import randint
 4 import cv2
 5 from utils import PrepareImage,CatImage
 6 #class data:
 7 #path should be /home/
 8 def test_data_pre(path):
 9     img_list=[]
10     image_num=len(os.listdir(path+'/0'))+len(os.listdir(path+'/1'))+len(os.listdir(path+'/2'))  
11     label = np.zeros(image_num, dtype=np.float32)  
12 
13     i=0
14     for idf in range(3): 
15         idf_str=str(idf)
16         path1=path+idf_str
17         tmp_path=os.listdir(path1)
18         for idi in range(len(tmp_path)):   
19             img_path=path1+'/'+tmp_path[idi] 
20             img_list.append(img_path)
21             label[i]=idf
22             i=i+1
23     return ( img_list,label)

4 pre_data.py

 1 import os
 2 import numpy as np
 3 from random import randint
 4 import cv2
 5 from utils import PrepareImage,CatImage
 6 #class data:
 7 #path should be /home/
 8 def prepare_data(path,batchsize):
 9     #tmp_path=os.listdir(path)
10     img_list=[]
11     label = np.zeros(batchsize, dtype=np.float32)
12     for i in range(batchsize): 
13         #randomly select one file
14         idf=randint(0,2)
15         idf_str=str(idf)
16         path1=path+idf_str
17         tmp_path=os.listdir(path1)
18         
19         #randomly select one image    
20         idi=randint(0,len(tmp_path)-1)
21         #img = cv2.imread(imgPaths[idx])
22         img_path=path1+'/'+tmp_path[idi]
23         img=cv2.imread(img_path)
24 
25         img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
26         flip = randint(0, 1)>0
27         if flip > 0:
28             img = img[:, ::-1, :] # flip left to right
29  
30         img=PrepareImage(img, (227,227))
31         img_list.append(img)
32         label[i]=idf
33     imgData = CatImage(img_list)
34     return (imgData,label)

5 utils.py

 1 import os
 2 import cv2
 3 import numpy as np
 4 
 5 def PrepareImage(im, size):
 6     im = cv2.resize(im, (size[0], size[1]))
 7     im = im.transpose(2, 0, 1)
 8     im = im.astype(np.float32, copy=False)
 9     return im
10 
11 def CatImage(im_list):
12     max_shape = np.array([im.shape for im in im_list]).max(axis=0)
13     blob = np.zeros((len(im_list), 3, max_shape[1], max_shape[2]), dtype=np.float32)
14     # set to mean value
15     blob[:, 0, :, :] = 102.9801
16     blob[:, 1, :, :] = 115.9465
17     blob[:, 2, :, :] = 122.7717 
18     for i, im in enumerate(im_list):
19         blob[i, :, 0:im.shape[1], 0:im.shape[2]] = im
20     return blob

6 layer/data_layer.py

 1 import caffe
 2 import numpy as np
 3 
 4 #import data_config
 5 #import prepare_data
 6 from pre_data import prepare_data
 7 
 8 class DataLayer(caffe.Layer):
 9 
10     def SetDataConfig(self, data_config):
11         self._data_config = data_config
12 
13     def GetDataConfig(self):
14         return self._data_config
15 
16     def setup(self, bottom, top):
17         # data blob
18         top[0].reshape(1, 3, 227, 227)
19         #top[0].reshape(1, 3, 34, 44)
20         # label type
21         top[1].reshape(1, 1)
22 
23     def reshape(self, bootom, top):
24         pass
25 
26     def forward(self, bottom, top):
27         #(imgs, label) = self._data_config.next()
28         path=self.GetDataConfig()
29         (imgs,label)=prepare_data(path,128)
30         (N, C, W, H) = imgs.shape
31         # image data
32         top[0].reshape(N, C, W, H)
33         top[0].data[...] = imgs
34         # object type label
35         top[1].reshape(N)
36         top[1].data[...] = label
37 
38     def backward(self, top, propagate_down, bottom):
39         pass

7 layer/__init__.py

import data_layer

还有一些caffe中经典的东西没放进来。

代码和数据:

原文地址:https://www.cnblogs.com/Wanggcong/p/5169737.html