-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTesting-Digits.R
More file actions
97 lines (81 loc) · 2.53 KB
/
Testing-Digits.R
File metadata and controls
97 lines (81 loc) · 2.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# AUTHOR: Alexandre Galvão Patriota
# IME-USP
options(scipen=999)
#require('stringdist')
#require('stringr')
require('torch')
source('config.R')
source('GPT.R')
source('Generators.R')
library(gmp)
#Change here to set the number of digits
config$digits = 100
n0 = config$digits
#Print the results?
Print=TRUE
#comment for testing with different numbers
set.seed(10)
#NUmbers
x = paste0(sample(0:9,n0, replace=TRUE), collapse="")
y = paste0(sample(0:9,n0, replace=TRUE), collapse="")
model_save = "Model.pt"
#Vocabulary
Voc = c("P", "S", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "\n", "C")
device0= if (torch::cuda_is_available()) "cuda" else "cpu"
#MODEL
model = GPT(block_size = config$block_size,
block_size_out = config$block_size_out,
n_embd = config$n_embd,
N_Layers = config$N_Layers,
Head = config$n_head,
nvoc = length(Voc),p0 =config$p0, p1=config$p1
)
if (file.exists(model_save)){
model$load_state_dict(state_dict = torch::torch_load(model_save), .refer_to_state_dict = TRUE )
}
model = model$to(device=device0)
x0 = as.bigz(ifelse(as.numeric(x)==0, x, sub("^0+", "",x)))
y0 = as.bigz(ifelse(as.numeric(y)==0, y, sub("^0+", "",y)))
real = x0+y0
num = Intercalar(paste(x,"+", y, sep=""))
temp = list()
num0 = strsplit(num, "")[[1]]
N = abs(2-nchar(num))
num1 = c(num0[(length(num0)-2+1):length(num0)],rep("P", config$block_size-2))
num1 = c(Encoder(num1))
temp[[1]] = Generate(num1,Model=model,block_size_out=config$block_size_out,max_new_tokens=config$max, print=FALSE)
s=0
if(N>=2){
for(l in (1:(N/2))*2) {
s = s+1
num1 = num0[(length(num0)-l-1):(length(num0)-l)]
num1 = c(temp[[s]][-1],14, Encoder(num1))
nn = length(num1)
num1 = c(num1,rep(1, config$block_size-nn))
temp[[s+1]] = Generate(num1,Model=model,block_size_out=config$block_size_out,max_new_tokens=config$max, print=FALSE)
cli::cli_progress_message(paste(" Digit #",s+1, " ", sep=""))
if(l%%100000==0){
gc()
gc()
}
}
}
aux = function(s) temp[[s]][length(temp[[s]])]
ind = length(temp):1
a = temp[[length(temp)]]
a = a[length(a)-1]
if(a %in% c(1,2,3,13,14)){
pred = paste(Decoder(c(sapply(ind,aux))), collapse="")
} else{
pred = paste(Decoder(c(a,sapply(ind,aux))), collapse="")
}
num1 = as.numeric(unlist(strsplit(as.character(num), "")))
pred = ifelse(pred==0, pred, sub("^0+", "", pred))
cat("\n")
print(paste("The output is ",real==pred, sep=""))
if(Print){
print(paste("x=", x, sep=""))
print(paste("y=", y, sep=""))
print(paste("Real x+y=", real, sep=""))
print(paste("Pred x+y=", pred, sep=""))
}