如何用Caffe训练自己的网络-探索与试验

现在一直都是用Caffe在跑别人写好的网络,如何运行自定义的网络和图片,是接下来要学习的一点。

1. 使用Caffe中自带的网络模型来运行自己的数据集

参考 [1] :http://www.cnblogs.com/denny402/p/5083300.html,下面几乎是全文转载,有部分对自己踩过的坑的补充,向原作者致敬!

一、准备数据

我去网上找了一些其它的图片来代替,共有500张图片,分为大巴车、恐龙、大象、鲜花和马五个类,每个类100张。需要的同学,可到我的网盘下载:http://pan.baidu.com/s/1nuqlTnN

编号分别以3,4,5,6,7开头,各为一类。我从其中每类选出20张作为测试,其余80张作为训练。因此最终训练图片400张,测试图片100张,共5类。我将图片放在caffe根目录下的data文件夹下面。即训练图片目录:data/re/train/ ,测试图片目录: data/re/test/

二、转换为lmdb格式

具体的转换过程,可参见我的前一篇博文:Caffe学习系列(11):图像数据转换成db(leveldb/lmdb)文件

首先,在examples下面创建一个myfile的文件夹,来用存放配置文件和脚本文件。然后编写一个脚本create_filelist.sh,用来生成train.txt和test.txt清单文件

# mkdir examples/myfile
# vi examples/myfile/create_filelist.sh

编辑此文件,写入如下代码,并保存

复制代码
#!/usr/bin/env sh
DATA=data/re/
MY=examples/myfile
echo "Create train.txt..." rm -rf $MY/train.txt for i in 3 4 5 6 7 do find $DATA/train -name $i*.jpg | cut -d '/' -f4-5 | sed "s/$/ $i/">>$MY/train.txt done echo "Create test.txt..." rm -rf $MY/test.txt for i in 3 4 5 6 7 do find $DATA/test -name $i*.jpg | cut -d '/' -f4-5 | sed "s/$/ $i/">>$MY/test.txt done echo "All done"
复制代码

然后,运行此脚本

# sh examples/myfile/create_filelist.sh

成功的话,就会在examples/myfile/ 文件夹下生成train.txt和test.txt两个文本文件,里面就是图片的列表清单。

接着再编写一个脚本文件,调用convert_imageset命令来转换数据格式。

# vi examples/myfile/create_lmdb.sh

插入并保存:(注意修改Caffe的绝对路径)

复制代码
#!/usr/bin/env sh
MY=examples/myfile

echo "Create train lmdb.."
rm -rf $MY/img_train_lmdb
build/tools/convert_imageset 
--shuffle 
--resize_height=256 
--resize_width=256 
/home/xxx/caffe/data/re/ 
$MY/train.txt 
$MY/img_train_lmdb

echo "Create test lmdb.."
rm -rf $MY/img_test_lmdb
build/tools/convert_imageset 
--shuffle 
--resize_width=256 
--resize_height=256 
/home/xxx/caffe/data/re/ 
$MY/test.txt 
$MY/img_test_lmdb

echo "All Done.."
复制代码

因为图片大小不一,因此我统一转换成256*256大小。

运行脚本:

# sh examples/myfile/create_lmdb.sh

运行成功后,会在 examples/myfile下面生成两个文件夹img_train_lmdb和img_test_lmdb,分别用于保存图片转换后的lmdb文件。

三、计算均值并保存

图片减去均值再训练,会提高训练速度和精度。因此,一般都会有这个操作。

caffe程序提供了一个计算均值的文件compute_image_mean.cpp,我们直接使用就可以了

# build/tools/compute_image_mean examples/myfile/img_train_lmdb examples/myfile/mean.binaryproto
compute_image_mean带两个参数,第一个参数是lmdb训练数据位置,第二个参数设定均值文件的名字及保存路径。
运行成功后,会在 examples/myfile/ 下面生成一个mean.binaryproto的均值文件。

四、创建模型并编写配置文件

模型就用程序自带的caffenet模型,位置在 models/bvlc_reference_caffenet/文件夹下, 将需要的两个配置文件,复制到myfile文件夹内

# cp models/bvlc_reference_caffenet/solver.prototxt examples/myfile/
# cp models/bvlc_reference_caffenet/train_val.prototxt examples/myfile/

修改其中的solver.prototxt

# sudo vi examples/myfile/solver.prototxt
 
复制代码
net: "examples/myfile/train_val.prototxt"
test_iter: 2
test_interval: 50
base_lr: 0.001
lr_policy: "step"
gamma: 0.1
stepsize: 100
display: 20
max_iter: 500
momentum: 0.9
weight_decay: 0.005
solver_mode: GPU
复制代码

100个测试数据,batch_size为50,因此test_iter设置为2,就能全cover了。在训练过程中,调整学习率,逐步变小。

修改train_val.protxt,只需要修改两个阶段的data层就可以了,其它可以不用管。

复制代码
name: "CaffeNet"
layer {
  name: "data"
  type: "Data"
  top: "data"
  top: "label"
  include {
    phase: TRAIN
  }
  transform_param {
    mirror: true
    crop_size: 227
    mean_file: "examples/myfile/mean.binaryproto"
  }
  data_param {
    source: "examples/myfile/img_train_lmdb"
    batch_size: 16 // 这里注意修改为16
    backend: LMDB
  }
}
layer {
  name: "data"
  type: "Data"
  top: "data"
  top: "label"
  include {
    phase: TEST
  }
  transform_param {
    mirror: false
    crop_size: 227
    mean_file: "examples/myfile/mean.binaryproto"
  }
  data_param {
    source: "examples/myfile/img_test_lmdb"
    batch_size: 50
    backend: LMDB
  }
}
复制代码

注意修改batch_size为16,不能是256

否则会出现Restarting data prefetching from start.一直重复出现

参考 [2] ,原因是

训练集太小了,于是我把batch_size改小了,从原来的256改成16,才OK了。

五、训练和测试

如果前面都没有问题,数据准备好了,配置文件也配置好了,这一步就比较简单了。

# sudo build/tools/caffe train -solver examples/myfile/solver.prototxt

运行时间和最后的精确度,会根据机器配置,参数设置的不同而不同。我的是gpu+cudnn运行500次,大约3分钟,精度为91%。

 总结一下,相当于用已有的网络架构在自己的训练集和测试集上进行测试,
网络架构比较难发明,所以暂时先不考虑怎么定义网络架构

接下来会探究如何定义自己的网络,
 
以及如何使用训练好的模型来判断一张即时抓取的图片属于哪一类
 
未完待续
 
 
=====================================
 

感兴趣的同学可以去研究一下TFlearning和Pytorch,比Caffe以及TF更容易学懂(两者都是在TF的基础上的高级API)

参考文献:

[1] http://www.cnblogs.com/denny402/p/5083300.html

[2] http://blog.csdn.net/iambool/article/details/69526089

原文地址:https://www.cnblogs.com/QingHuan/p/7306720.html