-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathextractor.py
More file actions
67 lines (57 loc) · 2.25 KB
/
extractor.py
File metadata and controls
67 lines (57 loc) · 2.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import tensorflow as tf
import numpy as np
import csv
import pickle
from utils import ImageLoader
# training data dir
data_dir = '../image_captioning/train/images'
class ImageFeatureExtractor(object):
def __init__(self, model_path):
"""Load TensorFlow CNN model."""
assert os.path.exists(model_path), 'File does not exist %s' % model_path
self.model_path = model_path
# load graph
with tf.gfile.FastGFile(os.path.join(model_path), 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
# create a session for feature extraction
self.session = tf.Session()
self.writer = None
def extract_features(self, image, tensor_name='InceptionV3/Logits/AvgPool_1a_8x8/AvgPool:0'):
"""Extract image feature from image (numpy array) or from jpeg file."""
sess = self.session
feat_tensor = sess.graph.get_tensor_by_name(tensor_name)
# image is a path to an jpeg file
assert os.path.exists(image), 'File does not exist %s' % image
image_loader = ImageLoader()
features = sess.run(feat_tensor, {'input:0': image_loader.load_imgs([image])})
return list(np.squeeze(features))
if __name__ == '__main__':
ife = ImageFeatureExtractor('model/inception_v3_2016_08_28_frozen.pb')
anns = csv.reader(open("train/anns.csv"))
count = 0
img_2048_dict = {}
for row in anns:
try:
if row[1] == 'caption':
continue
image_path = row[2]
img_2048_dict[int(row[3])] = ife.extract_features(image_path)
count += 1
if count % 100 == 0:
print('%d finished' % count)
if count % 10000 == 0:
with open('./train/train_img2048_%d.pkl' % (count / 10000),
'wb') as f:
pickle.dump(img_2048_dict, f)
img_2048_dict = {}
except:
continue
with open('./train/train_img2048_%d.pkl' % (count / 10000 + 1),
'wb') as f:
pickle.dump(img_2048_dict, f)