-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest.lua
More file actions
91 lines (58 loc) · 2.12 KB
/
test.lua
File metadata and controls
91 lines (58 loc) · 2.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
local knn = require 'knn'
local function time(name, f)
local t = torch.Timer()
local r = f()
print(string.format("%s = %s: %.2f", name, tostring(r), t:time().real))
collectgarbage()
end
local test = {}
local makeData = function(n, size)
return torch.FloatTensor():range(1, size * n):reshape(n, size)
end
test.benchmark = function(k, size, q, n)
local data = makeData(n, size)
local query = makeData(q, size)
time(string.format("knn (k: %d) (features: %d) (query: %d) (data: %d)", k, query:size(2), query:size(1), data:size(1)), function ()
knn.knn(data, query, k)
end)
end
test.knn = function(k, size, q, n)
local data = torch.FloatTensor():rand(n, size)
local query = torch.FloatTensor(q, size):zero()
local inds = {}
for i = 1, q do
local index = torch.random(n)
inds[i] = index
query[i] = data[index]
end
local dists, indices = knn.knn(data, query, k)
-- print(data, query, dists, indices)
for i = 1, q do
assert(dists[i][1] == 0, "distance should be zero, was: "..tostring(dists[i][1]))
assert(indices[i][1] == inds[i], "indices aren't correct, was: "..tostring(indices[i][1]).." should be: "..inds[i])
end
print(string.format("test passed, knn (k: %d) (features: %d) (query: %d) (data: %d)", k, query:size(2), query:size(1), data:size(1)))
end
test.lookup = function(n)
for i = 1, n do
local n1 = torch.random(100)
local n2 = torch.random(10)
local l = torch.random(20)
local table = torch.LongTensor():range(1, l)
local indices = torch.IntTensor(n1, n2):random(1, l)
local r = knn.lookup(table, indices):int()
assert(r:eq(indices):min() == 1, string.format("lookup failed table = %d indices = (%d, %d)", l, n1, n2))
end
print(string.format("lookup passed %d tests", n))
end
test.lookup(10000)
test.knn(2, 5, 10, 10)
test.knn(4, 1280, 100, 10000)
test.knn(4, 100, 200, 10000)
test.knn(16, 1024, 2000, 70000)
test.benchmark(2, 128, 10000, 10000)
test.benchmark(4, 128, 10000, 50000)
test.benchmark(8, 128, 50000, 50000)
test.benchmark(24, 128, 10000, 100000)
test.benchmark(16, 1024, 10000, 50000)
return test