-
Notifications
You must be signed in to change notification settings - Fork 29
Expand file tree
/
Copy pathdata_ops.py
More file actions
executable file
·107 lines (82 loc) · 3.12 KB
/
data_ops.py
File metadata and controls
executable file
·107 lines (82 loc) · 3.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
'''
Operations used for data management
MASSIVE help from https://github.com/affinelayer/pix2pix-tensorflow/blob/master/pix2pix.py
'''
from __future__ import division
from __future__ import absolute_import
from scipy import misc
from skimage import color
import collections
import tensorflow as tf
import numpy as np
import math
import time
import random
import glob
import os
import fnmatch
import cPickle as pickle
Data = collections.namedtuple('trainData', 'paths, inputs, targets, count, steps_per_epoch')
def getPaths(data_dir, ext='jpg'):
pattern = '*.'+ext
image_paths = []
for d, s, fList in os.walk(data_dir):
for filename in fList:
if fnmatch.fnmatch(filename, pattern):
image_paths.append(os.path.join(d,filename))
return image_paths
def loadData(data_dir, batch_size, train=True):
if data_dir is None or not os.path.exists(data_dir): raise Exception('data_dir does not exist')
if train:
pkl_train_file = 'pokemon.pkl'
if os.path.isfile(pkl_train_file):
print 'Found pickle file'
train_paths = pickle.load(open(pkl_train_file, 'rb'))
else:
train_paths = getPaths(data_dir)
random.shuffle(train_paths)
pf = open(pkl_train_file, 'wb')
data = pickle.dumps(train_paths)
pf.write(data)
pf.close()
input_paths = train_paths
else:
input_paths = [data_dir]
decode = tf.image.decode_image
if len(input_paths) == 0: raise Exception('data_dir contains no image files')
else: print 'Found',len(input_paths),'images!'
with tf.name_scope('load_images'):
path_queue = tf.train.string_input_producer(input_paths, shuffle=train)
reader = tf.WholeFileReader()
paths, contents = reader.read(path_queue)
raw_input_ = decode(contents)
raw_input_ = tf.image.convert_image_dtype(raw_input_, dtype=tf.float32)
raw_input_.set_shape([None, None, 3])
inputs = tf.image.rgb_to_grayscale(raw_input_)
targets = raw_input_
scale_size = 180
height = 160
width = 144
seed = random.randint(0, 2**31 - 1)
def transform(image):
r = image
r = tf.image.random_flip_left_right(r, seed=seed)
r = tf.image.resize_images(r, [height, width], method=tf.image.ResizeMethod.AREA)
#offset = tf.cast(tf.floor(tf.random_uniform([2], 0, scale_size - width + 1, seed=seed)), dtype=tf.int32)
#r = tf.image.crop_to_bounding_box(r, offset[0], offset[1], height, width)
return r
if train:
input_images = transform(inputs)
target_images = transform(targets)
else:
input_images = tf.image.resize_images(inputs, [160, 160], method=tf.image.ResizeMethod.AREA)
target_images = tf.image.resize_images(targets, [160, 160], method=tf.image.ResizeMethod.AREA)
paths_batch, inputs_batch, targets_batch = tf.train.batch([paths, input_images, target_images], batch_size=batch_size)
steps_per_epoch = int(math.ceil(len(input_paths) / batch_size))
return Data(
paths=paths_batch,
inputs=inputs_batch,
targets=targets_batch,
count=len(input_paths),
steps_per_epoch=steps_per_epoch,
)