Skip to content

Commit 09a2f12

Browse files
author
Subhasis Das
committed
python port of features matlab code
1 parent 4921991 commit 09a2f12

File tree

2 files changed

+119
-0
lines changed

2 files changed

+119
-0
lines changed

python_features/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
This directory contains a Python port of the Matlab code in matlab_features_reference/ directory
2+
3+
- This code uses [Caffe](http://caffe.berkeleyvision.org/) and their Python wrapper.
4+
- I use VGG Net which can be found in the [Model Zoo ](https://github.com/BVLC/caffe/wiki/Model-Zoo) under the title *Models used by the VGG team in ILSVRC-2014*. I use the [16-layer version](https://gist.github.com/ksimonyan/211839e770f7b538e2d8#file-readme-md).
5+
- Note that I provide my _features deploy network def as well, which is exactly what you see on that page but I chopped off the softmax to get the 4096-D codes below.

python_features/extract_features.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import sys
2+
import argparse
3+
4+
import numpy as np
5+
from scipy.misc import imread, imresize
6+
7+
import cPickle as pickle
8+
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument('--caffe',
11+
help='path to caffe installation')
12+
parser.add_argument('--model_def',
13+
help='path to model definition prototxt')
14+
parser.add_argument('--model',
15+
help='path to model parameters')
16+
parser.add_argument('--files',
17+
help='path to a file contsining a list of images')
18+
parser.add_argument('--gpu',
19+
action='store_true',
20+
help='whether to use gpu training')
21+
parser.add_argument('--out',
22+
help='name of the pickle file where to store the features')
23+
24+
args = parser.parse_args()
25+
26+
caffepath = args.caffe + '/python'
27+
sys.path.append(caffepath)
28+
29+
import caffe
30+
31+
def predict(in_data, net):
32+
"""
33+
Get the features for a batch of data using network
34+
35+
Inputs:
36+
in_data: data batch
37+
"""
38+
39+
out = net.forward(**{net.inputs[0]: in_data})
40+
features = out[net.outputs[0]].squeeze(axis=(2,3))
41+
return features
42+
43+
44+
def batch_predict(filenames, net):
45+
"""
46+
Get the features for all images from filenames using a network
47+
48+
Inputs:
49+
filenames: a list of names of image files
50+
51+
Returns:
52+
an array of feature vectors for the images in that file
53+
"""
54+
55+
N, C, H, W = net.blobs[net.inputs[0]].data.shape
56+
F = net.blobs[net.outputs[0]].data.shape[1]
57+
Nf = len(filenames)
58+
Hi, Wi, _ = imread(filenames[0]).shape
59+
allftrs = np.zeros((Nf, F))
60+
for i in range(0, Nf, N):
61+
in_data = np.zeros((N, C, H, W), dtype=np.float32)
62+
63+
batch_range = range(i, min(i+N, Nf))
64+
batch_filenames = [filenames[j] for j in batch_range]
65+
Nb = len(batch_range)
66+
67+
batch_images = np.zeros((Nb, 3, H, W))
68+
for j,fname in enumerate(batch_filenames):
69+
im = imread(fname)
70+
if len(im.shape) == 2:
71+
im = np.tile(im[:,:,np.newaxis], (1,1,3))
72+
# RGB -> BGR
73+
im = im[:,:,(2,1,0)]
74+
# mean subtraction
75+
im = im - np.array([103.939, 116.779, 123.68])
76+
# resize
77+
im = imresize(im, (H, W))
78+
# get channel in correct dimension
79+
im = np.transpose(im, (2, 0, 1))
80+
batch_images[j,:,:,:] = im
81+
82+
# insert into correct place
83+
in_data[0:len(batch_range), :, :, :] = batch_images
84+
85+
# predict features
86+
ftrs = predict(in_data, net)
87+
88+
for j in range(len(batch_range)):
89+
allftrs[i+j,:] = ftrs[j,:]
90+
91+
print 'Done %d/%d files' % (i+len(batch_range), len(filenames))
92+
93+
return allftrs
94+
95+
96+
if args.gpu:
97+
caffe.set_mode_gpu()
98+
else:
99+
caffe.set_mode_cpu()
100+
101+
net = caffe.Net(args.model_def, args.model)
102+
caffe.set_phase_test()
103+
104+
filenames = []
105+
with open(args.files) as fp:
106+
for line in fp:
107+
filename = line.strip().split()[0]
108+
filenames.append(filename)
109+
110+
allftrs = batch_predict(filenames, net)
111+
112+
# store the features in a pickle file
113+
with open(args.out, 'w') as fp:
114+
pickle.dump(allftrs, fp)

0 commit comments

Comments
 (0)