Create branches with different subsets of data with mlr3 PipeOps

215 views Asked by At

I want to train models on different subsets of data using mlr3, and I was wondering if there a way to train models on different subsets of data in a pipeline.

What I want to do is similar to the example from R for Data Science - Chapter 25: Many models. Say we use the same data set, gapminder, a data set containing different variables for countries around the world, such as GDP and life expectancy. If I wanted to train models for life expectancy for each country, is there an easy way to create such a pipeline using mlr3?

Ideally, I want to use mlr3pipelines to create a branch in the graph for each subset (e.g. a separate branch for each country) with a model at the end. Therefore, the final graph will start at a single node, and have n trained learners at the end nodes, one for each group (i.e. country) in the data set, or a final node that aggregates the results. I would also expect it to work for new data, for example if we obtain new data in the future for 2020, I would want it to be able to create predictions for each country using the model trained for that specific country.

All the mlr3 examples I have found seem to deal with models for the entire data set, or have models trained with all the groups in the training set.

Currently, I am just manually creating a separate task for each group of data, but it would be nice to have the data subsetting step incorporated into the modelling pipeline.

1

There are 1 answers

3
ekoam On BEST ANSWER

It would help if you had functions from these two packages: dplyr and tidyr. The following code shows you how to train multiple models by country:

library(dplyr)
library(tidyr)

df <- gapminder::gapminder

by_country <- 
  df %>% 
  nest(data = -c(continent, country)) %>% 
  mutate(model = lapply(data, learn))

Note that learn is a function that takes a single dataframe as its input. I will show you how to define that function later. Now you need to know that the returned dataframe from this pipeline is as follows:

# A tibble: 142 x 4
   country     continent data              model     
   <fct>       <fct>     <list>            <list>    
 1 Afghanistan Asia      <tibble [12 x 4]> <LrnrRgrR>
 2 Albania     Europe    <tibble [12 x 4]> <LrnrRgrR>
 3 Algeria     Africa    <tibble [12 x 4]> <LrnrRgrR>
 4 Angola      Africa    <tibble [12 x 4]> <LrnrRgrR>
 5 Argentina   Americas  <tibble [12 x 4]> <LrnrRgrR>
 6 Australia   Oceania   <tibble [12 x 4]> <LrnrRgrR>
 7 Austria     Europe    <tibble [12 x 4]> <LrnrRgrR>
 8 Bahrain     Asia      <tibble [12 x 4]> <LrnrRgrR>
 9 Bangladesh  Asia      <tibble [12 x 4]> <LrnrRgrR>
10 Belgium     Europe    <tibble [12 x 4]> <LrnrRgrR>

To define the learn function, I follow the steps provided on the mlr3 website. The function is

learn <- function(df) {
  # I create a regression task as the target `lifeExp` is a numeric variable.
  task <- mlr3::TaskRegr$new(id = "gapminder", backend = df, target = "lifeExp")
  # define the learner you want to use.
  learner <- mlr3::lrn("regr.rpart")
  # train your dataset and return the trained model as an output
  learner$train(task)
}

I hope this solve your problem.

New

Consider the following steps to train your model and predict the result for each country.

create_task <- function(id, df, ratio) {
  train <- sample(nrow(df), ratio * nrow(df))
  task <- mlr3::TaskRegr$new(id = as.character(id), backend = df, target = "lifeExp")
  list(task = task, train = train, test = seq_len(nrow(df))[-train])
}

model_task <- function(learner, task_list) {
  learner$train(task_list[["task"]], row_ids = task_list[["train"]])
}

predict_result <- function(learner, task_list) {
  learner$predict(task_list[["task"]], row_ids = task_list[["test"]])
}

by_country <- 
  df %>% 
  nest(data = -c(continent, country)) %>% 
  mutate(
    task_list = Map(create_task, country, data, 0.8), 
    learner = list(mlr3::lrn("regr.rpart"))
  ) %>% 
  within({
    Map(model_task, learner, task_list)
    prediction <- Map(predict_result, learner, task_list)
  })