Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: irxforge
Title: Forging data for pharmacometric analyses
Version: 0.0.0.9000
Version: 0.0.0.9001
Authors@R: c(
person("Ron", "Keizer", email = "ron@insight-rx.com", role = c("cre", "aut")),
person("Michael", "McCarthy", email = "michael.mccarthy@insight-rx.com", role = "ctb"),
Expand Down
119 changes: 103 additions & 16 deletions R/sample_covariates_mvtnorm.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
#' Sample covariates from multivariate normal distributions
#'
#' Samples from a multivariate normal distribution either derived from observed
#' data or specified directly via `means` plus a covariance matrix (`sigma`) or
#' standard deviations (`sd`).
#'
#' @param data data.frame (n x p) containing the original, observed,
#' time-invariant covariates (ID should not be included) that will be used to
#' inform the imputation.
#' inform the imputation. Can be `NULL` when `means` and either `sigma` or
#' `sd` are supplied directly. Ignored with a warning when `means` is also
#' provided.
#' @param means named numeric vector of means for each covariate. When supplied,
#' the distribution is specified directly and `data` is ignored. Must be
#' supplied together with either `sigma` or `sd` when `data` is `NULL`.
#' @param sigma numeric matrix (p x p) giving the full covariance matrix.
#' Takes precedence over `sd` when both are provided. If both `sigma` and
#' `means` are named, the matrix is reordered to match the order of `means`;
#' names must refer to the same set of variables.
#' @param sd numeric vector of standard deviations. Used to construct a
#' diagonal covariance matrix (`diag(sd^2)`) when `sigma` is not provided.
#' If both `sd` and `means` are named, `sd` is reordered to match the order
#' of `means`; names must refer to the same set of variables.
#' @param cat_covs character vector containing the names of the categorical
#' covariates in orgCovs.
#' @param n_subjects number of simulated subjects, default is the number of
#' subjects in the data.
#' @param n_subjects number of simulated subjects. Defaults to `nrow(data)`
#' when `data` is provided; required (no default) when `data` is `NULL`.
#' @param exponential sample from exponential distribution? Default `FALSE`.
#' Only applies when means/covariance are derived from `data`.
#' @param conditional description...
#' @param seed integer random seed passed to [set.seed()] for reproducibility.
#' Default `NULL` does not set a seed.
Expand All @@ -18,23 +36,92 @@
#'
#' @note missing values in `data` must be coded as NA
#'
#' @examples
#' sample_covariates_mvtnorm(
#' means = c(WT = 70, HT = 170, AGE = 50),
#' sd = c(10, 20, 5),
#' n_subjects = 100
#' )
#' sample_covariates_mvtnorm(
#' means = c(WT = 70, HT = 170, AGE = 50),
#' sigma = matrix(
#' c(100, 20, 12,
#' 20, 400, 10,
#' 12, 10, 25), nrow = 3, ncol = 3),
#' n_subjects = 100
#' )
#'
#' @export
sample_covariates_mvtnorm <- function(
data,
data = NULL,
means = NULL,
sigma = NULL,
sd = NULL,
cat_covs = NULL,
n_subjects = nrow(data),
n_subjects = if (!is.null(data)) nrow(data) else stop("`n_subjects` must be specified when `data` is NULL."),
exponential = FALSE,
conditional = NULL,
seed = NULL,
...
) {
if (!is.null(seed)) set.seed(seed)

## Branch: distribution specified directly via means + sigma/sd
if (!is.null(means)) {
if (!is.null(data)) {
warning("`data` is ignored when `means` is provided.")
}
if (is.null(sigma) && is.null(sd)) {
stop("When `means` is supplied, either `sigma` or `sd` must also be provided.")
}
if (!is.null(sigma)) {
cov_mat <- sigma
if (!is.null(names(means)) && !is.null(rownames(cov_mat))) {
if (!setequal(rownames(cov_mat), names(means))) {
stop("Names of `sigma` must match `names(means)` (same set of names).")
}
cov_mat <- cov_mat[names(means), names(means), drop = FALSE]
}
} else {
if (length(sd) != length(means)) {
stop("`sd` must have the same length as `means`.")
}
if (!is.null(names(means)) && !is.null(names(sd))) {
if (!setequal(names(sd), names(means))) {
stop("Names of `sd` must match `names(means)` (same set of names).")
}
sd <- sd[names(means)]
}
cov_mat <- diag(sd^2)
if (!is.null(names(means))) {
rownames(cov_mat) <- names(means)
colnames(cov_mat) <- names(means)
} else if (!is.null(names(sd))) {
rownames(cov_mat) <- names(sd)
colnames(cov_mat) <- names(sd)
}
}
out <- mvtnorm::rmvnorm(
n_subjects,
mean = means,
sigma = cov_mat,
...
) |>
as.data.frame()
if (!is.null(names(means))) names(out) <- names(means)
return(out)
}

## Branch: derive distribution from data
if (is.null(data)) {
stop("Either `data` or `means` with `sigma`/`sd` must be provided.")
}

if(!is.null(conditional)) {
for(key in names(conditional)) {
data <- dplyr::filter(
data,
.data[[key]] >= min(conditional[[key]]) &
.data[[key]] >= min(conditional[[key]]) &
.data[[key]] <= max(conditional[[key]])
)
}
Expand All @@ -44,33 +131,33 @@ sample_covariates_mvtnorm <- function(
# FIXME: This code does nothing currently... is this function intended to
# work with categorical covariates? or only continuous? If the latter, how
# do we handle categorical?
cont_covs <- setdiff(names(data), cat_covs)
cont_covs <- setdiff(names(data), cat_covs)
miss_vars <- names(data)[colSums(is.na(data)) > 0]

## Get distribution and sample
if(exponential) {
# FIXME: This fails if there are zeroes or negative numbers. Should add some
# safety rails.
means <- apply(data, 2, function(x) mean(log(x)))
cov_mat <- stats::cov(log(data))
data_means <- apply(data, 2, function(x) mean(log(x)))
cov_mat <- stats::cov(log(data))
out <- mvtnorm::rmvnorm(
n_subjects,
mean = means,
n_subjects,
mean = data_means,
sigma = cov_mat
) |>
exp() |>
as.data.frame()
} else {
means <- apply(data, 2, mean)
data_means <- apply(data, 2, mean)
cov_mat <- stats::cov(data)
out <- mvtnorm::rmvnorm(
n_subjects,
mean = means,
n_subjects,
mean = data_means,
sigma = cov_mat
) |>
as.data.frame()
}

if (tibble::is_tibble(data)) out <- tibble::as_tibble(out)
out
out
}
53 changes: 46 additions & 7 deletions man/sample_covariates_mvtnorm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

125 changes: 125 additions & 0 deletions tests/testthat/test-sample_covariates_mvtnorm.R
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,128 @@ test_that("different seeds produce different output", {
out2 <- sample_covariates_mvtnorm(dat, n_subjects = 20, seed = 2)
expect_false(identical(out1, out2))
})

# --- Direct distribution specification (means + sigma / sds) ---

test_that("means + sigma samples correct number of rows and columns", {
mu <- c(x = 10, y = 20)
S <- matrix(c(4, 1, 1, 9), nrow = 2, dimnames = list(c("x","y"), c("x","y")))
out <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 100)
expect_equal(nrow(out), 100)
expect_equal(ncol(out), 2)
expect_named(out, c("x", "y"))
})

test_that("means + sigma samples near the specified mean (large n)", {
mu <- c(AGE = 40, WT = 70)
S <- diag(c(25, 100))
out <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 2000, seed = 42)
expect_equal(mean(out$AGE), 40, tolerance = 1.5)
expect_equal(mean(out$WT), 70, tolerance = 3)
})

test_that("means + sd constructs diagonal covariance and samples correctly", {
mu <- c(x = 5, y = 50)
sds <- c(x = 1, y = 10)
out <- sample_covariates_mvtnorm(means = mu, sd = sds, n_subjects = 2000, seed = 7)
expect_named(out, c("x", "y"))
expect_equal(mean(out$x), 5, tolerance = 0.25)
expect_equal(mean(out$y), 50, tolerance = 2.5)
expect_equal(sd(out$x), 1, tolerance = 0.12)
expect_equal(sd(out$y), 10, tolerance = 1.2)
})

test_that("means + sd: length mismatch raises error", {
expect_error(
sample_covariates_mvtnorm(means = c(a = 1, b = 2), sd = c(1, 2, 3), n_subjects = 10),
"`sd` must have the same length"
)
})

test_that("means + sd: reorders sd to match means when both are named", {
mu <- c(y = 50, x = 5) # y first
sds <- c(x = 1, y = 10) # x first — should be reordered
out <- sample_covariates_mvtnorm(means = mu, sd = sds, n_subjects = 2000, seed = 7)
expect_named(out, c("y", "x"))
expect_equal(sd(out$y), 10, tolerance = 1.2)
expect_equal(sd(out$x), 1, tolerance = 0.12)
})

test_that("means + sd: name mismatch raises error", {
expect_error(
sample_covariates_mvtnorm(means = c(a = 1, b = 2), sd = c(a = 1, c = 2), n_subjects = 10),
"Names of `sd` must match"
)
})

test_that("means + sd: unnamed sd is assumed to be in the right order", {
mu <- c(x = 5, y = 50)
out <- sample_covariates_mvtnorm(means = mu, sd = c(1, 10), n_subjects = 2000, seed = 7)
expect_named(out, c("x", "y"))
expect_equal(sd(out$x), 1, tolerance = 0.12)
expect_equal(sd(out$y), 10, tolerance = 1.2)
})

test_that("means + sigma: reorders sigma to match means when both are named", {
mu <- c(y = 20, x = 10) # y first
S <- matrix(c(4, 1, 1, 9), nrow = 2, dimnames = list(c("x","y"), c("x","y")))
out <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 100, seed = 1)
expect_named(out, c("y", "x"))
})

