How do we calculate multiclass probabilities in yardstick?

186 views Asked by At

I have a multiclass classification problem and want to build a precision-recall curve using pr_curve from yardstick library in R. This function requires that a tibble with probabilities for each class were fed to it, like this (this is data(hpc_cv)). enter image description here How do I get there from my classification results, stored as columns in a tibble?

library(yardstick)
data <- tibble(predicted = as.factor(c("A", "A", "B", "B", "C", "C")), 
               expected = as.factor(c("A", "B", "B", "C", "A", "C")))
data %>% conf_mat(truth = expected, estimate = predicted)

I have not found a function in yardstick (or elsewhere) to calculate those.

I am not sure how class probs are calculated, I am thinking along these lines:

data %>% filter(predicted == "A") %>% summarise(n = n() / 6)

Is this correct? If so, I wonder if there is a nice way to do it without for-loops on each class in each fold, and to receive a tibble like hpc_cv on the picture above.

1

There are 1 answers

1
topepo On BEST ANSWER

I am not sure how class probs are calculated

Class probabilities are generated by a specific model for each individual data point.

PR curves (and precision and recall) are metrics for data sets where the outcome has two classes. You can do multiclass averaging to get an overall PR curve AUC though.

There is an example below but I would advise reading the tidymodels book for a bit before proceeding.

library(nnet) # <- for mutlinom_fit
library(tidymodels)

tidymodels_prefer()

data(hpc_data, package = "modeldata")

set.seed(1)
hpc_split <- initial_split(hpc_data)
hpc_train <- training(hpc_split)
hpc_test  <- testing(hpc_split)

set.seed(2)
mutlinom_fit <- 
  multinom_reg() %>% 
  fit(class ~ iterations + compounds, data = hpc_train)

test_predictions <- augment(mutlinom_fit, new_data = hpc_test)

# examples of the hard class predictions and the 
# predicted probabilities: 
test_predictions %>% select(starts_with(".pred")) %>% head()
#> # A tibble: 6 × 5
#>   .pred_class .pred_VF .pred_F .pred_M .pred_L
#>   <fct>          <dbl>   <dbl>   <dbl>   <dbl>
#> 1 VF             0.641   0.279  0.0670  0.0128
#> 2 VF             0.640   0.280  0.0671  0.0128
#> 3 VF             0.628   0.287  0.0711  0.0138
#> 4 VF             0.628   0.287  0.0711  0.0138
#> 5 VF             0.626   0.288  0.0716  0.0139
#> 6 VF             0.626   0.288  0.0719  0.0140

# a confusion matrix
test_predictions %>% conf_mat(class, .pred_class)
#>           Truth
#> Prediction  VF   F   M   L
#>         VF 516 278  74  16
#>         F   18  46  36   4
#>         M    2   7  19  21
#>         L    0  11   7  28

# create some metrics:
cls_metrics <- metric_set(accuracy, precision, recall, pr_auc)
# precision, recal, and the PR AUC are caluclated using macro weighting of 4 
# different 1 vs all results. 
# See https://yardstick.tidymodels.org/articles/multiclass.html

# evaluate them
test_predictions %>% 
  # See ?metric_set for more information. We pass the truth (class), all of the
  # predicted probability columns (.pred_VF:.pred_L), and the named hard class
  # predictions. 
  cls_metrics(class, .pred_VF:.pred_L, estimate = .pred_class)
#> # A tibble: 4 × 3
#>   .metric   .estimator .estimate
#>   <chr>     <chr>          <dbl>
#> 1 accuracy  multiclass     0.562
#> 2 precision macro          0.506
#> 3 recall    macro          0.411
#> 4 pr_auc    macro          0.481

Created on 2022-12-09 with reprex v2.0.2