Produce a dataset of 6 columns: ID of each observation, variable name, SHAP
value, variable values (feature value), deviation of the feature value for
each observation (for coloring the point), and the mean SHAP values for each
variable. You can view this example dataset included in the package:
shap_long_iris
shap.prep( xgb_model = NULL, shap_contrib = NULL, X_train, top_n = NULL, var_cat = NULL )
xgb_model | an XGBoost (or LightGBM) model object, will derive the SHAP values from it |
---|---|
shap_contrib | optional to directly supply a SHAP values dataset. If
supplied, it will overwrite the |
X_train | the dataset of predictors used to calculate SHAP values, it provides feature values to the plot, must be supplied |
top_n | to choose top_n variables ranked by mean|SHAP| if needed |
var_cat | if supplied, will provide long format data, grouped by this categorical variable |
a long-format data.table, named as shap_long
The ID variable is added for each observation in the shap_contrib
dataset
for better tracking, it is created as 1:nrow(shap_contrib)
before melting
shap_contrib
into long format.
data("iris") X1 = as.matrix(iris[,-5]) mod1 = xgboost::xgboost( data = X1, label = iris$Species, gamma = 0, eta = 1, lambda = 0, nrounds = 1, verbose = FALSE) # shap.values(model, X_dataset) returns the SHAP # data matrix and ranked features by mean|SHAP| shap_values <- shap.values(xgb_model = mod1, X_train = X1) shap_values$mean_shap_score#> Petal.Length Petal.Width Sepal.Length Sepal.Width #> 0.62935975 0.21664035 0.02910357 0.00000000shap_values_iris <- shap_values$shap_score # shap.prep() returns the long-format SHAP data from either model or shap_long_iris <- shap.prep(xgb_model = mod1, X_train = X1) # is the same as: using given shap_contrib shap_long_iris <- shap.prep(shap_contrib = shap_values_iris, X_train = X1) # **SHAP summary plot** shap.plot.summary(shap_long_iris, scientific = TRUE)# Alternatives options to make the same plot: # option 1: from the xgboost model shap.plot.summary.wrap1(mod1, X = as.matrix(iris[,-5]), top_n = 3)# option 2: supply a self-made SHAP values dataset # (e.g. sometimes as output from cross-validation) shap.plot.summary.wrap2(shap_score = shap_values_iris, X = X1, top_n = 3)#### # # use `var_cat` to add a categorical variable, output the long-format data differently: library("data.table") data("iris") set.seed(123) iris$Group <- 0 iris[sample(1:nrow(iris), nrow(iris)/2), "Group"] <- 1 data.table::setDT(iris) X_train = as.matrix(iris[,c(colnames(iris)[1:4], "Group"), with = FALSE]) mod1 = xgboost::xgboost( data = X_train, label = iris$Species, gamma = 0, eta = 1, lambda = 0, nrounds = 1, verbose = FALSE) shap_long2 <- shap.prep(xgb_model = mod1, X_train = X_train, var_cat = "Group") # **SHAP summary plot** shap.plot.summary(shap_long2, scientific = TRUE) + ggplot2::facet_wrap(~ Group)