What I'd like to do

I am trying to build a model in tidymodels that will predict the efficacy of drugs on cell lines (like bacteria). The model will rank drugs by efficacy for a given cell line, so I want to use Spearman's correlation (ρ) as a metric. In the following example data set, each cell line (column Sample) is represented by a letter, Q, R, S, ..., Z, and each sample was treated with 50 drugs.

When I split the data for cross-validation, the training/test splits for each fold will have >1 cell line (e.g. Q, R in the test split for fold 1), but in calculating the metric (ρ), I want to calculate it for each cell line individually and then take the average across all the cell lines in the test split, rather than for all the observations in aggregate. For example, if the test split for fold 1 consists of Q, R, then I want to calculate ρ for the 50 drugs tested against Q, then a separate ρ for the 50 drugs tested against R, average these two ρ, and have that average be the metric calculated for fold 1.

What I've tried

I was thinking that I'd have to calculate the metric on a tibble/data.frame grouped by the Sample column, but I can't figure out how to pass that variable into tune_grid(). I don't think I can include the variable in add_formula() when creating the workflow object, since I don't want it as a predictor variable. I just discovered tidymodels yesterday, so maybe there's a straightforward solution I'm unaware of, but I haven't been able to find anything on Google so far. The code below is what I've tried, but obviously it doesn't work. Thank you in advance for any advice you can give.

Error

i Resample1: preprocessor 1/1
✓ Resample1: preprocessor 1/1
i Resample1: preprocessor 1/1, model 1/20
✓ Resample1: preprocessor 1/1, model 1/20
i Resample1: preprocessor 1/1, model 1/20 (predictions)
x Resample1: internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm ...
i Resample2: preprocessor 1/1
✓ Resample2: preprocessor 1/1
i Resample2: preprocessor 1/1, model 1/20
✓ Resample2: preprocessor 1/1, model 1/20
i Resample2: preprocessor 1/1, model 1/20 (predictions)
x Resample2: internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm ...
i Resample3: preprocessor 1/1
✓ Resample3: preprocessor 1/1
i Resample3: preprocessor 1/1, model 1/20
✓ Resample3: preprocessor 1/1, model 1/20
i Resample3: preprocessor 1/1, model 1/20 (predictions)
x Resample3: internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm ...
i Resample4: preprocessor 1/1
✓ Resample4: preprocessor 1/1
i Resample4: preprocessor 1/1, model 1/20
✓ Resample4: preprocessor 1/1, model 1/20
i Resample4: preprocessor 1/1, model 1/20 (predictions)
x Resample4: internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm ...
i Resample5: preprocessor 1/1
✓ Resample5: preprocessor 1/1
i Resample5: preprocessor 1/1, model 1/20
✓ Resample5: preprocessor 1/1, model 1/20
i Resample5: preprocessor 1/1, model 1/20 (predictions)
x Resample5: internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm ...
Warning message:
All models failed. See the `.notes` column. 

Upon running glmnet_tuning_results:

Warning message:
This tuning result has notes. Example notes on model fitting include:
internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm = ~na_rm)
internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm = ~na_rm)
internal: Error: In metric: `spearman_cor`
unused arguments (truth = ~TargetVariable, estimate = ~.pred, na_rm = ~na_rm)

Code

Example data set

data = tibble(
  Sample = rep(LETTERS[17:26], each = 50),
  TargetVariable = rnorm(500, mean = 0, sd = 1),
  PredictorVariable1 = rnorm(500, mean = 5, sd = 1),
  PredictorVariable2 = rpois(500, lambda = 5)
)

Model

# Splitting for cross-validation.
set.seed(1026)
folds = group_vfold_cv(data, group = Sample, v = 5)

# Model specification.
glmnet_model = linear_reg(
  mode    = "regression", 
  penalty = tune(), 
  mixture = tune()
) %>%
  set_engine("glmnet")

