Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 39 additions & 39 deletions R/crunch.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,24 @@
#' @param to_combine Levels to combine.
#' @param other_level Name of the new level, e.g., "Other 3" if three levels are combined.
#' @returns A factor with combined levels.
combine_levels <- function(f, to_combine, other_level = "Other") {
if (length(to_combine) <= 2L) {
return(f)
}
old_levels <- lvl <- levels(f)
to_keep <- setdiff(lvl, to_combine)
if (other_level %in% to_keep) {
stop("The 'other_level' level is already present in 'f'")
}
new_levels <- c(to_keep, other_level)
old_levels[!(lvl %in% to_keep)] <- other_level

# like in forcats:::lvls_revalue()
out <- match(old_levels, new_levels)[f]
attributes(out) <- attributes(f)
attr(out, "levels") <- new_levels
return(out)
}
# combine_levels <- function(f, to_combine, other_level = "Other") {
# if (length(to_combine) <= 2L) {
# return(f)
# }
# old_levels <- lvl <- levels(f)
# to_keep <- setdiff(lvl, to_combine)
# if (other_level %in% to_keep) {
# stop("The 'other_level' level is already present in 'f'")
# }
# new_levels <- c(to_keep, other_level)
# old_levels[!(lvl %in% to_keep)] <- other_level
#
# # like in forcats:::lvls_revalue()
# out <- match(old_levels, new_levels)[f]
# attributes(out) <- attributes(f)
# attr(out, "levels") <- new_levels
# return(out)
# }

#' Lump rare factor levels (currently unused)
#'
Expand All @@ -42,27 +42,27 @@ combine_levels <- function(f, to_combine, other_level = "Other") {
#' A list with three elements: "f" is a factor with combined levels,
#' "combined" is a character vector with the combined levels, and "other_level"
#' is the name of the new level.
flump <- function(f, combine_m, w = NULL) {
if (is.null(w)) {
N <- collapse::fnobs(f, g = f)
} else {
N <- collapse::fsum(w, g = f, fill = TRUE)
}
to_combine <- levels(f)[order(N)][seq_len(combine_m)]
to_combine <- setdiff(to_combine, NA) # don't collapse explicit NA levels
m_other <- length(to_combine)
if (m_other <= 2L) {
return(list(f = f, combined = NULL, other_level = NULL))
}
other_level <- paste("Other", m_other)

out <- list(
f = combine_levels(f, to_combine = to_combine, other_level = other_level),
combined = to_combine,
other_level = other_level
)
return(out)
}
# flump <- function(f, combine_m, w = NULL) {
# if (is.null(w)) {
# N <- collapse::fnobs(f, g = f)
# } else {
# N <- collapse::fsum(w, g = f, fill = TRUE)
# }
# to_combine <- levels(f)[order(N)][seq_len(combine_m)]
# to_combine <- setdiff(to_combine, NA) # don't collapse explicit NA levels
# m_other <- length(to_combine)
# if (m_other <= 2L) {
# return(list(f = f, combined = NULL, other_level = NULL))
# }
# other_level <- paste("Other", m_other)
#
# out <- list(
# f = combine_levels(f, to_combine = to_combine, other_level = other_level),
# combined = to_combine,
# other_level = other_level
# )
# return(out)
# }


#' Prepares discrete feature for grouped operations of {collapse}
Expand Down
33 changes: 29 additions & 4 deletions tests/testthat/test-plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,45 @@ xvars <- colnames(iris)[-1]
M <- feature_effects(fit, v = xvars, data = iris, y = "Sepal.Length", breaks = 5)

test_that("plot() returns correct class", {
expect_s3_class(plot(M), "patchwork")
expect_s3_class(plot(M[1L]), "ggplot")
expect_s3_class(plot(M, rotate_x = 45, title = "multiple plots"), "patchwork")
expect_s3_class(plot(M, stats = "resid_mean", interval = "ci"), "patchwork")

expect_s3_class(plot(M[1L], rotate_x = 45), "ggplot")
expect_s3_class(plot(M[1L], stats = "resid_mean", interval = "ci"), "ggplot")

# Plotly
p <- plot(M, plotly = TRUE)
p <- plot(M, plotly = TRUE, title = "multiple plots")
expect_s3_class(p, "plotly")
expect_true("subplot" %in% names(p$x))

p <- plot(M[1L], plotly = TRUE)
expect_s3_class(p, "plotly")
expect_false("subplot" %in% names(p$x))

p <- plot(M, stats = "resid_mean", interval = "ci", plotly = TRUE)
expect_s3_class(p, "plotly")
expect_true("subplot" %in% names(p$x))

p <- plot(M[1L], stats = "resid_mean", interval = "ci", plotly = TRUE)
expect_s3_class(p, "plotly")
expect_false("subplot" %in% names(p$x))
})

test_that("plot() returns correct class with single ALE line", {
expect_s3_class(plot(M[1:2], stats = "ale"), "patchwork")
expect_s3_class(plot(M[1L], stats = "ale"), "ggplot")

# Plotly
p <- plot(M[1:2], plotly = TRUE, stats = "ale")
expect_s3_class(p, "plotly")
expect_true("subplot" %in% names(p$x))

p <- plot(M[1L], plotly = TRUE, stats = "ale")
expect_s3_class(p, "plotly")
expect_false("subplot" %in% names(p$x))
})


test_that("ncols has an effect", {
# How to do with patchwork??

Expand Down Expand Up @@ -59,4 +85,3 @@ test_that("y axis can be shared", {
expect_null(p1$x$layout$yaxis$range)
expect_equal(p2$x$layout$yaxis$range, p2$x$layout$yaxis3$range)
})