Roc curves with mlr3::autoplot() for benchmark with "holdout" resampling

596 views Asked by At

I am using the mlr3 package and I want to plot ROC curves for different models. If I use cross validation as explained in the documentation it works perfectly well, but if I use "holdout" for the resampling then I get an error Error: Invalid show_cb. Inconsistent with calc_avg of evalmod..

Here is the code:

library("mlr3")
library("mlr3learners")
library("mlr3viz")

# one task only
tasks = lapply(c("german_credit"), tsk)

# get some learners and for all learners ...
# * predict probabilities
# * predict also on the training set
learners = c("classif.featureless", "classif.rpart", "classif.ranger", "classif.kknn")
learners = lapply(learners, lrn,
                  predict_type = "prob")

# compare via 3-fold cross validation
resamplings = rsmp("holdout", ratio = .8) # holdout instead of cv

# create a BenchmarkDesign object
design = benchmark_grid(tasks, learners, resamplings)
print(design)

bmr = benchmark(design)
autoplot(bmr, type = "roc")

Thanks for your help, Mathieu

1

There are 1 answers

1
maRmat On BEST ANSWER

In case someone else is having the same problem here is a solution. The problem occurs because the argument calc_avg is set to TRUE by default in precrec::evalmod() and the function is used as is in mlr3viz::autoplot(). Since as_precrec() returns an object without different dsids (different values coming from different folds in the case of cross-validation, with holdout there is only one element) then averaging is not possible for precrec hence the error (although theoretically it could).

Here is a piece of code that can be used to plot ROC curves with holdout (or any other types of resampling). Using the code in the answer we can do the following:

roc_data <- evalmod(as_precrec(bmr), mode = "rocprc", calc_avg = FALSE)  %>% # setting calc_avg to FALSE is critical
  fortify() %>% # precrec objects have a fortify generic function
  .[.$curvetype == "ROC", ] # both roc and prc are returned

# Tracer les courbes
ggplot(
  data = roc_data,
  mapping = aes(x = x, y = y, color = modname)
) +
  geom_line()

This code also has the advantage of being a ggplot object so it can be modified easily with ggplot2 which is not the case for precrec::autoplot().