-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
29 lines (24 loc) · 888 Bytes
/
test.py
File metadata and controls
29 lines (24 loc) · 888 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import GPU_AC
import torch
import time
useGpu = True
# Generate random symbols and pdf (both must be on the same device/CPU and in uint16).
dim = 256
symsNum = 8192*100 # compress 8192 sybmols in one thread
pdf = torch.rand(symsNum,dim).short().cuda() + 0.01
pdf = pdf / (torch.sum(pdf,1,keepdims=True))
symgpu = torch.randint(0,dim,(symsNum,1)).short().cuda()
pdfgpu = pdf
t1 = time.time()
filebin = 'gpuac.bin'
# Encode to bytestream.
encodsz = GPU_AC.encode(symgpu, pdf,filebin,useGpu=useGpu,interaction=True)
# Number of bits taken by the stream.
print('real_bpp',encodsz/symsNum)
# Theoretical bits number
criterion = torch.nn.NLLLoss()
print('shannon entropy', criterion(torch.log2(pdfgpu), symgpu.reshape(-1).long()))
# Decode from bytestream.
symbols_dec = GPU_AC.decode(pdf,filebin,useGpu=useGpu)
assert (symbols_dec == symbols_dec).all()
print('time used',time.time()-t1)