diff --git a/DESCRIPTION b/DESCRIPTION index b892b46..a20d495 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: irxforge Title: Forging data for pharmacometric analyses -Version: 0.0.0.9000 +Version: 0.0.0.9001 Authors@R: c( person("Ron", "Keizer", email = "ron@insight-rx.com", role = c("cre", "aut")), person("Michael", "McCarthy", email = "michael.mccarthy@insight-rx.com", role = "ctb"), diff --git a/R/sample_covariates_mvtnorm.R b/R/sample_covariates_mvtnorm.R index 14317a4..67aa689 100644 --- a/R/sample_covariates_mvtnorm.R +++ b/R/sample_covariates_mvtnorm.R @@ -1,13 +1,31 @@ #' Sample covariates from multivariate normal distributions #' +#' Samples from a multivariate normal distribution either derived from observed +#' data or specified directly via `means` plus a covariance matrix (`sigma`) or +#' standard deviations (`sd`). +#' #' @param data data.frame (n x p) containing the original, observed, #' time-invariant covariates (ID should not be included) that will be used to -#' inform the imputation. +#' inform the imputation. Can be `NULL` when `means` and either `sigma` or +#' `sd` are supplied directly. Ignored with a warning when `means` is also +#' provided. +#' @param means named numeric vector of means for each covariate. When supplied, +#' the distribution is specified directly and `data` is ignored. Must be +#' supplied together with either `sigma` or `sd` when `data` is `NULL`. +#' @param sigma numeric matrix (p x p) giving the full covariance matrix. +#' Takes precedence over `sd` when both are provided. If both `sigma` and +#' `means` are named, the matrix is reordered to match the order of `means`; +#' names must refer to the same set of variables. +#' @param sd numeric vector of standard deviations. Used to construct a +#' diagonal covariance matrix (`diag(sd^2)`) when `sigma` is not provided. +#' If both `sd` and `means` are named, `sd` is reordered to match the order +#' of `means`; names must refer to the same set of variables. #' @param cat_covs character vector containing the names of the categorical #' covariates in orgCovs. -#' @param n_subjects number of simulated subjects, default is the number of -#' subjects in the data. +#' @param n_subjects number of simulated subjects. Defaults to `nrow(data)` +#' when `data` is provided; required (no default) when `data` is `NULL`. #' @param exponential sample from exponential distribution? Default `FALSE`. +#' Only applies when means/covariance are derived from `data`. #' @param conditional description... #' @param seed integer random seed passed to [set.seed()] for reproducibility. #' Default `NULL` does not set a seed. @@ -18,11 +36,29 @@ #' #' @note missing values in `data` must be coded as NA #' +#' @examples +#' sample_covariates_mvtnorm( +#' means = c(WT = 70, HT = 170, AGE = 50), +#' sd = c(10, 20, 5), +#' n_subjects = 100 +#' ) +#' sample_covariates_mvtnorm( +#' means = c(WT = 70, HT = 170, AGE = 50), +#' sigma = matrix( +#' c(100, 20, 12, +#' 20, 400, 10, +#' 12, 10, 25), nrow = 3, ncol = 3), +#' n_subjects = 100 +#' ) +#' #' @export sample_covariates_mvtnorm <- function( - data, + data = NULL, + means = NULL, + sigma = NULL, + sd = NULL, cat_covs = NULL, - n_subjects = nrow(data), + n_subjects = if (!is.null(data)) nrow(data) else stop("`n_subjects` must be specified when `data` is NULL."), exponential = FALSE, conditional = NULL, seed = NULL, @@ -30,11 +66,62 @@ sample_covariates_mvtnorm <- function( ) { if (!is.null(seed)) set.seed(seed) + ## Branch: distribution specified directly via means + sigma/sd + if (!is.null(means)) { + if (!is.null(data)) { + warning("`data` is ignored when `means` is provided.") + } + if (is.null(sigma) && is.null(sd)) { + stop("When `means` is supplied, either `sigma` or `sd` must also be provided.") + } + if (!is.null(sigma)) { + cov_mat <- sigma + if (!is.null(names(means)) && !is.null(rownames(cov_mat))) { + if (!setequal(rownames(cov_mat), names(means))) { + stop("Names of `sigma` must match `names(means)` (same set of names).") + } + cov_mat <- cov_mat[names(means), names(means), drop = FALSE] + } + } else { + if (length(sd) != length(means)) { + stop("`sd` must have the same length as `means`.") + } + if (!is.null(names(means)) && !is.null(names(sd))) { + if (!setequal(names(sd), names(means))) { + stop("Names of `sd` must match `names(means)` (same set of names).") + } + sd <- sd[names(means)] + } + cov_mat <- diag(sd^2) + if (!is.null(names(means))) { + rownames(cov_mat) <- names(means) + colnames(cov_mat) <- names(means) + } else if (!is.null(names(sd))) { + rownames(cov_mat) <- names(sd) + colnames(cov_mat) <- names(sd) + } + } + out <- mvtnorm::rmvnorm( + n_subjects, + mean = means, + sigma = cov_mat, + ... + ) |> + as.data.frame() + if (!is.null(names(means))) names(out) <- names(means) + return(out) + } + + ## Branch: derive distribution from data + if (is.null(data)) { + stop("Either `data` or `means` with `sigma`/`sd` must be provided.") + } + if(!is.null(conditional)) { for(key in names(conditional)) { data <- dplyr::filter( data, - .data[[key]] >= min(conditional[[key]]) & + .data[[key]] >= min(conditional[[key]]) & .data[[key]] <= max(conditional[[key]]) ) } @@ -44,33 +131,33 @@ sample_covariates_mvtnorm <- function( # FIXME: This code does nothing currently... is this function intended to # work with categorical covariates? or only continuous? If the latter, how # do we handle categorical? - cont_covs <- setdiff(names(data), cat_covs) + cont_covs <- setdiff(names(data), cat_covs) miss_vars <- names(data)[colSums(is.na(data)) > 0] ## Get distribution and sample if(exponential) { # FIXME: This fails if there are zeroes or negative numbers. Should add some # safety rails. - means <- apply(data, 2, function(x) mean(log(x))) - cov_mat <- stats::cov(log(data)) + data_means <- apply(data, 2, function(x) mean(log(x))) + cov_mat <- stats::cov(log(data)) out <- mvtnorm::rmvnorm( - n_subjects, - mean = means, + n_subjects, + mean = data_means, sigma = cov_mat ) |> exp() |> as.data.frame() } else { - means <- apply(data, 2, mean) + data_means <- apply(data, 2, mean) cov_mat <- stats::cov(data) out <- mvtnorm::rmvnorm( - n_subjects, - mean = means, + n_subjects, + mean = data_means, sigma = cov_mat ) |> as.data.frame() } - + if (tibble::is_tibble(data)) out <- tibble::as_tibble(out) - out + out } diff --git a/man/sample_covariates_mvtnorm.Rd b/man/sample_covariates_mvtnorm.Rd index f7e61a4..201b817 100644 --- a/man/sample_covariates_mvtnorm.Rd +++ b/man/sample_covariates_mvtnorm.Rd @@ -5,9 +5,13 @@ \title{Sample covariates from multivariate normal distributions} \usage{ sample_covariates_mvtnorm( - data, + data = NULL, + means = NULL, + sigma = NULL, + sd = NULL, cat_covs = NULL, - n_subjects = nrow(data), + n_subjects = if (!is.null(data)) nrow(data) else + stop("`n_subjects` must be specified when `data` is NULL."), exponential = FALSE, conditional = NULL, seed = NULL, @@ -17,15 +21,32 @@ sample_covariates_mvtnorm( \arguments{ \item{data}{data.frame (n x p) containing the original, observed, time-invariant covariates (ID should not be included) that will be used to -inform the imputation.} +inform the imputation. Can be \code{NULL} when \code{means} and either \code{sigma} or +\code{sd} are supplied directly. Ignored with a warning when \code{means} is also +provided.} + +\item{means}{named numeric vector of means for each covariate. When supplied, +the distribution is specified directly and \code{data} is ignored. Must be +supplied together with either \code{sigma} or \code{sd} when \code{data} is \code{NULL}.} + +\item{sigma}{numeric matrix (p x p) giving the full covariance matrix. +Takes precedence over \code{sd} when both are provided. If both \code{sigma} and +\code{means} are named, the matrix is reordered to match the order of \code{means}; +names must refer to the same set of variables.} + +\item{sd}{numeric vector of standard deviations. Used to construct a +diagonal covariance matrix (\code{diag(sd^2)}) when \code{sigma} is not provided. +If both \code{sd} and \code{means} are named, \code{sd} is reordered to match the order +of \code{means}; names must refer to the same set of variables.} \item{cat_covs}{character vector containing the names of the categorical covariates in orgCovs.} -\item{n_subjects}{number of simulated subjects, default is the number of -subjects in the data.} +\item{n_subjects}{number of simulated subjects. Defaults to \code{nrow(data)} +when \code{data} is provided; required (no default) when \code{data} is \code{NULL}.} -\item{exponential}{sample from exponential distribution? Default \code{FALSE}.} +\item{exponential}{sample from exponential distribution? Default \code{FALSE}. +Only applies when means/covariance are derived from \code{data}.} \item{conditional}{description...} @@ -39,8 +60,26 @@ a data.frame with the simulated covariates, with \code{n_subjects} rows and \code{p} columns } \description{ -Sample covariates from multivariate normal distributions +Samples from a multivariate normal distribution either derived from observed +data or specified directly via \code{means} plus a covariance matrix (\code{sigma}) or +standard deviations (\code{sd}). } \note{ missing values in \code{data} must be coded as NA } +\examples{ +sample_covariates_mvtnorm( + means = c(WT = 70, HT = 170, AGE = 50), + sd = c(10, 20, 5), + n_subjects = 100 +) +sample_covariates_mvtnorm( + means = c(WT = 70, HT = 170, AGE = 50), + sigma = matrix( + c(100, 20, 12, + 20, 400, 10, + 12, 10, 25), nrow = 3, ncol = 3), + n_subjects = 100 +) + +} diff --git a/tests/testthat/test-sample_covariates_mvtnorm.R b/tests/testthat/test-sample_covariates_mvtnorm.R index 1f53e00..3ba6ffd 100644 --- a/tests/testthat/test-sample_covariates_mvtnorm.R +++ b/tests/testthat/test-sample_covariates_mvtnorm.R @@ -92,3 +92,128 @@ test_that("different seeds produce different output", { out2 <- sample_covariates_mvtnorm(dat, n_subjects = 20, seed = 2) expect_false(identical(out1, out2)) }) + +# --- Direct distribution specification (means + sigma / sds) --- + +test_that("means + sigma samples correct number of rows and columns", { + mu <- c(x = 10, y = 20) + S <- matrix(c(4, 1, 1, 9), nrow = 2, dimnames = list(c("x","y"), c("x","y"))) + out <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 100) + expect_equal(nrow(out), 100) + expect_equal(ncol(out), 2) + expect_named(out, c("x", "y")) +}) + +test_that("means + sigma samples near the specified mean (large n)", { + mu <- c(AGE = 40, WT = 70) + S <- diag(c(25, 100)) + out <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 2000, seed = 42) + expect_equal(mean(out$AGE), 40, tolerance = 1.5) + expect_equal(mean(out$WT), 70, tolerance = 3) +}) + +test_that("means + sd constructs diagonal covariance and samples correctly", { + mu <- c(x = 5, y = 50) + sds <- c(x = 1, y = 10) + out <- sample_covariates_mvtnorm(means = mu, sd = sds, n_subjects = 2000, seed = 7) + expect_named(out, c("x", "y")) + expect_equal(mean(out$x), 5, tolerance = 0.25) + expect_equal(mean(out$y), 50, tolerance = 2.5) + expect_equal(sd(out$x), 1, tolerance = 0.12) + expect_equal(sd(out$y), 10, tolerance = 1.2) +}) + +test_that("means + sd: length mismatch raises error", { + expect_error( + sample_covariates_mvtnorm(means = c(a = 1, b = 2), sd = c(1, 2, 3), n_subjects = 10), + "`sd` must have the same length" + ) +}) + +test_that("means + sd: reorders sd to match means when both are named", { + mu <- c(y = 50, x = 5) # y first + sds <- c(x = 1, y = 10) # x first — should be reordered + out <- sample_covariates_mvtnorm(means = mu, sd = sds, n_subjects = 2000, seed = 7) + expect_named(out, c("y", "x")) + expect_equal(sd(out$y), 10, tolerance = 1.2) + expect_equal(sd(out$x), 1, tolerance = 0.12) +}) + +test_that("means + sd: name mismatch raises error", { + expect_error( + sample_covariates_mvtnorm(means = c(a = 1, b = 2), sd = c(a = 1, c = 2), n_subjects = 10), + "Names of `sd` must match" + ) +}) + +test_that("means + sd: unnamed sd is assumed to be in the right order", { + mu <- c(x = 5, y = 50) + out <- sample_covariates_mvtnorm(means = mu, sd = c(1, 10), n_subjects = 2000, seed = 7) + expect_named(out, c("x", "y")) + expect_equal(sd(out$x), 1, tolerance = 0.12) + expect_equal(sd(out$y), 10, tolerance = 1.2) +}) + +test_that("means + sigma: reorders sigma to match means when both are named", { + mu <- c(y = 20, x = 10) # y first + S <- matrix(c(4, 1, 1, 9), nrow = 2, dimnames = list(c("x","y"), c("x","y"))) + out <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 100, seed = 1) + expect_named(out, c("y", "x")) +}) + +test_that("means + sigma: name mismatch raises error", { + mu <- c(a = 1, b = 2) + S <- matrix(c(1, 0, 0, 1), nrow = 2, dimnames = list(c("a","c"), c("a","c"))) + expect_error( + sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 10), + "Names of `sigma` must match" + ) +}) + +test_that("means + sigma: unnamed sigma is assumed to be in the right order", { + mu <- c(x = 10, y = 20) + S <- diag(2) # no dimnames + out <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 100, seed = 1) + expect_named(out, c("x", "y")) +}) + +test_that("means without sigma or sd raises error", { + expect_error( + sample_covariates_mvtnorm(means = c(x = 1), n_subjects = 10), + "either `sigma` or `sd` must also be provided" + ) +}) + +test_that("data = NULL without means raises error", { + expect_error( + sample_covariates_mvtnorm(data = NULL, n_subjects = 10), + "Either `data` or `means`" + ) +}) + +test_that("warning is issued when both data and means are provided", { + dat <- data.frame(x = rnorm(50), y = rnorm(50)) + mu <- c(x = 0, y = 0) + S <- diag(2) + expect_warning( + sample_covariates_mvtnorm(data = dat, means = mu, sigma = S, n_subjects = 10), + "`data` is ignored" + ) +}) + +test_that("n_subjects is required when data is NULL", { + mu <- c(x = 0) + S <- matrix(1) + expect_error( + sample_covariates_mvtnorm(means = mu, sigma = S), + "n_subjects.*must be specified" + ) +}) + +test_that("seed produces reproducible output with means + sigma", { + mu <- c(x = 0, y = 1) + S <- diag(2) + out1 <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 30, seed = 5) + out2 <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 30, seed = 5) + expect_equal(out1, out2) +})