-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathload_data.py
More file actions
47 lines (25 loc) · 877 Bytes
/
load_data.py
File metadata and controls
47 lines (25 loc) · 877 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import gzip
import pickle
import numpy as np
def loadData() :
f = gzip.open('data/mnist.pkl.gz', 'rb')
trainingData, validationData, testData = pickle.load(f, encoding = 'latin1')
f.close()
return (trainingData, validationData, testData)
def loadDataWrapper() :
'''
Convert loadData() result to more appropriate format
'''
trD, vaD, teD = loadData()
trainingInputs = [np.reshape(x, (784, 1)) for x in trD[0]]
trainingResults = [vectorizedResult(y) for y in trD[1]]
trainingData = zip(trainingInputs, trainingResults)
validationInputs = [np.reshape(x, (784, 1)) for x in vaD[0]]
validationData = zip(validationInputs, vaD[1])
testInputs = [np.reshape(x, (784, 1)) for x in teD[0]]
testData = zip(testInputs, teD[1])
return (trainingData, validationData, testData)
def vectorizedResult(j) :
e = np.zeros((10, 1))
e[j] = 1.0
return e