test_that("means + sigma: name mismatch raises error", {
mu <- c(a = 1, b = 2)
S <- matrix(c(1, 0, 0, 1), nrow = 2, dimnames = list(c("a","c"), c("a","c")))
expect_error(
sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 10),
"Names of `sigma` must match"
)
})

test_that("means + sigma: unnamed sigma is assumed to be in the right order", {
mu <- c(x = 10, y = 20)
S <- diag(2) # no dimnames
out <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 100, seed = 1)
expect_named(out, c("x", "y"))
})

test_that("means without sigma or sd raises error", {
expect_error(
sample_covariates_mvtnorm(means = c(x = 1), n_subjects = 10),
"either `sigma` or `sd` must also be provided"
)
})

test_that("data = NULL without means raises error", {
expect_error(
sample_covariates_mvtnorm(data = NULL, n_subjects = 10),
"Either `data` or `means`"
)
})

test_that("warning is issued when both data and means are provided", {
dat <- data.frame(x = rnorm(50), y = rnorm(50))
mu <- c(x = 0, y = 0)
S <- diag(2)
expect_warning(
sample_covariates_mvtnorm(data = dat, means = mu, sigma = S, n_subjects = 10),
"`data` is ignored"
)
})

test_that("n_subjects is required when data is NULL", {
mu <- c(x = 0)
S <- matrix(1)
expect_error(
sample_covariates_mvtnorm(means = mu, sigma = S),
"n_subjects.*must be specified"
)
})

test_that("seed produces reproducible output with means + sigma", {
mu <- c(x = 0, y = 1)
S <- diag(2)
out1 <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 30, seed = 5)
out2 <- sample_covariates_mvtnorm(means = mu, sigma = S, n_subjects = 30, seed = 5)
expect_equal(out1, out2)
})
Loading