From 9a0167cdfebae3634028b9f13e516de7ee5e1ca5 Mon Sep 17 00:00:00 2001 From: Ron Keizer Date: Mon, 16 Mar 2026 08:31:23 -0700 Subject: [PATCH 1/6] feat: allow sampling from a directly specified multivariate distribution Add `means`, `sigma`, and `sd` parameters to `sample_covariates_mvtnorm()` so callers can sample from a fully specified distribution without needing observed data. Passing both `data` and `means` issues a warning that `data` is ignored. Co-Authored-By: Claude Sonnet 4.6 --- DESCRIPTION | 2 +- R/sample_covariates_mvtnorm.R | 90 +++++++++++++++---- man/sample_covariates_mvtnorm.Rd | 35 ++++++-- .../testthat/test-sample_covariates_mvtnorm.R | 78 ++++++++++++++++ 4 files changed, 181 insertions(+), 24 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index b892b46..a20d495 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: irxforge Title: Forging data for pharmacometric analyses -Version: 0.0.0.9000 +Version: 0.0.0.9001 Authors@R: c( person("Ron", "Keizer", email = "ron@insight-rx.com", role = c("cre", "aut")), person("Michael", "McCarthy", email = "michael.mccarthy@insight-rx.com", role = "ctb"), diff --git a/R/sample_covariates_mvtnorm.R b/R/sample_covariates_mvtnorm.R index 14317a4..66f560b 100644 --- a/R/sample_covariates_mvtnorm.R +++ b/R/sample_covariates_mvtnorm.R @@ -1,13 +1,29 @@ #' Sample covariates from multivariate normal distributions #' +#' Samples from a multivariate normal distribution either derived from observed +#' data or specified directly via `means` plus a covariance matrix (`sigma`) or +#' standard deviations (`sd`). +#' #' @param data data.frame (n x p) containing the original, observed, #' time-invariant covariates (ID should not be included) that will be used to -#' inform the imputation. +#' inform the imputation. Can be `NULL` when `means` and either `sigma` or +#' `sd` are supplied directly. Ignored with a warning when `means` is also +#' provided. +#' @param means named numeric vector of means for each covariate. When supplied, +#' the distribution is specified directly and `data` is ignored. Must be +#' supplied together with either `sigma` or `sd` when `data` is `NULL`. +#' @param sigma named numeric matrix (p x p) giving the full covariance matrix. +#' Takes precedence over `sd` when both are provided. Column and row names +#' must match the names in `means`. +#' @param sd named numeric vector of standard deviations. Used to construct a +#' diagonal covariance matrix (`diag(sd^2)`) when `sigma` is not provided. +#' Names must match the names in `means`. #' @param cat_covs character vector containing the names of the categorical #' covariates in orgCovs. -#' @param n_subjects number of simulated subjects, default is the number of -#' subjects in the data. +#' @param n_subjects number of simulated subjects. Defaults to `nrow(data)` +#' when `data` is provided; required (no default) when `data` is `NULL`. #' @param exponential sample from exponential distribution? Default `FALSE`. +#' Only applies when means/covariance are derived from `data`. #' @param conditional description... #' @param seed integer random seed passed to [set.seed()] for reproducibility. #' Default `NULL` does not set a seed. @@ -20,9 +36,12 @@ #' #' @export sample_covariates_mvtnorm <- function( - data, + data = NULL, + means = NULL, + sigma = NULL, + sd = NULL, cat_covs = NULL, - n_subjects = nrow(data), + n_subjects = if (!is.null(data)) nrow(data) else stop("`n_subjects` must be specified when `data` is NULL."), exponential = FALSE, conditional = NULL, seed = NULL, @@ -30,11 +49,50 @@ sample_covariates_mvtnorm <- function( ) { if (!is.null(seed)) set.seed(seed) + ## Branch: distribution specified directly via means + sigma/sd + if (!is.null(means)) { + if (!is.null(data)) { + warning("`data` is ignored when `means` is provided.") + } + if (is.null(sigma) && is.null(sd)) { + stop("When `means` is supplied, either `sigma` or `sd` must also be provided.") + } + if (!is.null(sigma)) { + cov_mat <- sigma + } else { + if (length(sd) != length(means)) { + stop("`sd` must have the same length as `means`.") + } + cov_mat <- diag(sd^2) + if (!is.null(names(sd))) { + rownames(cov_mat) <- names(sd) + colnames(cov_mat) <- names(sd) + } else if (!is.null(names(means))) { + rownames(cov_mat) <- names(means) + colnames(cov_mat) <- names(means) + } + } + out <- mvtnorm::rmvnorm( + n_subjects, + mean = means, + sigma = cov_mat, + ... + ) |> + as.data.frame() + if (!is.null(names(means))) names(out) <- names(means) + return(out) + } + + ## Branch: derive distribution from data + if (is.null(data)) { + stop("Either `data` or `means` with `sigma`/`sd` must be provided.") + } + if(!is.null(conditional)) { for(key in names(conditional)) { data <- dplyr::filter( data, - .data[[key]] >= min(conditional[[key]]) & + .data[[key]] >= min(conditional[[key]]) & .data[[key]] <= max(conditional[[key]]) ) } @@ -44,33 +102,33 @@ sample_covariates_mvtnorm <- function( # FIXME: This code does nothing currently... is this function intended to # work with categorical covariates? or only continuous? If the latter, how # do we handle categorical? - cont_covs <- setdiff(names(data), cat_covs) + cont_covs <- setdiff(names(data), cat_covs) miss_vars <- names(data)[colSums(is.na(data)) > 0] ## Get distribution and sample if(exponential) { # FIXME: This fails if there are zeroes or negative numbers. Should add some # safety rails. - means <- apply(data, 2, function(x) mean(log(x))) - cov_mat <- stats::cov(log(data)) + data_means <- apply(data, 2, function(x) mean(log(x))) + cov_mat <- stats::cov(log(data)) out <- mvtnorm::rmvnorm( - n_subjects, - mean = means, + n_subjects, + mean = data_means, sigma = cov_mat ) |> exp() |> as.data.frame() } else { - means <- apply(data, 2, mean) + data_means <- apply(data, 2, mean) cov_mat <- stats::cov(data) out <- mvtnorm::rmvnorm( - n_subjects, - mean = means, + n_subjects, + mean = data_means, sigma = cov_mat ) |> as.data.frame() } - + if (tibble::is_tibble(data)) out <- tibble::as_tibble(out) - out + out } diff --git a/man/sample_covariates_mvtnorm.Rd b/man/sample_covariates_mvtnorm.Rd index f7e61a4..fbe0bd2 100644 --- a/man/sample_covariates_mvtnorm.Rd +++ b/man/sample_covariates_mvtnorm.Rd @@ -5,9 +5,13 @@ \title{Sample covariates from multivariate normal distributions} \usage{ sample_covariates_mvtnorm( - data, + data = NULL, + means = NULL, + sigma = NULL, + sd = NULL, cat_covs = NULL, - n_subjects = nrow(data), + n_subjects = if (!is.null(data)) nrow(data) else + stop("`n_subjects` must be specified when `data` is NULL."), exponential = FALSE, conditional = NULL, seed = NULL, @@ -17,15 +21,30 @@ sample_covariates_mvtnorm( \arguments{ \item{data}{data.frame (n x p) containing the original, observed, time-invariant covariates (ID should not be included) that will be used to -inform the imputation.} +inform the imputation. Can be \code{NULL} when \code{means} and either \code{sigma} or +\code{sd} are supplied directly. Ignored with a warning when \code{means} is also +provided.} + +\item{means}{named numeric vector of means for each covariate. When supplied, +the distribution is specified directly and \code{data} is ignored. Must be +supplied together with either \code{sigma} or \code{sd} when \code{data} is \code{NULL}.} + +\item{sigma}{named numeric matrix (p x p) giving the full covariance matrix. +Takes precedence over \code{sd} when both are provided. Column and row names +must match the names in \code{means}.} + +\item{sd}{named numeric vector of standard deviations. Used to construct a +diagonal covariance matrix (\code{diag(sd^2)}) when \code{sigma} is not provided. +Names must match the names in \code{means}.} \item{cat_covs}{character vector containing the names of the categorical covariates in orgCovs.} -\item{n_subjects}{number of simulated subjects, default is the number of -subjects in the data.} +\item{n_subjects}{number of simulated subjects. Defaults to \code{nrow(data)} +when \code{data} is provided; required (no default) when \code{data} is \code{NULL}.} -\item{exponential}{sample from exponential distribution? Default \code{FALSE}.} +\item{exponential}{sample from exponential distribution? Default \code{FALSE}. +Only applies when means/covariance are derived from \code{data}.} \item{conditional}{description...} @@ -39,7 +58,9 @@ a data.frame with the simulated covariates, with \code{n_subjects} rows and \code{p} columns } \description{ -Sample covariates from multivariate normal distributions +Samples from a multivariate normal distribution either derived from observed +data or specified directly via \code{means} plus a covariance matrix (\code{sigma}) or +standard deviations (\code{sd}). } \note{ missing values in \code{data} must be coded as NA diff --git a/tests/testthat/test-sample_covariates_mvtnorm.R b/tests/testthat/test-sample_covariates_mvtnorm.R index 1f53e00..406d9ff 100644 --- a/tests/testthat/test-sample_covariates_mvtnorm.R +++ b/tests/testthat/test-sample_covariates_mvtnorm.R @@ -92,3 +92,81 @@ test_that("different seeds produce different output", { out2 <- sample_covariates_mvtnorm(dat, n_subjects = 20, seed = 2) expect_false(identical(out1, out2)) }) + +# --- Direct distribution specification (means + sigma / sds) --- + +test_that("means + sigma samples correct number of rows and columns", { + mu <- c(x = 10, y = 20) + S <- matrix(c(4, 1, 1, 9), nrow = 2, dimnames = list(c("x","y"), c("x","y"))) + out <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 100) + expect_equal(nrow(out), 100) + expect_equal(ncol(out), 2) + expect_named(out, c("x", "y")) +}) + +test_that("means + sigma samples near the specified mean (large n)", { + mu <- c(AGE = 40, WT = 70) + S <- diag(c(25, 100)) + out <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 5000, seed = 42) + expect_equal(mean(out$AGE), 40, tolerance = 1) + expect_equal(mean(out$WT), 70, tolerance = 2) +}) + +test_that("means + sd constructs diagonal covariance and samples correctly", { + mu <- c(x = 5, y = 50) + sds <- c(x = 1, y = 10) + out <- sample_covariates_mvtnorm(means = mu, sd = sds, n_subjects = 5000, seed = 7) + expect_named(out, c("x", "y")) + expect_equal(mean(out$x), 5, tolerance = 0.2) + expect_equal(mean(out$y), 50, tolerance = 2) + expect_equal(sd(out$x), 1, tolerance = 0.1) + expect_equal(sd(out$y), 10, tolerance = 1) +}) + +test_that("means + sd: length mismatch raises error", { + expect_error( + sample_covariates_mvtnorm(means = c(a = 1, b = 2), sd = c(1, 2, 3), n_subjects = 10), + "`sd` must have the same length" + ) +}) + +test_that("means without sigma or sd raises error", { + expect_error( + sample_covariates_mvtnorm(means = c(x = 1), n_subjects = 10), + "either `sigma` or `sd` must also be provided" + ) +}) + +test_that("data = NULL without means raises error", { + expect_error( + sample_covariates_mvtnorm(data = NULL, n_subjects = 10), + "Either `data` or `means`" + ) +}) + +test_that("warning is issued when both data and means are provided", { + dat <- data.frame(x = rnorm(50), y = rnorm(50)) + mu <- c(x = 0, y = 0) + S <- diag(2) + expect_warning( + sample_covariates_mvtnorm(data = dat, means = mu, sigma = S, n_subjects = 10), + "`data` is ignored" + ) +}) + +test_that("n_subjects is required when data is NULL", { + mu <- c(x = 0) + S <- matrix(1) + expect_error( + sample_covariates_mvtnorm(means = mu, sigma = S), + "n_subjects.*must be specified" + ) +}) + +test_that("seed produces reproducible output with means + sigma", { + mu <- c(x = 0, y = 1) + S <- diag(2) + out1 <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 30, seed = 5) + out2 <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 30, seed = 5) + expect_equal(out1, out2) +}) From 356603e3fafdd918914e0177d071f3a167366ab1 Mon Sep 17 00:00:00 2001 From: Ron Keizer Date: Mon, 16 Mar 2026 21:07:59 -0700 Subject: [PATCH 2/6] add example to covariates sampler --- R/sample_covariates_mvtnorm.R | 15 +++++++++++++++ man/sample_covariates_mvtnorm.Rd | 16 ++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/R/sample_covariates_mvtnorm.R b/R/sample_covariates_mvtnorm.R index 66f560b..56c55eb 100644 --- a/R/sample_covariates_mvtnorm.R +++ b/R/sample_covariates_mvtnorm.R @@ -34,6 +34,21 @@ #' #' @note missing values in `data` must be coded as NA #' +#' @examples +#' sample_covariates_mvtnorm( +#' means = c(WT = 70, HT = 170, AGE = 50), +#' sd = c(10, 20, 5), +#' n_subjects = 100 +#' ) +#' sample_covariates_mvtnorm( +#' means = c(WT = 70, HT = 170, AGE = 50), +#' sigma = matrix( +#' c(100, 20, 12, +#' 20, 400, 10, +#' 12, 10, 25), nrow = 3, ncol = 3), +#' n_subjects = 100 +#' ) +#' #' @export sample_covariates_mvtnorm <- function( data = NULL, diff --git a/man/sample_covariates_mvtnorm.Rd b/man/sample_covariates_mvtnorm.Rd index fbe0bd2..da3f273 100644 --- a/man/sample_covariates_mvtnorm.Rd +++ b/man/sample_covariates_mvtnorm.Rd @@ -65,3 +65,19 @@ standard deviations (\code{sd}). \note{ missing values in \code{data} must be coded as NA } +\examples{ +sample_covariates_mvtnorm( + means = c(WT = 70, HT = 170, AGE = 50), + sd = c(10, 20, 5), + n_subjects = 100 +) +sample_covariates_mvtnorm( + means = c(WT = 70, HT = 170, AGE = 50), + sigma = matrix( + c(100, 20, 12, + 20, 400, 10, + 12, 10, 25), nrow = 3, ncol = 3), + n_subjects = 100 +) + +} From beec91b198e17bed9114b63349d9c1320c579a70 Mon Sep 17 00:00:00 2001 From: roninsightrx Date: Mon, 16 Mar 2026 21:10:26 -0700 Subject: [PATCH 3/6] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- R/sample_covariates_mvtnorm.R | 40 +++++++++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/R/sample_covariates_mvtnorm.R b/R/sample_covariates_mvtnorm.R index 56c55eb..2a13789 100644 --- a/R/sample_covariates_mvtnorm.R +++ b/R/sample_covariates_mvtnorm.R @@ -74,17 +74,49 @@ sample_covariates_mvtnorm <- function( } if (!is.null(sigma)) { cov_mat <- sigma + # Validate dimensions of `sigma` against `means` + if (!is.matrix(cov_mat)) { + stop("`sigma` must be a covariance matrix when `means` is supplied.") + } + if (nrow(cov_mat) != length(means) || ncol(cov_mat) != length(means)) { + stop("Dimensions of `sigma` must match the length of `means`.") + } + mean_names <- names(means) + row_nms <- rownames(cov_mat) + col_nms <- colnames(cov_mat) + if (!is.null(mean_names)) { + if (is.null(row_nms) || is.null(col_nms)) { + stop("When `means` is named, `sigma` must have row and column names.") + } + if (!setequal(row_nms, mean_names) || !setequal(col_nms, mean_names)) { + stop("Names of `sigma` must match `names(means)` (same set of names).") + } + # Reorder covariance matrix to match the order of `names(means)` + cov_mat <- cov_mat[mean_names, mean_names, drop = FALSE] + } } else { if (length(sd) != length(means)) { stop("`sd` must have the same length as `means`.") } + mean_names <- names(means) + sd_names <- names(sd) + # If both `sd` and `means` are named, validate and reorder `sd` + if (!is.null(sd_names) && !is.null(mean_names)) { + if (!setequal(sd_names, mean_names)) { + stop("Names of `sd` must match `names(means)` (same set of names).") + } + sd <- sd[mean_names] + } else if (!is.null(sd_names) && is.null(mean_names)) { + # Drop names to avoid relying on an arbitrary ordering when `means` is unnamed + sd <- unname(sd) + } cov_mat <- diag(sd^2) - if (!is.null(names(sd))) { + if (!is.null(mean_names)) { + rownames(cov_mat) <- mean_names + colnames(cov_mat) <- mean_names + } else if (!is.null(names(sd))) { rownames(cov_mat) <- names(sd) colnames(cov_mat) <- names(sd) - } else if (!is.null(names(means))) { - rownames(cov_mat) <- names(means) - colnames(cov_mat) <- names(means) } } out <- mvtnorm::rmvnorm( From 1a726e8c408cf9d30245e6853b698321e673e430 Mon Sep 17 00:00:00 2001 From: roninsightrx Date: Mon, 16 Mar 2026 21:11:10 -0700 Subject: [PATCH 4/6] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- tests/testthat/test-sample_covariates_mvtnorm.R | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/testthat/test-sample_covariates_mvtnorm.R b/tests/testthat/test-sample_covariates_mvtnorm.R index 406d9ff..9f6a7f1 100644 --- a/tests/testthat/test-sample_covariates_mvtnorm.R +++ b/tests/testthat/test-sample_covariates_mvtnorm.R @@ -107,20 +107,20 @@ test_that("means + sigma samples correct number of rows and columns", { test_that("means + sigma samples near the specified mean (large n)", { mu <- c(AGE = 40, WT = 70) S <- diag(c(25, 100)) - out <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 5000, seed = 42) - expect_equal(mean(out$AGE), 40, tolerance = 1) - expect_equal(mean(out$WT), 70, tolerance = 2) + out <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 2000, seed = 42) + expect_equal(mean(out$AGE), 40, tolerance = 1.5) + expect_equal(mean(out$WT), 70, tolerance = 3) }) test_that("means + sd constructs diagonal covariance and samples correctly", { mu <- c(x = 5, y = 50) sds <- c(x = 1, y = 10) - out <- sample_covariates_mvtnorm(means = mu, sd = sds, n_subjects = 5000, seed = 7) + out <- sample_covariates_mvtnorm(means = mu, sd = sds, n_subjects = 2000, seed = 7) expect_named(out, c("x", "y")) - expect_equal(mean(out$x), 5, tolerance = 0.2) - expect_equal(mean(out$y), 50, tolerance = 2) - expect_equal(sd(out$x), 1, tolerance = 0.1) - expect_equal(sd(out$y), 10, tolerance = 1) + expect_equal(mean(out$x), 5, tolerance = 0.25) + expect_equal(mean(out$y), 50, tolerance = 2.5) + expect_equal(sd(out$x), 1, tolerance = 0.12) + expect_equal(sd(out$y), 10, tolerance = 1.2) }) test_that("means + sd: length mismatch raises error", { From bb8cf68ae9529e1e459940e98de55e6c13ee61c1 Mon Sep 17 00:00:00 2001 From: Ron Keizer Date: Mon, 16 Mar 2026 21:17:37 -0700 Subject: [PATCH 5/6] revert: remove name validation requirements for sigma and sd Co-Authored-By: Claude Sonnet 4.6 --- R/sample_covariates_mvtnorm.R | 48 +++++--------------------------- man/sample_covariates_mvtnorm.Rd | 10 +++---- 2 files changed, 11 insertions(+), 47 deletions(-) diff --git a/R/sample_covariates_mvtnorm.R b/R/sample_covariates_mvtnorm.R index 2a13789..4d445a6 100644 --- a/R/sample_covariates_mvtnorm.R +++ b/R/sample_covariates_mvtnorm.R @@ -12,12 +12,10 @@ #' @param means named numeric vector of means for each covariate. When supplied, #' the distribution is specified directly and `data` is ignored. Must be #' supplied together with either `sigma` or `sd` when `data` is `NULL`. -#' @param sigma named numeric matrix (p x p) giving the full covariance matrix. -#' Takes precedence over `sd` when both are provided. Column and row names -#' must match the names in `means`. -#' @param sd named numeric vector of standard deviations. Used to construct a +#' @param sigma numeric matrix (p x p) giving the full covariance matrix. +#' Takes precedence over `sd` when both are provided. +#' @param sd numeric vector of standard deviations. Used to construct a #' diagonal covariance matrix (`diag(sd^2)`) when `sigma` is not provided. -#' Names must match the names in `means`. #' @param cat_covs character vector containing the names of the categorical #' covariates in orgCovs. #' @param n_subjects number of simulated subjects. Defaults to `nrow(data)` @@ -74,49 +72,17 @@ sample_covariates_mvtnorm <- function( } if (!is.null(sigma)) { cov_mat <- sigma - # Validate dimensions of `sigma` against `means` - if (!is.matrix(cov_mat)) { - stop("`sigma` must be a covariance matrix when `means` is supplied.") - } - if (nrow(cov_mat) != length(means) || ncol(cov_mat) != length(means)) { - stop("Dimensions of `sigma` must match the length of `means`.") - } - mean_names <- names(means) - row_nms <- rownames(cov_mat) - col_nms <- colnames(cov_mat) - if (!is.null(mean_names)) { - if (is.null(row_nms) || is.null(col_nms)) { - stop("When `means` is named, `sigma` must have row and column names.") - } - if (!setequal(row_nms, mean_names) || !setequal(col_nms, mean_names)) { - stop("Names of `sigma` must match `names(means)` (same set of names).") - } - # Reorder covariance matrix to match the order of `names(means)` - cov_mat <- cov_mat[mean_names, mean_names, drop = FALSE] - } } else { if (length(sd) != length(means)) { stop("`sd` must have the same length as `means`.") } - mean_names <- names(means) - sd_names <- names(sd) - # If both `sd` and `means` are named, validate and reorder `sd` - if (!is.null(sd_names) && !is.null(mean_names)) { - if (!setequal(sd_names, mean_names)) { - stop("Names of `sd` must match `names(means)` (same set of names).") - } - sd <- sd[mean_names] - } else if (!is.null(sd_names) && is.null(mean_names)) { - # Drop names to avoid relying on an arbitrary ordering when `means` is unnamed - sd <- unname(sd) - } cov_mat <- diag(sd^2) - if (!is.null(mean_names)) { - rownames(cov_mat) <- mean_names - colnames(cov_mat) <- mean_names - } else if (!is.null(names(sd))) { + if (!is.null(names(sd))) { rownames(cov_mat) <- names(sd) colnames(cov_mat) <- names(sd) + } else if (!is.null(names(means))) { + rownames(cov_mat) <- names(means) + colnames(cov_mat) <- names(means) } } out <- mvtnorm::rmvnorm( diff --git a/man/sample_covariates_mvtnorm.Rd b/man/sample_covariates_mvtnorm.Rd index da3f273..a06e06d 100644 --- a/man/sample_covariates_mvtnorm.Rd +++ b/man/sample_covariates_mvtnorm.Rd @@ -29,13 +29,11 @@ provided.} the distribution is specified directly and \code{data} is ignored. Must be supplied together with either \code{sigma} or \code{sd} when \code{data} is \code{NULL}.} -\item{sigma}{named numeric matrix (p x p) giving the full covariance matrix. -Takes precedence over \code{sd} when both are provided. Column and row names -must match the names in \code{means}.} +\item{sigma}{numeric matrix (p x p) giving the full covariance matrix. +Takes precedence over \code{sd} when both are provided.} -\item{sd}{named numeric vector of standard deviations. Used to construct a -diagonal covariance matrix (\code{diag(sd^2)}) when \code{sigma} is not provided. -Names must match the names in \code{means}.} +\item{sd}{numeric vector of standard deviations. Used to construct a +diagonal covariance matrix (\code{diag(sd^2)}) when \code{sigma} is not provided.} \item{cat_covs}{character vector containing the names of the categorical covariates in orgCovs.} From 1c506459006b9fd740babc177e6d324bc493c809 Mon Sep 17 00:00:00 2001 From: Ron Keizer Date: Mon, 16 Mar 2026 21:19:21 -0700 Subject: [PATCH 6/6] feat: optional name-based reordering for sigma and sd When both `means` and `sigma`/`sd` are named, reorder to match the order of `means` and validate that the name sets match. Names remain optional: unnamed inputs are assumed to be in the correct order. Co-Authored-By: Claude Sonnet 4.6 --- R/sample_covariates_mvtnorm.R | 26 ++++++++-- man/sample_covariates_mvtnorm.Rd | 8 +++- .../testthat/test-sample_covariates_mvtnorm.R | 47 +++++++++++++++++++ 3 files changed, 74 insertions(+), 7 deletions(-) diff --git a/R/sample_covariates_mvtnorm.R b/R/sample_covariates_mvtnorm.R index 4d445a6..67aa689 100644 --- a/R/sample_covariates_mvtnorm.R +++ b/R/sample_covariates_mvtnorm.R @@ -13,9 +13,13 @@ #' the distribution is specified directly and `data` is ignored. Must be #' supplied together with either `sigma` or `sd` when `data` is `NULL`. #' @param sigma numeric matrix (p x p) giving the full covariance matrix. -#' Takes precedence over `sd` when both are provided. +#' Takes precedence over `sd` when both are provided. If both `sigma` and +#' `means` are named, the matrix is reordered to match the order of `means`; +#' names must refer to the same set of variables. #' @param sd numeric vector of standard deviations. Used to construct a #' diagonal covariance matrix (`diag(sd^2)`) when `sigma` is not provided. +#' If both `sd` and `means` are named, `sd` is reordered to match the order +#' of `means`; names must refer to the same set of variables. #' @param cat_covs character vector containing the names of the categorical #' covariates in orgCovs. #' @param n_subjects number of simulated subjects. Defaults to `nrow(data)` @@ -72,17 +76,29 @@ sample_covariates_mvtnorm <- function( } if (!is.null(sigma)) { cov_mat <- sigma + if (!is.null(names(means)) && !is.null(rownames(cov_mat))) { + if (!setequal(rownames(cov_mat), names(means))) { + stop("Names of `sigma` must match `names(means)` (same set of names).") + } + cov_mat <- cov_mat[names(means), names(means), drop = FALSE] + } } else { if (length(sd) != length(means)) { stop("`sd` must have the same length as `means`.") } + if (!is.null(names(means)) && !is.null(names(sd))) { + if (!setequal(names(sd), names(means))) { + stop("Names of `sd` must match `names(means)` (same set of names).") + } + sd <- sd[names(means)] + } cov_mat <- diag(sd^2) - if (!is.null(names(sd))) { - rownames(cov_mat) <- names(sd) - colnames(cov_mat) <- names(sd) - } else if (!is.null(names(means))) { + if (!is.null(names(means))) { rownames(cov_mat) <- names(means) colnames(cov_mat) <- names(means) + } else if (!is.null(names(sd))) { + rownames(cov_mat) <- names(sd) + colnames(cov_mat) <- names(sd) } } out <- mvtnorm::rmvnorm( diff --git a/man/sample_covariates_mvtnorm.Rd b/man/sample_covariates_mvtnorm.Rd index a06e06d..201b817 100644 --- a/man/sample_covariates_mvtnorm.Rd +++ b/man/sample_covariates_mvtnorm.Rd @@ -30,10 +30,14 @@ the distribution is specified directly and \code{data} is ignored. Must be supplied together with either \code{sigma} or \code{sd} when \code{data} is \code{NULL}.} \item{sigma}{numeric matrix (p x p) giving the full covariance matrix. -Takes precedence over \code{sd} when both are provided.} +Takes precedence over \code{sd} when both are provided. If both \code{sigma} and +\code{means} are named, the matrix is reordered to match the order of \code{means}; +names must refer to the same set of variables.} \item{sd}{numeric vector of standard deviations. Used to construct a -diagonal covariance matrix (\code{diag(sd^2)}) when \code{sigma} is not provided.} +diagonal covariance matrix (\code{diag(sd^2)}) when \code{sigma} is not provided. +If both \code{sd} and \code{means} are named, \code{sd} is reordered to match the order +of \code{means}; names must refer to the same set of variables.} \item{cat_covs}{character vector containing the names of the categorical covariates in orgCovs.} diff --git a/tests/testthat/test-sample_covariates_mvtnorm.R b/tests/testthat/test-sample_covariates_mvtnorm.R index 9f6a7f1..3ba6ffd 100644 --- a/tests/testthat/test-sample_covariates_mvtnorm.R +++ b/tests/testthat/test-sample_covariates_mvtnorm.R @@ -130,6 +130,53 @@ test_that("means + sd: length mismatch raises error", { ) }) +test_that("means + sd: reorders sd to match means when both are named", { + mu <- c(y = 50, x = 5) # y first + sds <- c(x = 1, y = 10) # x first — should be reordered + out <- sample_covariates_mvtnorm(means = mu, sd = sds, n_subjects = 2000, seed = 7) + expect_named(out, c("y", "x")) + expect_equal(sd(out$y), 10, tolerance = 1.2) + expect_equal(sd(out$x), 1, tolerance = 0.12) +}) + +test_that("means + sd: name mismatch raises error", { + expect_error( + sample_covariates_mvtnorm(means = c(a = 1, b = 2), sd = c(a = 1, c = 2), n_subjects = 10), + "Names of `sd` must match" + ) +}) + +test_that("means + sd: unnamed sd is assumed to be in the right order", { + mu <- c(x = 5, y = 50) + out <- sample_covariates_mvtnorm(means = mu, sd = c(1, 10), n_subjects = 2000, seed = 7) + expect_named(out, c("x", "y")) + expect_equal(sd(out$x), 1, tolerance = 0.12) + expect_equal(sd(out$y), 10, tolerance = 1.2) +}) + +test_that("means + sigma: reorders sigma to match means when both are named", { + mu <- c(y = 20, x = 10) # y first + S <- matrix(c(4, 1, 1, 9), nrow = 2, dimnames = list(c("x","y"), c("x","y"))) + out <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 100, seed = 1) + expect_named(out, c("y", "x")) +}) + +test_that("means + sigma: name mismatch raises error", { + mu <- c(a = 1, b = 2) + S <- matrix(c(1, 0, 0, 1), nrow = 2, dimnames = list(c("a","c"), c("a","c"))) + expect_error( + sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 10), + "Names of `sigma` must match" + ) +}) + +test_that("means + sigma: unnamed sigma is assumed to be in the right order", { + mu <- c(x = 10, y = 20) + S <- diag(2) # no dimnames + out <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 100, seed = 1) + expect_named(out, c("x", "y")) +}) + test_that("means without sigma or sd raises error", { expect_error( sample_covariates_mvtnorm(means = c(x = 1), n_subjects = 10),