# Workflow.
glmnet_wf = workflow() %>%
  add_model(glmnet_model) %>% 
  add_formula(TargetVariable ~ . - Sample)

# Grid specification.
glmnet_params = parameters(penalty(), mixture())
set.seed(1026)
glmnet_grid = grid_max_entropy(glmnet_params, size = 20)

# Hyperparameter tuning.
glmnet_tuning_results = tune_grid(
  glmnet_wf,
  resamples = folds,
  grid      = glmnet_grid,
  metrics   = metric_set(spearman_cor),
  control   = control_grid(verbose = TRUE)
)

glmnet_tuning_results %>% show_best(n = 10)

Custom metric

# Vector version.
spearman_cor_vec = function(truth, estimate, na_rm = TRUE) {
  
  spearman_cor_impl = function(truth, estimate) {
    cor(truth, estimate, method = "spearman")
  }
  
  metric_vec_template(
    metric_impl = spearman_cor_impl,
    truth = truth, 
    estimate = estimate,
    na_rm = na_rm,
    cls = "numeric"
  )
}
# Data frame version. 
spearman_cor = function(data) {
  UseMethod("spearman_cor")
}

spearman_cor = new_numeric_metric(spearman_cor, direction = "maximize")

spearman_cor.data.frame = function(data, truth, estimate, na_rm = TRUE) {
  
  data_grouped = data %>%
    group_by(Sample)
  
  metric_summarizer(
    metric_nm = "spearman_cor",
    metric_fn = spearman_cor_vec,
    data = data_grouped,
    truth = !! enquo(truth),
    estimate = !! enquo(estimate), 
    na_rm = na_rm
  )
  
}

Session info

sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 3.6.3 (2020-02-29)
#>  os       macOS Catalina 10.15.7      
#>  system   x86_64, darwin15.6.0        
#>  ui       X11                         
#>  language (EN)                        
#>  collate  en_US.UTF-8                 
#>  ctype    en_US.UTF-8                 
#>  tz       America/Chicago             
#>  date     2021-08-25                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package     * version date       lib source        
#>  backports     1.1.6   2020-04-05 [1] CRAN (R 3.6.2)
#>  cli           3.0.1   2021-07-17 [1] CRAN (R 3.6.2)
#>  crayon        1.3.4   2017-09-16 [1] CRAN (R 3.6.0)
#>  digest        0.6.25  2020-02-23 [1] CRAN (R 3.6.0)
#>  ellipsis      0.3.2   2021-04-29 [1] CRAN (R 3.6.2)
#>  evaluate      0.14    2019-05-28 [1] CRAN (R 3.6.0)
#>  fansi         0.4.1   2020-01-08 [1] CRAN (R 3.6.0)
#>  fs            1.3.1   2019-05-06 [1] CRAN (R 3.6.0)
#>  glue          1.4.0   2020-04-03 [1] CRAN (R 3.6.2)
#>  highr         0.8     2019-03-20 [1] CRAN (R 3.6.0)
#>  htmltools     0.5.1.1 2021-01-22 [1] CRAN (R 3.6.2)
#>  knitr         1.27    2020-01-16 [1] CRAN (R 3.6.0)
#>  lifecycle     1.0.0   2021-02-15 [1] CRAN (R 3.6.2)
#>  magrittr      2.0.1   2020-11-17 [1] CRAN (R 3.6.2)
#>  pillar        1.6.2   2021-07-29 [1] CRAN (R 3.6.2)
#>  pkgconfig     2.0.3   2019-09-22 [1] CRAN (R 3.6.0)
#>  purrr         0.3.4   2020-04-17 [1] CRAN (R 3.6.2)
#>  Rcpp          1.0.4.6 2020-04-09 [1] CRAN (R 3.6.1)
#>  reprex        2.0.1   2021-08-05 [1] CRAN (R 3.6.2)
#>  rlang         0.4.10  2020-12-30 [1] CRAN (R 3.6.2)
#>  rmarkdown     2.1     2020-01-20 [1] CRAN (R 3.6.0)
#>  rstudioapi    0.13    2020-11-12 [1] CRAN (R 3.6.2)
#>  sessioninfo   1.1.1   2018-11-05 [1] CRAN (R 3.6.0)
#>  stringi       1.4.5   2020-01-11 [1] CRAN (R 3.6.0)
#>  stringr       1.4.0   2019-02-10 [1] CRAN (R 3.6.0)
#>  styler        1.5.1   2021-07-13 [1] CRAN (R 3.6.2)
#>  tibble        3.1.3   2021-07-23 [1] CRAN (R 3.6.2)
#>  utf8          1.1.4   2018-05-24 [1] CRAN (R 3.6.0)
#>  vctrs         0.3.8   2021-04-29 [1] CRAN (R 3.6.2)
#>  withr         2.4.2   2021-04-18 [1] CRAN (R 3.6.2)
#>  xfun          0.12    2020-01-13 [1] CRAN (R 3.6.0)
#>  yaml          2.2.0   2018-07-25 [1] CRAN (R 3.6.0)
#> 
#> [1] /Library/Frameworks/R.framework/Versions/3.6/Resources/library
1

