Improve parallel performance with batching in a static-dynamic branching pipeline

219 views Asked by At

BLUF: I am struggling to understand out how to use batching in the R targets package to improve performance in a static and dynamic branching pipeline processed in parallel using tar_make_future(). I presume that I need to batch within each dynamic branch but I am unsure how to go about doing that.

Here's a reprex that uses dynamic branching nested inside static branching, similar to what my actual pipeline is doing. It first branches statically for each value in all_types, and then dynamically branches within each category. This code produces 1,000 branches and 1,010 targets total. In the actual workflow I obviously don't use replicate, and the dynamic branches vary in number depending on the type value.

# _targets.r

library(targets)
library(tarchetypes)
library(future)
library(future.callr)

plan(callr)

all_types = data.frame(type = LETTERS[1:10])

tar_map(values = all_types, names = "type",
  tar_target(
    make_data,
    replicate(100,
      data.frame(x = seq(1000) + rnorm(1000, 0, 5),
                 y = seq(1000) + rnorm(1000, 20, 20)),
      simplify = FALSE
    ),
    iteration = "list"
  ),
  tar_target(
    fit_model,
    lm(make_data),
    pattern = map(make_data),
    iteration = "list"
  )
)

And here's a timing comparison of tar_make() vs tar_make_future() with eight workers:

# tar_destroy()
t1 <- system.time(tar_make())
# tar_destroy()
t2 <- system.time(tar_make_future(workers = 8))

rbind(serial = t1, parallel = t2)

##          user.self sys.self elapsed user.child sys.child
## serial        2.12     0.11   25.59         NA        NA
## parallel      2.07     0.24  184.68         NA        NA

I don't think the user or system fields are useful here since the job gets dispatched to separate R processes, but the elapsed time for the parallel job takes about 7 times longer than the serial job.

I presume this slowdown is caused by the large number of targets. Will batching improve performance in this case, and if so how can I implement batching within the dynamic branch?

1

There are 1 answers

3
landau On BEST ANSWER

You are on the right track with batching. In your case, that is a matter of breaking up your list of 100 datasets into groups of, say, 10 or so. You could do this with a nested list of datasets, but that's a lot of work. Luckily, there is an easier way.

Your question is actually really well-timed. I just wrote some new target factories in tarchetypes that could help. To access them, you will need the development version of tarchetypes from GitHub:

remotes::install_github("ropensci/tarchetypes")

Then, with tar_map2_count(), it will be much easier to batch your list of 100 datasets for each scenario.

library(targets)
tar_script({
  library(broom)
  library(targets)
  library(tarchetypes)
  library(tibble)

  make_data <- function(n) {
    datasets_per_batch <- replicate(
      100,
      tibble(
        x = seq(n) + rnorm(n, 0, 5),
        y = seq(n) + rnorm(n, 20, 20)
      ),
      simplify = FALSE
    )
    tibble(dataset = datasets_per_batch, rep = seq_along(datasets_per_batch))
  }

  tar_map2_count(
    name = model,
    command1 = make_data(n = rows),
    command2 = tidy(lm(y ~ x, data = dataset)), # Need dataset[[1]] in tarchetypes 0.4.0
    values = data_frame(
      scenario = LETTERS[seq_len(10)],
      rows = seq(10, 100, length.out = 10)
    ),
    columns2 = NULL,
    batches = 10
  )
})
tar_make(reporter = "silent")
#> Warning message:
#> `data_frame()` was deprecated in tibble 1.1.0.
#> Please use `tibble()` instead.
#> This warning is displayed once every 8 hours.
#> Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
tar_read(model)
#> # A tibble: 2,000 × 8
#>    term        estimate std.error statistic   p.value scenario  rows tar_group
#>    <chr>          <dbl>     <dbl>     <dbl>     <dbl> <chr>    <dbl>     <int>
#>  1 (Intercept)   17.1      12.8       1.34  0.218     A           10        10
#>  2 x              1.39      1.35      1.03  0.333     A           10        10
#>  3 (Intercept)    6.42     14.0       0.459 0.658     A           10        10
#>  4 x              1.75      1.28      1.37  0.209     A           10        10
#>  5 (Intercept)   32.8       7.14      4.60  0.00176   A           10        10
#>  6 x             -0.300     1.14     -0.263 0.799     A           10        10
#>  7 (Intercept)   29.7       3.24      9.18  0.0000160 A           10        10
#>  8 x              0.314     0.414     0.758 0.470     A           10        10
#>  9 (Intercept)   20.0      13.6       1.47  0.179     A           10        10
#> 10 x              1.23      1.77      0.698 0.505     A           10        10
#> # … with 1,990 more rows

Created on 2021-12-10 by the reprex package (v2.0.1)

There is also tar_map_rep(), which may be easier if all your datasets are randomly generated, but I am not sure if I am overfitting your use case.

library(targets)
tar_script({
  library(broom)
  library(targets)
  library(tarchetypes)
  library(tibble)

  make_one_dataset <- function(n) {
    tibble(
      x = seq(n) + rnorm(n, 0, 5),
      y = seq(n) + rnorm(n, 20, 20)
    )
  }

  tar_map_rep(
    name = model,
    command = tidy(lm(y ~ x, data = make_one_dataset(n = rows))),
    values = data_frame(
      scenario = LETTERS[seq_len(10)],
      rows = seq(10, 100, length.out = 10)
    ),
    batches = 10,
    reps = 10
  )
})
tar_make(reporter = "silent")
#> Warning message:
#> `data_frame()` was deprecated in tibble 1.1.0.
#> Please use `tibble()` instead.
#> This warning is displayed once every 8 hours.
#> Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
tar_read(model)
#> # A tibble: 2,000 × 10
#>    term    estimate std.error statistic p.value scenario  rows tar_batch tar_rep
#>    <chr>      <dbl>     <dbl>     <dbl>   <dbl> <chr>    <dbl>     <int>   <int>
#>  1 (Inter…   37.5        7.50     5.00  0.00105 A           10         1       1
#>  2 x         -0.701      1.17    -0.601 0.564   A           10         1       1
#>  3 (Inter…   21.5        9.64     2.23  0.0567  A           10         1       2
#>  4 x         -0.213      1.55    -0.138 0.894   A           10         1       2
#>  5 (Inter…   20.6        9.51     2.17  0.0620  A           10         1       3
#>  6 x          1.40       1.79     0.783 0.456   A           10         1       3
#>  7 (Inter…   11.6       11.2      1.04  0.329   A           10         1       4
#>  8 x          2.34       1.39     1.68  0.131   A           10         1       4
#>  9 (Inter…   26.8        9.16     2.93  0.0191  A           10         1       5
#> 10 x          0.288      1.10     0.262 0.800   A           10         1       5
#> # … with 1,990 more rows, and 1 more variable: tar_group <int>

Created on 2021-12-10 by the reprex package (v2.0.1)

Unfortunately, futures do come with overhead. Maybe it will be faster in your case if you try tar_make_clustermq()?