diff --git a/R/mp_tmb_calibrator.R b/R/mp_tmb_calibrator.R index 16177861..7edb1e2e 100644 --- a/R/mp_tmb_calibrator.R +++ b/R/mp_tmb_calibrator.R @@ -488,19 +488,27 @@ TMBCalDataStruc = function(data, time) { } FALSE } - - data = rename_synonyms(data - , time = c( - "time", "Time", "ID", "time_id", "id", "date", "Date" - , "time_step", "timeStep", "TimeStep" - ) - , matrix = c( - "matrix", "Matrix", "mat", "Mat", "variable", "var", "Variable", "Var" - ) - , row = c("row", "Row") - , col = c("col", "Col", "column", "Column") - , value = c("value", "Value", "val", "Val", "default", "Default") - ) + + syns <- list(time = c("time", "Time", "ID", "time_id", "id", + "date", "Date", + "time_step", "timeStep", "TimeStep"), + matrix = c("matrix", "Matrix", "mat", "Mat", + "variable", "var", "Variable", "Var"), + row = c("row", "Row"), + col = c("col", "Col", "column", "Column"), + value = c("value", "Value", "val", "Val", "default", "Default")) + data = do.call(rename_synonyms, c(list(data), syns)) + + ## check presence (row/col not required?) + for (m in setdiff(names(syns), c("row", "col"))) { + if (is.null(data[[m]])) { + stop( + "Supplied data did not contain a column called '", m, "' ", + "(or its synonyms: ", + paste(sprintf("'%s'", syns[[m]]), collapse = ", "), ")" + ) + } + } time_column_test_value = data$time if (is.character(data$time)) { original_coercer = as.character @@ -1270,9 +1278,23 @@ TMBTraj.character = function( ## Depended upon to create a character vector of output variables to fit to self$outputs = function() names(self$list) + # ff = function(dat) { + # xx = split(dat, dat$row) + # time_ids = lapply(xx, getElement, "time_ids") + # row = lapply(xx, getElement, "row") |> lapply(unique) + # if (!all(vapply(row, length, integer(1L)) == 1L)) { + # stop("Ca") + # } + # if (any()) + # list(time_ids, row) + # } + ## implemented methods self$obs = function() lapply(self$list, getElement, "value") - self$obs_times = function() lapply(self$list, getElement, "time_ids") + self$obs_times = function() { + #split(traj$list$infection, traj$list$infection$row) + lapply(self$list, getElement, "time_ids") + } self$distr_params = function() { switch( getOption("macpan2_default_loss")[1L] diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index d9de5a94..bdaa2fca 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -24,6 +24,7 @@ sims = list( , sir_10_I = mp_simulator(all_specs$sir, 10L, "I") , sir_50_infection = mp_simulator(all_specs$sir, 50L, "infection") , sir_50_I = mp_simulator(all_specs$sir, 50L, "I") + , sir_age_10_infection = mp_simulator(all_specs$sir_age, 10L, "infection") ) for (obj in names(sims)) { diff --git a/tests/testthat/test-calibrator-traj.R b/tests/testthat/test-calibrator-traj.R index 8a31733d..9b8f97b0 100644 --- a/tests/testthat/test-calibrator-traj.R +++ b/tests/testthat/test-calibrator-traj.R @@ -1,4 +1,3 @@ -library(macpan2); library(testthat); library(dplyr); library(tidyr); library(ggplot2) test_that("bad outputs give warnings", { sir = mp_tmb_library("starter_models", "sir", package = "macpan2") expect_warning( @@ -60,3 +59,29 @@ test_that("trajectories specified with likelihood distributions end up in calibr ) }) + +test_that("missing required columns in calibration data throw errors", { + sir = "SPEC-sir.rds" |> test_cache_read() + sir_sims = "TRAJ-sir_5_state.rds" |> test_cache_read() + err = "Supplied data did not contain a column called" + expect_error(mp_tmb_calibrator(sir, data = select(sir_sims, -time)), err) + expect_error(mp_tmb_calibrator(sir, data = select(sir_sims, -matrix)), err) + expect_error(mp_tmb_calibrator(sir, data = select(sir_sims, -value)), err) +}) + +test_that("vector-valued trajectories can be calibrated to", { + skip("Skipping because rbind_time is not working on non-scalars") + sir_age = "SPEC-sir_age.rds" |> test_cache_read() + sir_age_sims = "TRAJ-sir_age_10_infection.rds" |> test_cache_read() + sir_age_cal = mp_tmb_calibrator(sir_age + , data = sir_age_sims + , par = "tau" + , traj = "infection" + , outputs = "sim_infection" + ) + sir_age_cal + expect_equal( + sir_age_cal$cal_spec |> mp_simulator(10, "infection") |> mp_trajectory() |> pull(value) + , sir_age_cal$cal_spec |> mp_simulator(10, "sim_infection") |> mp_final() |> pull(value) + ) +})