I want to compute the FIRM importance scores for a model made from a tidymodels workflow. For regex, I will use the iris dataset and try to predict whether an observation is setosa or not.
library(tidymodels)
library(readr)
library(vip)
#clean data
iris <- iris %>%
mutate(class = case_when(Species == 'setosa' ~ 'setosa',
TRUE ~ 'other'))
iris$class = as.factor(iris$class)
iris <- subset(iris, select = -c(Species))
#split data into training and testing
iris_split = initial_split(iris, prop = 0.8)
cv_splits = vfold_cv(training(iris_split), v = 5)
#preprocessing
iris_recipe = recipe(class ~., data = iris) %>%
step_center(Sepal.Length) %>%
prep()
#specify MARS model
model = rand_forest(
mode = "classification",
mtry = tune(),
trees = 50
) %>%
set_engine("ranger", importance = "impurity")
#tuning parameters
tuning_grid = grid_regular(mtry(range=c(1,4)), levels = 4)
iris_wkfl = workflow() %>%
add_recipe(iris_recipe) %>%
add_model(model)
iris_tune = tune_grid(iris_wkfl,
resamples = cv_splits,
grid = tuning_grid,
metrics = metric_set(accuracy))
best_params = iris_tune %>%
select_best(metric = "accuracy")
best_model = finalize_workflow(iris_wkfl, best_params) %>%
parsnip::fit(data = training(iris_split)) %>%
pull_workflow_fit()
vip(best_model, method = "firm")
The last line produces an error from the pdp package.
Error in get_training_data.default(object) :
The training data could not be extracted from object. Please supply the raw training data using the train
argument in the call to partial
.
Is the following line correct? Or do I need to supply for transformed training data using my recipe first? I want to make sure that vip is applying my recipe when computing the importance scores. I know the error says "raw training data" but I am unsure if pdp knows about my workflow.
vip(best_model, method = "firm", train = training(iris_split))
You'll want to take the same approach that I outlined in this answer.
Start by tuning and then training your model on the training data:
Your model is now trained and you can compute model-independent variable importance scores like FIRM. There are a couple of steps:
pull()
the fitted model out of the workflow.class
.predict()
but unfortunately for ranger it ispredictions()
).Created on 2020-12-10 by the reprex package (v0.3.0.9001)