I am trying to build a catboost model within the tidymodels framework. Minimal reproducible example is given below. I am able to use the DALEX
and modelStudio
packages to get model explanations but I want to create VIP plots like this and summary shap plots like this for this catboost model. I have tried packages like fastshap
, SHAPforxgboost
without any luck. I realise that i have to extract the variable importance and shap values from the model
object and use them to produce these plots but dont know how to do that. Is there a way to get this done in R?
library(tidymodels)
library(treesnip)
library(catboost)
library(modelStudio)
library(DALEXtra)
library(DALEX)
data <- structure(list(Age = c(74, 60, 57, 53, 72, 72, 71, 77, 50, 66), StatusofNation0developed = structure(c(2L, 2L, 2L, 2L, 2L,
1L, 2L, 1L, 1L, 2L), .Label = c("0", "1"), class = "factor"),
treatment = structure(c(2L, 1L, 2L, 2L, 2L, 1L, 1L, 3L, 1L,
2L), .Label = c("0", "1", "2"), class = "factor"), InHospitalMortalityMortality = c(0,
0, 1, 1, 1, 0, 0, 1, 1, 0)), row.names = c(NA, 10L), class = "data.frame")
split <- initial_split(data, strata = InHospitalMortalityMortality)
train <- training(split)
test <- testing(split)
train$InHospitalMortalityMortality <- as.factor(train$InHospitalMortalityMortality)
rec <- recipe(InHospitalMortalityMortality ~ ., data = train)
clf <- boost_tree() %>%
set_engine("catboost") %>%
set_mode("classification")
wflow <- workflow() %>%
add_recipe(rec) %>%
add_model(clf)
model <- wflow %>% fit(data = train)
explainer <- explain_tidymodels(model,
data = test,
y = test$InHospitalMortalityMortality,
label = "catboost")
new_observation <- test[1:2,]
modelStudio(explainer, new_observation)
The link above provides an answer, but it is incomplete. Here it is completed, following an identical workflow.
As indicated: first, install R packages {fastshap} and and {reticulate}. Next, setup a virtual environment for python use with {reticulate}. Setting up a virtual environment is relatively straightforward when using RStudio. Please check their reference material for step by step instructions.
Then, pip install {shap} and {matplotlib} in venv -- note that matplotlib 3.2.2 would seem necessary for summary plots (see GitHub issues for greater detail).
The workflow (from treesnip docs):
Fit the workflow:
With a fit workflow, we can now create shap values via {fastshap} and plot with {fastshap} and {reticulate}.
First, the force plots: to do this, we need to create a prediction function for the pred_wrapper argument.
Now we want the mean prediction values for the baseline argument.
Here, create the shap values:
Now, for the force plot:
Now for the summary plot: use {reticulate} to access function directly:
The same would work for dependency plots, for example.
Final note: repeated rendering will result in buggy visualizations. Naming a feature directly (i.e., "cut") in dependence_plot threw me an error.