extract_parameter_set_dials() fails for self-written gbm engine for boost_tree()

116 views Asked by At

I followed this vignette of the tidymodels package to make the gbm an engine for the boost_tree() function.

I came up with the following for regression tasks:

library(tidyverse)
library(tidymodels)
library(workflow)
library(tune)
install.packages(gbm) #just install do not load
{
    set_model_engine(
      model = "boost_tree", 
      mode  = "regression", 
      eng  = "gbm"
    )
    set_dependency("boost_tree", eng = "gbm", pkg = "gbm")
    set_model_arg(
      model        = "boost_tree", 
      eng          = "gbm", 
      parsnip      = "trees", 
      original     = "n.trees",
      func         = list(pkg = "gbm", fun = "gbm"),
      has_submodel = TRUE
    )
    set_model_arg(
      model        = "boost_tree", 
      eng          = "gbm", 
      parsnip      = "tree_depth", 
      original     = "interaction.depth", 
      func         = list(pkg = "gbm", fun = "gbm"), 
      has_submodel = TRUE
    )
    set_model_arg(
      model        = "boost_tree", 
      eng          = "gbm", 
      parsnip      = "learn_rate", 
      original     = "shrinkage", 
      func         = list(pkg = "gbm", fun = "gbm"), 
      has_submodel = TRUE
    )
    
    set_encoding(
      model = "boost_tree",
      eng = "gbm",
      mode = "regression",
      options = list(
        predictor_indicators = "none",
        compute_intercept = FALSE,
        remove_intercept = FALSE,
        allow_sparse_x = FALSE
      )
    )
    
    gbm <- function(mode = "regression", trees = NULL, 
                    tree_depth = NULL, learn_rate = NULL, cv.folds = 1) {
      # make sure mode is regression
      if(mode != "regression") {
        stop("`mode` should be 'regression'", call. = FALSE)
      }
      
      # capture argument in quosures
      args <- list(
        trees      = rlang::enquo(trees), 
        tree_depth = rlang::enquo(tree_depth), 
        learn_rate = rlang::enquo(learn_rate)
      )
      
      # empty slots for future parts of specification
      out <- list(args = args, eng_args = NULL, 
                  mode = mode, method = NULL, engine = NULL)
      
      # set class in correct order
      class(out) <- make_classes("gbm")
      out
    }
    
    set_fit(
      model = "boost_tree", 
      eng   = "gbm", 
      mode  = "regression", 
      value = list(
        interface = "formula", # other possible values are "data.frame", "matrix"
        protect = c("formula", "data"), # nonchangeable user-arguments
        func = c(pkg = "gbm", fun = "gbm"), # call gbm::gbm()
        defaults = list(
          distribution = "gaussian", 
          n.cores = NULL, 
          verbose = FALSE
        ) # default argument changeable by user
      )
    )
    
    set_pred(
      model = "boost_tree", 
      eng   = "gbm", 
      mode  = "regression", 
      type  = "numeric", 
      value = list(
        pre = NULL, 
        post = NULL, 
        func = c(fun = "predict"), 
        args = list(
          object = expr(object$fit), 
          newdata = expr(new_data), 
          n.trees = expr(object$fit$n.trees),
          type = "response", 
          single.tree = TRUE
        )
      )
    )
    

}

But if I try to use this engine to tune the hyperparameters using tune_bayes() from the parsnip package my code fails to extract the parameter set from the workflow:

rec <- recipe(mpg ~.,mtcars)

