Skip to content
50 changes: 36 additions & 14 deletions R/mp_tmb_calibrator.R
Original file line number Diff line number Diff line change
Expand Up @@ -488,19 +488,27 @@ TMBCalDataStruc = function(data, time) {
}
FALSE
}

data = rename_synonyms(data
, time = c(
"time", "Time", "ID", "time_id", "id", "date", "Date"
, "time_step", "timeStep", "TimeStep"
)
, matrix = c(
"matrix", "Matrix", "mat", "Mat", "variable", "var", "Variable", "Var"
)
, row = c("row", "Row")
, col = c("col", "Col", "column", "Column")
, value = c("value", "Value", "val", "Val", "default", "Default")
)

syns <- list(time = c("time", "Time", "ID", "time_id", "id",
"date", "Date",
"time_step", "timeStep", "TimeStep"),
matrix = c("matrix", "Matrix", "mat", "Mat",
"variable", "var", "Variable", "Var"),
row = c("row", "Row"),
col = c("col", "Col", "column", "Column"),
value = c("value", "Value", "val", "Val", "default", "Default"))
data = do.call(rename_synonyms, c(list(data), syns))

## check presence (row/col not required?)
for (m in setdiff(names(syns), c("row", "col"))) {
if (is.null(data[[m]])) {
stop(
"Supplied data did not contain a column called '", m, "' ",
"(or its synonyms: ",
paste(sprintf("'%s'", syns[[m]]), collapse = ", "), ")"
)
}
}
time_column_test_value = data$time
if (is.character(data$time)) {
original_coercer = as.character
Expand Down Expand Up @@ -1270,9 +1278,23 @@ TMBTraj.character = function(
## Depended upon to create a character vector of output variables to fit to
self$outputs = function() names(self$list)

# ff = function(dat) {
# xx = split(dat, dat$row)
# time_ids = lapply(xx, getElement, "time_ids")
# row = lapply(xx, getElement, "row") |> lapply(unique)
# if (!all(vapply(row, length, integer(1L)) == 1L)) {
# stop("Ca")
# }
# if (any())
# list(time_ids, row)
# }

## implemented methods
self$obs = function() lapply(self$list, getElement, "value")
self$obs_times = function() lapply(self$list, getElement, "time_ids")
self$obs_times = function() {
#split(traj$list$infection, traj$list$infection$row)
lapply(self$list, getElement, "time_ids")
}
self$distr_params = function() {
switch(
getOption("macpan2_default_loss")[1L]
Expand Down
1 change: 1 addition & 0 deletions tests/testthat/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ sims = list(
, sir_10_I = mp_simulator(all_specs$sir, 10L, "I")
, sir_50_infection = mp_simulator(all_specs$sir, 50L, "infection")
, sir_50_I = mp_simulator(all_specs$sir, 50L, "I")
, sir_age_10_infection = mp_simulator(all_specs$sir_age, 10L, "infection")
)

for (obj in names(sims)) {
Expand Down
27 changes: 26 additions & 1 deletion tests/testthat/test-calibrator-traj.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
library(macpan2); library(testthat); library(dplyr); library(tidyr); library(ggplot2)
test_that("bad outputs give warnings", {
sir = mp_tmb_library("starter_models", "sir", package = "macpan2")
expect_warning(
Expand Down Expand Up @@ -60,3 +59,29 @@ test_that("trajectories specified with likelihood distributions end up in calibr
)

})

test_that("missing required columns in calibration data throw errors", {
sir = "SPEC-sir.rds" |> test_cache_read()
sir_sims = "TRAJ-sir_5_state.rds" |> test_cache_read()
err = "Supplied data did not contain a column called"
expect_error(mp_tmb_calibrator(sir, data = select(sir_sims, -time)), err)
expect_error(mp_tmb_calibrator(sir, data = select(sir_sims, -matrix)), err)
expect_error(mp_tmb_calibrator(sir, data = select(sir_sims, -value)), err)
})

test_that("vector-valued trajectories can be calibrated to", {
skip("Skipping because rbind_time is not working on non-scalars")
sir_age = "SPEC-sir_age.rds" |> test_cache_read()
sir_age_sims = "TRAJ-sir_age_10_infection.rds" |> test_cache_read()
sir_age_cal = mp_tmb_calibrator(sir_age
, data = sir_age_sims
, par = "tau"
, traj = "infection"
, outputs = "sim_infection"
)
sir_age_cal
expect_equal(
sir_age_cal$cal_spec |> mp_simulator(10, "infection") |> mp_trajectory() |> pull(value)
, sir_age_cal$cal_spec |> mp_simulator(10, "sim_infection") |> mp_final() |> pull(value)
)
})