Lasso regression does not work properly in tidyverse

77 views Asked by At

I was trying to replicate Julia Silge's great lesson on Lasso regression using tidymodel but even if the data cleaning gives exactly the same result, I have found that the same coding for the regression gives a completely different result, which then messes up the whole prediction phase.

A reproducible example is the following:

# Load libraries
library(tidymodels)
library(tidyverse)
library(schrute)
library(janitor)
tidymodels_prefer()

# Load data
ratings_raw <- read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2020/2020-03-17/office_ratings.csv")
#> Rows: 188 Columns: 6
#> ── Column specification ────────────────────────────────────────────────────────
#> Delimiter: ","
#> chr  (1): title
#> dbl  (4): season, episode, imdb_rating, total_votes
#> date (1): air_date
#> 
#> ℹ Use `spec()` to retrieve the full column specification for this data.
#> ℹ Specify the column types or set `show_col_types = FALSE` to quiet this message.
schrute::theoffice
#> # A tibble: 55,130 × 12
#>    index season episode episode_name director   writer           character text 
#>    <int>  <int>   <int> <chr>        <chr>      <chr>            <chr>     <chr>
#>  1     1      1       1 Pilot        Ken Kwapis Ricky Gervais;S… Michael   All …
#>  2     2      1       1 Pilot        Ken Kwapis Ricky Gervais;S… Jim       Oh, …
#>  3     3      1       1 Pilot        Ken Kwapis Ricky Gervais;S… Michael   So y…
#>  4     4      1       1 Pilot        Ken Kwapis Ricky Gervais;S… Jim       Actu…
#>  5     5      1       1 Pilot        Ken Kwapis Ricky Gervais;S… Michael   All …
#>  6     6      1       1 Pilot        Ken Kwapis Ricky Gervais;S… Michael   Yes,…
#>  7     7      1       1 Pilot        Ken Kwapis Ricky Gervais;S… Michael   I've…
#>  8     8      1       1 Pilot        Ken Kwapis Ricky Gervais;S… Pam       Well…
#>  9     9      1       1 Pilot        Ken Kwapis Ricky Gervais;S… Michael   If y…
#> 10    10      1       1 Pilot        Ken Kwapis Ricky Gervais;S… Pam       What?
#> # ℹ 55,120 more rows
#> # ℹ 4 more variables: text_w_direction <chr>, imdb_rating <dbl>,
#> #   total_votes <int>, air_date <chr>

# We wish to join up the two data sets. We need to change the way
# season and episodes are counted. So we need to do some work to join them up
remove_regex <- "[:punct:]|[:digit:]|parts |part |the |and"
office_ratings <- ratings_raw |> 
    transmute(episode_name = str_to_lower(title),
              episode_name = str_remove_all(episode_name, remove_regex),
              episode_name = str_trim(episode_name),
              imdb_rating)

office_info <- schrute::theoffice |> 
    mutate(season = as.numeric(season),
           episode = as.numeric(episode),
           episode_name = str_to_lower(episode_name),
           episode_name = str_remove_all(episode_name, remove_regex),
           episode_name = str_trim(episode_name)) |> 
    select(season, episode, episode_name, director, writer, character)

# Let us count how many line every character has per episode
characters <- office_info |> 
    count(episode_name, character) |> 
    add_count(character, wt = n, name = "character_count") |> 
    filter(character_count > 800) |> 
    select(-character_count) |> 
    pivot_wider(names_from = character, 
                values_from = n,
                values_fill = list(n = 0))

# We want to do the same for writer and directors
creators <- office_info |> 
    distinct(episode_name, director, writer) |> 
    pivot_longer(director:writer, names_to = "role", values_to = "person") |> 
    separate_rows(person, sep = ";") |> 
    add_count(person) |> 
    filter(n > 10) |> 
    distinct(episode_name, person) |> 
    mutate(person_value = 1) |> 
    pivot_wider(names_from = person,
                values_from = person_value,
                values_fill = list(person_value = 0))

office <- office_info |> 
    distinct(season, episode, episode_name) |> 
    inner_join(characters) |> 
    inner_join(creators) |> 
    inner_join(office_ratings) |> 
    janitor::clean_names()
#> Joining with `by = join_by(episode_name)`
#> Joining with `by = join_by(episode_name)`
#> Joining with `by = join_by(episode_name)`
#> Warning in inner_join(inner_join(inner_join(distinct(office_info, season, : Detected an unexpected many-to-many relationship between `x` and `y`.
#> ℹ Row 71 of `x` matches multiple rows in `y`.
#> ℹ Row 79 of `y` matches multiple rows in `x`.
#> ℹ If a many-to-many relationship is expected, set `relationship =
#>   "many-to-many"` to silence this warning.


# Training the model
# We are going to use a lasso regression model
set.seed(1234)
office_split <- initial_split(office, strata = season)
office_train <- training(office_split)
office_test <- testing(office_split)

# Let us use a recipe
office_rec <- recipe(imdb_rating ~ ., data = office_train) |> 
    update_role(episode_name, new_role = "ID") |> # It is no longer a predictor
    step_zv(all_numeric(), - all_outcomes()) |> 
    step_normalize(all_numeric(), - all_outcomes())

office_prep <- office_rec |> 
    prep(strings_as_factors = FALSE)

# We can now train the model
lasso_spec <- linear_reg(penalty = 0.1, mixture = 1) |> 
    set_engine("glmnet")

wf <- workflow() |> 
    add_recipe(office_rec)

lasso_fit <- wf |> 
    add_model(lasso_spec) |> 
    fit(data = office_train)

lasso_fit %>%
    pull_workflow_fit() %>%
    tidy()
#> Warning: `pull_workflow_fit()` was deprecated in workflows 0.2.3.
#> ℹ Please use `extract_fit_parsnip()` instead.
#> This warning is displayed once every 8 hours.
#> Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
#> generated.
#> Loading required package: Matrix
#> 
#> Attaching package: 'Matrix'
#> 
#> The following objects are masked from 'package:tidyr':
#> 
#>     expand, pack, unpack
#> 
#> Loaded glmnet 4.1-8
#> # A tibble: 31 × 3
#>    term        estimate penalty
#>    <chr>          <dbl>   <dbl>
#>  1 (Intercept)  8.37        0.1
#>  2 season       0           0.1
#>  3 episode      0           0.1
#>  4 andy         0           0.1
#>  5 angela       0.00234     0.1
#>  6 darryl       0           0.1
#>  7 dwight       0           0.1
#>  8 jim          0.00150     0.1
#>  9 kelly        0           0.1
#> 10 kevin        0           0.1
#> # ℹ 21 more rows

The outcome that she gets is instead the following:

enter image description here

Can someone please explain me what is happening here, because I don't really know what do do to make this regression work.

0

There are 0 answers