There are 1 answers

0
Julia Silge On

To make your custom metric work, you were just missing some ... so arguments could be passed through:

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip

spearman_cor_vec <- function(truth, estimate, na_rm = TRUE) {
    
    spearman_cor_impl <- function(truth, estimate) {
        cor(truth, estimate, method = "spearman")
    }
    
    metric_vec_template(
        metric_impl = spearman_cor_impl,
        truth = truth, 
        estimate = estimate,
        na_rm = na_rm,
        cls = "numeric"
    )
}


spearman_cor <- function(data, ...) {    ## these dots were missing
    UseMethod("spearman_cor")
}

spearman_cor <- new_numeric_metric(spearman_cor, direction = "maximize")

spearman_cor.data.frame <- function(data, truth, estimate, na_rm = TRUE) {
    
    data_grouped = data %>%
        group_by(Sample)
    
    metric_summarizer(
        metric_nm = "spearman_cor",
        metric_fn = spearman_cor_vec,
        data = data_grouped,
        truth = !! enquo(truth),
        estimate = !! enquo(estimate), 
        na_rm = na_rm
    )
    
}

This makes it so you can use this metric on a dataset like so:


df <- tibble(
    Sample = rep(LETTERS[17:26], each = 50),
    TargetVariable = rnorm(500, mean = 0, sd = 1),
    Pred1 = rnorm(500, mean = 5, sd = 1),
    Pred2 = rpois(500, lambda = 5)
)


df %>% 
    mutate(.pred = TargetVariable + rnorm(500, mean = 0, sd = 0.2)) %>% 
    spearman_cor(TargetVariable, .pred)
#> # A tibble: 10 × 4
#>    Sample .metric      .estimator .estimate
#>    <chr>  <chr>        <chr>          <dbl>
#>  1 Q      spearman_cor standard       0.980
#>  2 R      spearman_cor standard       0.975
#>  3 S      spearman_cor standard       0.983
#>  4 T      spearman_cor standard       0.985
#>  5 U      spearman_cor standard       0.978
#>  6 V      spearman_cor standard       0.963
#>  7 W      spearman_cor standard       0.975
#>  8 X      spearman_cor standard       0.979
#>  9 Y      spearman_cor standard       0.987
#> 10 Z      spearman_cor standard       0.969

Created on 2021-08-31 by the reprex package (v2.0.1)

However, this doesn't totally solve your problem because for the tuning functions, we typically only pass the predictors and outcomes, not any extra variables with other roles. I worked with this a little bit and couldn't quite figure out a way to get the tuning function to have a variable only for computing metrics and not for fitting. I don't believe we support this right now; you might want to create a reprex, explain your use case, and post an issue on the tune repo so we can prioritize a new feature like this.