model_tune <- parsnip::boost_tree(
        mode = 'regression',
        trees = 1000,
        tree_depth = tune(),
        learn_rate = tune()

model_wflow <- workflow() %>%
  add_model(model_tune) %>%
  add_recipe(rec)


HP_set <- extract_parameter_set_dials(model_wflow, tree_depth(range = c(1,100)))
HP_set

The function extract_parameter_set_dials() always prompts the following error :

Error in `mutate()`:
! Problem while computing `object = purrr::map(call_info, eval_call_info)`.
Caused by error in `.f()`:
! Error when calling gbm(): Error in terms.formula(formula, data = data) : 
  argument is not a valid model

Maybe this has something to do with the set_fit() options in the engine settings but that is just a wild guess.

How can I use the gbm engine for boost_tree() and tune the hyperparameter with tune_bayes()?

1

There are 1 answers

0
Julia Silge On BEST ANSWER

You were really close but there were a couple of issues:

  • the submodel trick was not working for this model, so I just turned it off (this would take some deeper exploration to see if it would work at all)
  • set_model_arg() calls should reference the functions in the dials package
library(tidymodels)

set_model_engine(model = "boost_tree",
                 mode  = "regression",
                 eng  = "gbm")
set_dependency("boost_tree", eng = "gbm", pkg = "gbm")
set_model_arg(
    model        = "boost_tree",
    eng          = "gbm",
    parsnip      = "trees",
    original     = "n.trees",
    func         = list(pkg = "dials", fun = "trees"),  # <- change here
    has_submodel = FALSE
)
set_model_arg(
    model        = "boost_tree",
    eng          = "gbm",
    parsnip      = "tree_depth",
    original     = "interaction.depth",
    func         = list(pkg = "dials", fun = "tree_depth"), # <- change here
    has_submodel = FALSE
)
set_model_arg(
    model        = "boost_tree",
    eng          = "gbm",
    parsnip      = "learn_rate",
    original     = "shrinkage",
    func         = list(pkg = "dials", fun = "learn_rate"), # <- change here
    has_submodel = FALSE
)

set_encoding(
    model = "boost_tree",
    eng = "gbm",
    mode = "regression",
    options = list(
        predictor_indicators = "none",
        compute_intercept = FALSE,
        remove_intercept = FALSE,
        allow_sparse_x = FALSE
    )
)

set_fit(
    model = "boost_tree",
    eng   = "gbm",
    mode  = "regression",
    value = list(
        interface = "formula",
        # other possible values are "data.frame", "matrix"
        protect = c("formula", "data"),
        # nonchangeable user-arguments
        func = c(pkg = "gbm", fun = "gbm"),
        # call gbm::gbm()
        defaults = list(
            distribution = "gaussian",
            n.cores = NULL,
            verbose = FALSE
        ) # default argument changeable by user
    )
)

set_pred(
    model = "boost_tree",
    eng   = "gbm",
    mode  = "regression",
    type  = "numeric",
    value = list(
        pre = NULL,
        post = NULL,
        func = c(fun = "predict"),
        args = list(
            object = expr(object$fit),
            newdata = expr(new_data),
            n.trees = expr(object$fit$n.trees),
            type = "response",
            single.tree = TRUE
        )
    )
)


model_spec <- parsnip::boost_tree(
    mode = "regression",
    trees = 1000,
    tree_depth = tune(),
    learn_rate = tune()
) %>%
    set_engine("gbm")

data(Sacramento)

model_wflow <- workflow(price ~ beds + baths + sqft, model_spec) 
extract_parameter_set_dials(model_wflow, tree_depth(range = c(1, 100)))
#> Collection of 2 parameters for tuning
#> 
#>  identifier       type    object
#>  tree_depth tree_depth nparam[+]
#>  learn_rate learn_rate nparam[+]

tune_bayes(
    model_wflow,
    resamples = bootstraps(Sacramento, times = 5),
    iter = 3
)
#> # Tuning results
#> # Bootstrap sampling 
#> # A tibble: 20 × 5
#>    splits            id         .metrics          .notes           .iter
#>    <list>            <chr>      <list>            <list>           <int>
#>  1 <split [932/341]> Bootstrap1 <tibble [10 × 6]> <tibble [0 × 3]>     0
#>  2 <split [932/348]> Bootstrap2 <tibble [10 × 6]> <tibble [0 × 3]>     0
#>  3 <split [932/336]> Bootstrap3 <tibble [10 × 6]> <tibble [0 × 3]>     0
#>  4 <split [932/348]> Bootstrap4 <tibble [10 × 6]> <tibble [0 × 3]>     0
#>  5 <split [932/359]> Bootstrap5 <tibble [10 × 6]> <tibble [0 × 3]>     0
#>  6 <split [932/341]> Bootstrap1 <tibble [2 × 6]>  <tibble [0 × 3]>     1
#>  7 <split [932/348]> Bootstrap2 <tibble [2 × 6]>  <tibble [0 × 3]>     1
#>  8 <split [932/336]> Bootstrap3 <tibble [2 × 6]>  <tibble [0 × 3]>     1
#>  9 <split [932/348]> Bootstrap4 <tibble [2 × 6]>  <tibble [0 × 3]>     1
#> 10 <split [932/359]> Bootstrap5 <tibble [2 × 6]>  <tibble [0 × 3]>     1
#> 11 <split [932/341]> Bootstrap1 <tibble [2 × 6]>  <tibble [0 × 3]>     2
#> 12 <split [932/348]> Bootstrap2 <tibble [2 × 6]>  <tibble [0 × 3]>     2
#> 13 <split [932/336]> Bootstrap3 <tibble [2 × 6]>  <tibble [0 × 3]>     2
#> 14 <split [932/348]> Bootstrap4 <tibble [2 × 6]>  <tibble [0 × 3]>     2
#> 15 <split [932/359]> Bootstrap5 <tibble [2 × 6]>  <tibble [0 × 3]>     2
#> 16 <split [932/341]> Bootstrap1 <tibble [2 × 6]>  <tibble [0 × 3]>     3
#> 17 <split [932/348]> Bootstrap2 <tibble [2 × 6]>  <tibble [0 × 3]>     3
#> 18 <split [932/336]> Bootstrap3 <tibble [2 × 6]>  <tibble [0 × 3]>     3
#> 19 <split [932/348]> Bootstrap4 <tibble [2 × 6]>  <tibble [0 × 3]>     3
#> 20 <split [932/359]> Bootstrap5 <tibble [2 × 6]>  <tibble [0 × 3]>     3

Created on 2022-05-03 by the reprex package (v2.0.1)