首页 > 编程语言 > 详细

python caffe 在师兄的代码上修改成自己风格的代码

时间:2016-01-29 20:41:37      阅读:492      评论:0      收藏:0      [点我收藏+]

师兄的代码封装成类,流畅精美,容易调试。我的代码是堆积成的,被师兄嘲笑说写脚本。好吧!我的代码只有我懂,哈哈! 希望以后代码能写得工整点。现在还是让我先懂。这里,我做了一个简单的任务:0,1,2三个数字的分类

(部分)代码分为:

1 train_net.py

 1 #import some module
 2 import time
 3 import os
 4 import numpy as np
 5 import sys
 6 import cv2
 7 sys.path.append("/home/wang/Downloads/caffe-master/python")
 8 import caffe
 9 #from prepare_data import DataConfig
10 #from data_config import DataConfig
11 
12 #configure GPU mode
13 ‘‘‘ uncommend below line to use gpu ‘‘‘
14 caffe.set_mode_gpu()
15 
16 # about dataset
17 ##dataset = Dataset(‘/home/wang/Downloads/object/extract/‘)
18 ##dataset = dataset.Split(‘train‘)
19 ##data_config = DataConfig(dataset)
20 ##data_config.SetBatchSize(256)
21 data_config=/home/wang/Downloads/caffe-master/examples/myFig_recognition/data/train/
22 
23 
24 
25 #configure solve.prototxt
26 solver = caffe.SGDSolver(models/solver.prototxt)
27 
28 # load pretrain model
29 print(load pretrain model)
30 solver.net.copy_from(models/bvlc_reference_caffenet.caffemodel)
31 
32 solver.net.layers[0].SetDataConfig(data_config)
33 
34 for i in range(1, 10000):
35     # Make one SGD update
36     solver.step(5)
37     if i % 100 == 0:
38         solver.net.save(tmp.caffemodel)
39         ‘‘‘ TODO:  test code ‘‘‘  

2 test_net.py(还没写)

3 pre_data.py

 1 import os
 2 import numpy as np
 3 from random import randint
 4 import cv2
 5 from utils import PrepareImage,CatImage
 6 #class data:
 7 #path should be /home/
 8 def prepare_data(path,batchsize):
 9     #tmp_path=os.listdir(path)
10     img_list=[]
11     label = np.zeros(batchsize, dtype=np.float32)
12     for i in range(batchsize): 
13         #randomly select one file
14         idf=randint(0,2)
15         idf_str=str(idf)
16         path1=path+idf_str
17         tmp_path=os.listdir(path1)
18         
19         #randomly select one image    
20         idi=randint(0,len(tmp_path)-1)
21         #img = cv2.imread(imgPaths[idx])
22         img_path=path1+/+tmp_path[idi]
23         img=cv2.imread(img_path)
24 
25         img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
26         flip = randint(0, 1)>0
27         if flip > 0:
28             img = img[:, ::-1, :] # flip left to right
29  
30         img=PrepareImage(img, (227,227))
31         img_list.append(img)
32         label[i]=idf
33     imgData = CatImage(img_list)
34     return (imgData,label)

4 utils.py

 1 import os
 2 import cv2
 3 import numpy as np
 4 
 5 def PrepareImage(im, size):
 6     im = cv2.resize(im, (size[0], size[1]))
 7     im = im.transpose(2, 0, 1)
 8     im = im.astype(np.float32, copy=False)
 9     return im
10 
11 def CatImage(im_list):
12     max_shape = np.array([im.shape for im in im_list]).max(axis=0)
13     blob = np.zeros((len(im_list), 3, max_shape[1], max_shape[2]), dtype=np.float32)
14     # set to mean value
15     blob[:, 0, :, :] = 102.9801
16     blob[:, 1, :, :] = 115.9465
17     blob[:, 2, :, :] = 122.7717 
18     for i, im in enumerate(im_list):
19         blob[i, :, 0:im.shape[1], 0:im.shape[2]] = im
20     return blob

4 layer/data_layer.py

 1 import caffe
 2 import numpy as np
 3 
 4 #import data_config
 5 #import prepare_data
 6 from pre_data import prepare_data
 7 
 8 class DataLayer(caffe.Layer):
 9 
10     def SetDataConfig(self, data_config):
11         self._data_config = data_config
12 
13     def GetDataConfig(self):
14         return self._data_config
15 
16     def setup(self, bottom, top):
17         # data blob
18         top[0].reshape(1, 3, 227, 227)
19         #top[0].reshape(1, 3, 34, 44)
20         # label type
21         top[1].reshape(1, 1)
22 
23     def reshape(self, bootom, top):
24         pass
25 
26     def forward(self, bottom, top):
27         #(imgs, label) = self._data_config.next()
28         path=self.GetDataConfig()
29         (imgs,label)=prepare_data(path,128)
30         (N, C, W, H) = imgs.shape
31         # image data
32         top[0].reshape(N, C, W, H)
33         top[0].data[...] = imgs
34         # object type label
35         top[1].reshape(N)
36         top[1].data[...] = label
37 
38     def backward(self, top, propagate_down, bottom):
39         pass

5 layer/__init__.py

import data_layer

还有一些caffe中经典的东西没放进来。

数据:http://pan.baidu.com/s/1mgYQa6G(尚未分训练集和测试集)

 

python caffe 在师兄的代码上修改成自己风格的代码

原文:http://www.cnblogs.com/Wanggcong/p/5169737.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!