diff --git a/DESCRIPTION b/DESCRIPTION index a95507b..26c5fe8 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: kernelshap Title: Kernel SHAP -Version: 0.9.0 +Version: 0.9.1 Authors@R: c( person("Michael", "Mayer", , "mayermichael79@gmail.com", role = c("aut", "cre"), comment = c(ORCID = "0009-0007-2540-9629")), diff --git a/NEWS.md b/NEWS.md index d6c3903..8209be7 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,9 @@ +# kernelshap 0.9.1 + +### Speed and memory improvements + +- More pre-calculations for exact part of the methods ([#175](https://github.com/ModelOriented/kernelshap/pull/175)). + # kernelshap 0.9.0 ### Bug fix diff --git a/R/kernelshap.R b/R/kernelshap.R index 4ddc690..785ef77 100644 --- a/R/kernelshap.R +++ b/R/kernelshap.R @@ -242,16 +242,19 @@ kernelshap.default <- function( p = p, deg = hybrid_degree, feature_names = feature_names ) } - m_exact <- nrow(precalc[["Z"]]) + Z <- precalc[["Z"]] + m_exact <- nrow(Z) prop_exact <- sum(precalc[["w"]]) - precalc[["bg_X_exact"]] <- rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact)) + precalc[["bg_exact_rep"]] <- rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact)) + g <- rep_each(m_exact, each = bg_n) + precalc[["Z_exact_rep"]] <- Z[g, , drop = FALSE] } else { precalc <- list() m_exact <- 0L prop_exact <- 0 } if (!exact) { - precalc[["bg_X_m"]] <- rep_rows(bg_X, rep.int(seq_len(bg_n), m)) + precalc[["bg_sampling_rep"]] <- rep_rows(bg_X, rep.int(seq_len(bg_n), m)) } if (max(m, m_exact) * bg_n > 2e5) { @@ -276,6 +279,7 @@ kernelshap.default <- function( max_iter = max_iter, v0 = v0, precalc = precalc, + bg_n = bg_n, ... ) } else { @@ -298,6 +302,7 @@ kernelshap.default <- function( max_iter = max_iter, v0 = v0, precalc = precalc, + bg_n = bg_n, ... ) if (verbose && n >= 2L) { diff --git a/R/permshap.R b/R/permshap.R index 398bcf8..d25782f 100644 --- a/R/permshap.R +++ b/R/permshap.R @@ -143,23 +143,27 @@ permshap.default <- function( # Pre-calculations that are identical for each row to be explained if (exact) { Z <- exact_Z(p, feature_names = feature_names) - m_exact <- nrow(Z) - 2L # We won't evaluate vz for first and last row + Z_no_extremes <- Z[2L:(nrow(Z) - 1L), , drop = FALSE] + m_exact <- nrow(Z_no_extremes) # 2^p - 2 m_eval <- 0L # for consistency with sampling case + g <- rep_each(m_exact, each = bg_n) precalc <- list( - Z = Z, - bg_X_rep = rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact)), + Z_exact_rep = Z_no_extremes[g, , drop = FALSE], + bg_exact_rep = rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact)), positions = positions_for_exact(Z), shapley_w = shapley_weights(p, ell = rowSums(Z) - 1) # how many other players? ) } else { max_iter <- as.integer(ceiling(max_iter / p) * p) # should be multiple of p - m_exact <- 2L * p - m <- 2L * (p - 3L) # inner loop + Z <- exact_Z_balanced(p, feature_names) + m_exact <- nrow(Z) # 2L * p + m <- 2L * (p - 3L) # for inner loop m_eval <- if (low_memory) m else m * p # outer loop + g <- rep_each(m_exact, each = bg_n) precalc <- list( - bg_X_rep = rep_rows(bg_X, rep.int(seq_len(bg_n), m_eval)), - bg_X_balanced = rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact)), - Z_balanced = exact_Z_balanced(p, feature_names) + Z_balanced_rep = Z[g, , drop = FALSE], + bg_balanced_rep = rep_rows(bg_X, rep.int(seq_len(bg_n), m_exact)), + bg_sampling_rep = rep_rows(bg_X, rep.int(seq_len(bg_n), m_eval)) ) } @@ -184,6 +188,7 @@ permshap.default <- function( low_memory = low_memory, tol = tol, max_iter = max_iter, + bg_n = bg_n, ... ) } else { @@ -205,6 +210,7 @@ permshap.default <- function( low_memory = low_memory, tol = tol, max_iter = max_iter, + bg_n = bg_n, ... ) if (verbose && n >= 2L) { diff --git a/R/utils.R b/R/utils.R index 6aac1c4..a53da38 100644 --- a/R/utils.R +++ b/R/utils.R @@ -73,36 +73,32 @@ exact_Z <- function(p, feature_names) { #' #' @inheritParams kernelshap #' @param x Row to be explained. -#' @param bg Background data stacked m times. -#' @param Z A logical (m x p) matrix with on-off values. +#' @param bg_rep Background data stacked m times. +#' @param Z_rep A logical ((m * bg_n) x p) matrix with on-off values. #' @param w A vector with case weights (of the same length as the unstacked #' background data). +#' @param bg_n Size of background dataset (unstacked). #' @returns A (m x K) matrix with vz values. -get_vz <- function(x, bg, Z, object, pred_fun, w, ...) { - m <- nrow(Z) - n_bg <- nrow(bg) / m # because bg was replicated m times - - # Replicate Z, so that bg and Z are of dimension (m*n_bg x p) - g <- rep_each(m, each = n_bg) - Z_rep <- Z[g, , drop = FALSE] - - for (v in colnames(Z)) { +get_vz <- function(x, bg_rep, Z_rep, object, pred_fun, w, bg_n, ...) { + for (v in colnames(Z_rep)) { s <- Z_rep[, v] if (is.matrix(x)) { - bg[s, v] <- x[, v] + bg_rep[s, v] <- x[, v] } else { - bg[[v]][s] <- x[[v]] + bg_rep[[v]][s] <- x[[v]] } } - preds <- align_pred(pred_fun(object, bg, ...)) + preds <- align_pred(pred_fun(object, bg_rep, ...)) # Aggregate (distinguishing fast 1-dim case) + m <- nrow(Z_rep) %/% bg_n if (ncol(preds) == 1L) { return(wrowmean_vector(preds, ngroups = m, w = w)) } + g <- rep_each(m, each = bg_n) if (is.null(w)) { - return(rowsum(preds, group = g, reorder = FALSE) / n_bg) + return(rowsum(preds, group = g, reorder = FALSE) / bg_n) } rowsum(preds * w, group = g, reorder = FALSE) / sum(w) } diff --git a/R/utils_kernelshap.R b/R/utils_kernelshap.R index ed33654..5bb916e 100644 --- a/R/utils_kernelshap.R +++ b/R/utils_kernelshap.R @@ -15,6 +15,7 @@ kernelshap_one <- function( max_iter, v0, precalc, + bg_n, ...) { p <- length(feature_names) K <- ncol(v1) @@ -22,18 +23,24 @@ kernelshap_one <- function( # Calculate A_exact and b_exact if (exact || deg >= 1L) { - A_exact <- precalc[["A"]] # (p x p) - bg_X_exact <- precalc[["bg_X_exact"]] # (m_ex*n_bg x p) - Z <- precalc[["Z"]] # (m_ex x p) + A_exact <- precalc$A # (p x p) + Z <- precalc$Z # (m_ex x p) m_exact <- nrow(Z) v0_m_exact <- v0[rep.int(1L, m_exact), , drop = FALSE] # (m_ex x K) # Most expensive part vz <- get_vz( - x = x, bg = bg_X_exact, Z = Z, object = object, pred_fun = pred_fun, w = bg_w, ... + x = x, + bg_rep = precalc$bg_exact_rep, # (m_ex*bg_n x p) + Z_rep = precalc$Z_exact_rep, # (m_ex*bg_n x p) + object = object, + pred_fun = pred_fun, + w = bg_w, + bg_n = bg_n, + ... ) # Note: w is correctly replicated along columns of (vz - v0_m_exact) - b_exact <- crossprod(Z, precalc[["w"]] * (vz - v0_m_exact)) # (p x K) + b_exact <- crossprod(Z, precalc$w * (vz - v0_m_exact)) # (p x K) # Some of the hybrid cases are exact as well if (exact || trunc(p / 2) == deg) { @@ -43,7 +50,8 @@ kernelshap_one <- function( } # Iterative sampling part, always using A_exact and b_exact to fill up the weights - bg_X_m <- precalc[["bg_X_m"]] # (m*n_bg x p) + g <- rep_each(m, each = bg_n) + v0_m <- v0[rep.int(1L, m), , drop = FALSE] # (m x K) est_m <- array( data = 0, dim = c(max_iter, p, K), dimnames = list(NULL, feature_names, K_names) @@ -62,16 +70,23 @@ kernelshap_one <- function( while (!converged && n_iter < max_iter) { n_iter <- n_iter + 1L input <- input_sampling(p = p, m = m, deg = deg, feature_names = feature_names) - Z <- input[["Z"]] + Z <- input$Z # Expensive # (m x K) vz <- get_vz( - x = x, bg = bg_X_m, Z = Z, object = object, pred_fun = pred_fun, w = bg_w, ... + x = x, + bg_rep = precalc$bg_sampling_rep, # (m*bg_n x p) + Z_rep = Z[g, , drop = FALSE], + object = object, + pred_fun = pred_fun, + w = bg_w, + bg_n = bg_n, + ... ) - # The sum of weights of A_exact and input[["A"]] is 1, same for b - A_temp <- A_exact + input[["A"]] # (p x p) - b_temp <- b_exact + crossprod(Z, input[["w"]] * (vz - v0_m)) # (p x K) + # The sum of weights of A_exact and input$A is 1, same for b + A_temp <- A_exact + input$A # (p x p) + b_temp <- b_exact + crossprod(Z, input$w * (vz - v0_m)) # (p x K) A_sum <- A_sum + A_temp # (p x p) b_sum <- b_sum + b_temp # (p x K) diff --git a/R/utils_permshap.R b/R/utils_permshap.R index 1f7ed88..422c9ed 100644 --- a/R/utils_permshap.R +++ b/R/utils_permshap.R @@ -26,26 +26,25 @@ permshap_one <- function( low_memory, tol, max_iter, + bg_n, ...) { - bg <- precalc$bg_X_rep - p <- length(feature_names) K <- ncol(v1) K_names <- colnames(v1) beta_n <- matrix(0, nrow = p, ncol = K, dimnames = list(feature_names, K_names)) if (exact) { - Z <- precalc$Z # ((m_ex+2) x K) - vz <- get_vz( # (m_ex x K) + vz <- get_vz( x = x, - bg = bg, - Z = Z[2L:(nrow(Z) - 1L), , drop = FALSE], # (m_ex x p) + bg_rep = precalc$bg_exact_rep, + Z_rep = precalc$Z_exact_rep, object = object, pred_fun = pred_fun, w = bg_w, + bg_n = bg_n, ... ) - vz <- rbind(v0, vz, v1) # we add the cheaply calculated v0 and v1 + vz <- rbind(v0, vz, v1) for (j in seq_len(p)) { pos <- precalc$positions[[j]] @@ -68,11 +67,12 @@ permshap_one <- function( # Pre-calculate part of Z with rowsum 1 or p - 1 vz_balanced <- get_vz( # (2p x K) x = x, - bg = precalc$bg_X_balanced, - Z = precalc$Z_balanced, + bg_rep = precalc$bg_balanced_rep, + Z_rep = precalc$Z_balanced_rep, object = object, pred_fun = pred_fun, w = bg_w, + bg_n = bg_n, ... ) @@ -83,24 +83,35 @@ permshap_one <- function( from_balanced <- c(2L, 2L + p, p, 2L * p) from_iter <- c(3L:(p - 1L), (p + 3L):(2L * p - 1L)) + bg_sampling_rep <- precalc$bg_sampling_rep + g <- rep_each(nrow(bg_sampling_rep) %/% bg_n, each = bg_n) + while (!converged && n_iter < max_iter) { chains <- balanced_chains(p) Z <- lapply(chains, sample_Z_from_chain, feature_names = feature_names) if (!low_memory) { # predictions for all chains at once Z <- do.call(rbind, Z) vz <- get_vz( - x = x, bg = bg, Z = Z, object = object, pred_fun = pred_fun, w = bg_w, ... + x = x, + bg_rep = bg_sampling_rep, + Z_rep = Z[g, , drop = FALSE], + object = object, + pred_fun = pred_fun, + w = bg_w, + bg_n = bg_n, + ... ) } else { # predictions for each chain separately vz <- vector("list", length = p) for (j in seq_len(p)) { vz[[j]] <- get_vz( x = x, - bg = bg, - Z = Z[[j]], + bg_rep = bg_sampling_rep, + Z_rep = Z[[j]][g, , drop = FALSE], object = object, pred_fun = pred_fun, w = bg_w, + bg_n = bg_n, ... ) } diff --git a/backlog/compare_with_python.R b/backlog/compare_with_python.R index 936464f..bd8006c 100644 --- a/backlog/compare_with_python.R +++ b/backlog/compare_with_python.R @@ -25,23 +25,23 @@ ks # [1,] -2.050074 -0.28048747 0.1281222 0.01587382 # [2,] -2.085838 0.04050415 0.1283010 0.03731644 -# Pure sampling version takes a bit longer (6.6 seconds) +# Pure sampling version takes a bit longer (5.6 seconds) system.time( ks2 <- kernelshap(fit, X_small, bg_X = bg_X, exact = FALSE, hybrid_degree = 0) ) ks2 bench::mark(kernelshap(fit, X_small, bg_X = bg_X, verbose=F)) -# 2.17s 1.64GB -> 1.79s 1.43GB +# 1.66s 1.4GB -> 1.79s 1.43GB bench::mark(kernelshap(fit, X_small, bg_X = bg_X, verbose=F, exact=F, hybrid_degree = 1)) -# 4.88s 2.79GB -> 4.38s 2.48GB +# 4.58s 2.45GB -> 4.38s 2.48GB bench::mark(permshap(fit, X_small, bg_X = bg_X, verbose=F)) -# 1.97s 1.64GB -> 1.9s 1.43GB +# 1.75s 1.4GB -> 1.9s 1.43GB bench::mark(permshap(fit, X_small, bg_X = bg_X, verbose=F, exact=F)) -# 3.04s 1.88GB -> 2.8s 1.64GB +# 3.97s 1.63GB -> 2.8s 1.64GB # SHAP values of first 2 observations: # carat clarity color cut diff --git a/backlog/compare_with_python2.R b/backlog/compare_with_python2.R index 8e23801..40ab234 100644 --- a/backlog/compare_with_python2.R +++ b/backlog/compare_with_python2.R @@ -18,8 +18,8 @@ pf <- function(model, newdata) { } ks <- kernelshap(pf, head(X), bg_X = X, pred_fun = pf) ks # -1.196216 -1.241848 -0.9567848 3.879420 -0.33825 0.5456252 -es <- permshap(pf, head(X), bg_X = X, pred_fun = pf) -es # -1.196216 -1.241848 -0.9567848 3.879420 -0.33825 0.5456252 +ps <- permshap(pf, head(X), bg_X = X, pred_fun = pf) +ps # -1.196216 -1.241848 -0.9567848 3.879420 -0.33825 0.5456252 set.seed(10) kss <- kernelshap( @@ -61,3 +61,15 @@ ksh2 <- kernelshap( tol = 0.0001 ) ksh2 # 1.195976 -1.241107 -0.9565121 3.878891 -0.3384621 0.5451118 + +set.seed(1) +pss <- permshap( + pf, + head(X, 1), + bg_X = X, + pred_fun = pf, + exact = FALSE, + max_iter = 40000, + tol = 0.0001 +) +pss # -1.222608 -1.252001 -0.9312635 3.890444 -0.33825 0.5456252 non-convergence diff --git a/packaging.R b/packaging.R index cc9ab08..cc4b6d4 100644 --- a/packaging.R +++ b/packaging.R @@ -15,7 +15,7 @@ library(usethis) use_description( fields = list( Title = "Kernel SHAP", - Version = "0.9.0", + Version = "0.9.1", Description = "Efficient implementation of Kernel SHAP (Lundberg and Lee, 2017, ) permutation SHAP, and additive SHAP for model interpretability. diff --git a/tests/testthat/test-basic.R b/tests/testthat/test-basic.R index eea4634..d3a273b 100644 --- a/tests/testthat/test-basic.R +++ b/tests/testthat/test-basic.R @@ -193,18 +193,7 @@ test_that("exact hybrid kernelshap() is similar to exact (non-hybrid)", { expect_equal(s1$S, shap[[1L]]$S) }) -# test_that("kernelshap works for large p (hybrid case)", { -# set.seed(9L) -# X <- data.frame(matrix(rnorm(20000L), ncol = 100L)) -# y <- X[, 1L] * X[, 2L] * X[, 3L] -# fit <- lm(y ~ X1:X2:X3 + ., data = cbind(y = y, X)) -# s <- kernelshap(fit, X[1L, ], bg_X = X, verbose = FALSE) -# -# expect_equal(s$baseline, mean(y)) -# expect_equal(rowSums(s$S) + s$baseline, unname(predict(fit, X[1L, ]))) -# }) - -test_that("kernelshap and permshap work for models with high-order interactions", { +test_that("exact kernelshap and permshap agree with Python on model with high-order interactions", { # Expected: Python output # import numpy as np # import shap # 0.47.2 @@ -263,10 +252,36 @@ test_that("kernelshap and permshap work for models with high-order interactions" ps <- permshap(pf, head(X, 2), bg_X = X, pred_fun = pf, verbose = FALSE) expect_equal(unname(ps$S), expected) +}) - # Sampling versions of KernelSHAP is quite close +test_that("Test that sampling versions are close to exact with interaction of order 3", { + # We use a different example with less multi-collinearity, but still there is strong + # collinearity between the first two features + n <- 100 set.seed(1) - ksh2 <- kernelshap( + X <- data.frame( + x1 = 1:n, + x2 = sqrt(1:n), + x3 = cos(1:n), + x4 = rnorm(n), + x5 = rexp(n), + x6 = runif(n) + ) + + pf <- function(model, newdata) { + x <- newdata + x[, 1] * x[, 2] * x[, 3] + x[, 4] + } + ks <- kernelshap(pf, head(X, 1), bg_X = X, pred_fun = pf, verbose = FALSE) + ps <- permshap(pf, head(X, 1), bg_X = X, pred_fun = pf, verbose = FALSE) + expect_equal(ks$S, ps$S) + expect_equal( + c(ks$S), c(-44.71698, -32.89963, 78.34841, -0.7353412, 0, 0), + tolerance = 0.0001 + ) + + # Sampling versions of KernelSHAP are very close + ks2 <- kernelshap( pf, head(X, 1), bg_X = X, @@ -276,16 +291,17 @@ test_that("kernelshap and permshap work for models with high-order interactions" m = 1000, max_iter = 100, tol = 0.001, - verbose = FALSE + verbose = FALSE, + seed = 1 ) + expect_true(ks2$converged) expect_equal( - c(ksh2$S), - c(-1.194878, -1.24747, -0.9596389, 3.883523, -0.3349787, 0.5453894), - tolerance = 1e-4 + c(ks2$S), c(-44.52355, -32.93511, 78.16014, -0.7202121, -0.02362539, 0.03880312), + # Exact c(-44.71698, -32.89963, 78.34841, -0.7353412, 0, 0) + tolerance = 0.0001 ) - set.seed(1) - ksh1 <- kernelshap( + ks1 <- kernelshap( pf, head(X, 1), bg_X = X, @@ -294,37 +310,56 @@ test_that("kernelshap and permshap work for models with high-order interactions" exact = FALSE, m = 1000, max_iter = 1000, - tol = 0.002, - verbose = FALSE + tol = 0.001, + verbose = FALSE, + seed = 1 ) + expect_true(ks1$converged) expect_equal( - c(ksh1$S), - c(-1.196958, -1.256924, -0.9603291, 3.886163, -0.3277153, 0.5477104), - tolerance = 1e-3 + c(ks1$S), c(-44.8478, -32.81717, 78.46633, -0.8514861, 0.01075054, 0.03582543), + # Exact c(-44.71698, -32.89963, 78.34841, -0.7353412, 0, 0) + tolerance = 0.001 ) - set.seed(1) - ksh0 <- suppressWarnings( - kernelshap( - pf, - head(X, 1), - bg_X = X, - pred_fun = pf, - hybrid_degree = 0, - exact = FALSE, - m = 10000, - max_iter = 10000, - tol = 0.003, - verbose = FALSE - ) + ks0 <- kernelshap( + pf, + head(X, 1), + bg_X = X, + pred_fun = pf, + hybrid_degree = 0, + exact = FALSE, + m = 1000, + max_iter = 1000, + tol = 0.005, + verbose = FALSE, + seed = 1 ) + expect_true(ks0$converged) expect_equal( - c(ksh0$S), - c(-1.18917, -1.2298, -0.9247995, 3.80673, -0.3144175, 0.5434034), - tolerance = 1e-3 + c(ks0$S), c(-44.29753, -33.39267, 78.67423, -0.290739, -0.3779175, -0.3189199), + # Exact c(-44.71698, -32.89963, 78.34841, -0.7353412, 0, 0) + tolerance = 0.001 ) -}) + # Too slow for closer results, but we can see additive recovery for x4-x6 + pss <- permshap( + pf, + head(X, 1), + bg_X = X, + pred_fun = pf, + exact = FALSE, + max_iter = 100000, + tol = 0.005, + verbose = FALSE, + seed = 1 + ) + expect_true(pss$converged) + expect_equal( + c(pss$S), c(-44.36299, -32.79343, 77.88822, -0.7353412, 0, 0), + # Exact c(-44.71698, -32.89963, 78.34841, -0.7353412, 0, 0) + tolerance = 0.001 + ) +}) test_that("Random seed works", { n <- 100 @@ -344,9 +379,36 @@ test_that("Random seed works", { } for (algo in c(permshap, kernelshap)) { - s1a <- algo(pf, head(X, 2), bg_X = X, pred_fun = pf, verbose = FALSE, seed = 1, exact = FALSE, hybrid_degree = 0) - s1b <- algo(pf, head(X, 2), bg_X = X, pred_fun = pf, verbose = FALSE, seed = 1, exact = FALSE, hybrid_degree = 0) - s2 <- algo(pf, head(X, 2), bg_X = X, pred_fun = pf, verbose = FALSE, seed = 2, exact = FALSE, hybrid_degree = 0) + s1a <- algo( + pf, + head(X, 2), + bg_X = X, + pred_fun = pf, + verbose = FALSE, + seed = 1, + exact = FALSE, + hybrid_degree = 0 + ) + s1b <- algo( + pf, + head(X, 2), + bg_X = X, + pred_fun = pf, + verbose = FALSE, + seed = 1, + exact = FALSE, + hybrid_degree = 0 + ) + s2 <- algo( + pf, + head(X, 2), + bg_X = X, + pred_fun = pf, + verbose = FALSE, + seed = 2, + exact = FALSE, + hybrid_degree = 0 + ) expect_equal(s1a, s1b) expect_false(identical(s1a$S, s2$S)) }