Partial dependence plots with tidymodels and DALEX for #TidyTuesday Mario Kart world records

By Julia Silge in rstats tidymodels

May 28, 2021

This is the latest in my series of screencasts demonstrating how to use the tidymodels packages, from just starting out to tuning more complex models with many hyperparameters. Today’s screencast walks through how to train and evalute a random forest model, with this week’s #TidyTuesday dataset on Mario Kart world records. 🍄


Here is the code I used in the video, for those who prefer reading instead of or in addition to video.

Explore data

Our modeling goal is to predict whether a Mario Kart world record was achieved using a shortcut or not.

library(tidyverse)

records <- read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2021/2021-05-25/records.csv")

How are the world records distributed over time, for the various tracks?

records %>%
  ggplot(aes(date, time, color = track)) +
  geom_point(alpha = 0.5, show.legend = FALSE) +
  facet_grid(rows = vars(type), cols = vars(shortcut), scales = "free_y")

The record times decreased at first but then have been more stable. The record times are different for the different tracks, and for three lap vs. one lap times.

Build a model

Let’s start our modeling by setting up our “data budget.”

library(tidymodels)

set.seed(123)
mario_split <- records %>%
  select(shortcut, track, type, date, time) %>%
  mutate_if(is.character, factor) %>%
  initial_split(strata = shortcut)

mario_train <- training(mario_split)
mario_test <- testing(mario_split)

set.seed(234)
mario_folds <- bootstraps(mario_train, strata = shortcut)
mario_folds
## # Bootstrap sampling using stratification 
## # A tibble: 25 x 2
##    splits             id         
##    <list>             <chr>      
##  1 <split [1750/627]> Bootstrap01
##  2 <split [1750/639]> Bootstrap02
##  3 <split [1750/652]> Bootstrap03
##  4 <split [1750/644]> Bootstrap04
##  5 <split [1750/648]> Bootstrap05
##  6 <split [1750/670]> Bootstrap06
##  7 <split [1750/648]> Bootstrap07
##  8 <split [1750/660]> Bootstrap08
##  9 <split [1750/645]> Bootstrap09
## 10 <split [1750/629]> Bootstrap10
## # … with 15 more rows

For this analysis, I am tuning a decision tree model. Tree-based models are very low-maintenance when it comes to data preprocessing, but single decision trees can be pretty easy to overfit.

tree_spec <- decision_tree(
  cost_complexity = tune(),
  tree_depth = tune()
) %>%
  set_engine("rpart") %>%
  set_mode("classification")

tree_grid <- grid_regular(cost_complexity(), tree_depth(), levels = 7)

mario_wf <- workflow() %>%
  add_model(tree_spec) %>%
  add_formula(shortcut ~ .)

mario_wf
## ══ Workflow ════════════════════════════════════════════════════════════════════
## Preprocessor: Formula
## Model: decision_tree()
## 
## ── Preprocessor ────────────────────────────────────────────────────────────────
## shortcut ~ .
## 
## ── Model ───────────────────────────────────────────────────────────────────────
## Decision Tree Model Specification (classification)
## 
## Main Arguments:
##   cost_complexity = tune()
##   tree_depth = tune()
## 
## Computational engine: rpart

Let’s tune the tree parameters to find the best decision tree for this Mario Kart data set.

doParallel::registerDoParallel()

tree_res <- tune_grid(
  mario_wf,
  resamples = mario_folds,
  grid = tree_grid,
  control = control_grid(save_pred = TRUE)
)

tree_res
## # Tuning results
## # Bootstrap sampling using stratification 
## # A tibble: 25 x 5
##    splits           id         .metrics        .notes        .predictions       
##    <list>           <chr>      <list>          <list>        <list>             
##  1 <split [1750/62… Bootstrap… <tibble [98 × … <tibble [0 ×… <tibble [30,723 × …
##  2 <split [1750/63… Bootstrap… <tibble [98 × … <tibble [0 ×… <tibble [31,311 × …
##  3 <split [1750/65… Bootstrap… <tibble [98 × … <tibble [0 ×… <tibble [31,948 × …
##  4 <split [1750/64… Bootstrap… <tibble [98 × … <tibble [0 ×… <tibble [31,556 × …
##  5 <split [1750/64… Bootstrap… <tibble [98 × … <tibble [0 ×… <tibble [31,752 × …
##  6 <split [1750/67… Bootstrap… <tibble [98 × … <tibble [0 ×… <tibble [32,830 × …
##  7 <split [1750/64… Bootstrap… <tibble [98 × … <tibble [0 ×… <tibble [31,752 × …
##  8 <split [1750/66… Bootstrap… <tibble [98 × … <tibble [0 ×… <tibble [32,340 × …
##  9 <split [1750/64… Bootstrap… <tibble [98 × … <tibble [0 ×… <tibble [31,605 × …
## 10 <split [1750/62… Bootstrap… <tibble [98 × … <tibble [0 ×… <tibble [30,821 × …
## # … with 15 more rows

All done! We tried all the possible combinations of tree parameters for each resample.

Choose and evaluate final model

Now we can explore our tuning results.

collect_metrics(tree_res)
## # A tibble: 98 x 8
##    cost_complexity tree_depth .metric  .estimator  mean     n std_err .config   
##              <dbl>      <int> <chr>    <chr>      <dbl> <int>   <dbl> <chr>     
##  1   0.0000000001           1 accuracy binary     0.637    25 0.00371 Preproces…
##  2   0.0000000001           1 roc_auc  binary     0.637    25 0.0109  Preproces…
##  3   0.00000000316          1 accuracy binary     0.637    25 0.00371 Preproces…
##  4   0.00000000316          1 roc_auc  binary     0.637    25 0.0109  Preproces…
##  5   0.0000001              1 accuracy binary     0.637    25 0.00371 Preproces…
##  6   0.0000001              1 roc_auc  binary     0.637    25 0.0109  Preproces…
##  7   0.00000316             1 accuracy binary     0.637    25 0.00371 Preproces…
##  8   0.00000316             1 roc_auc  binary     0.637    25 0.0109  Preproces…
##  9   0.0001                 1 accuracy binary     0.637    25 0.00371 Preproces…
## 10   0.0001                 1 roc_auc  binary     0.637    25 0.0109  Preproces…
## # … with 88 more rows
show_best(tree_res, metric = "accuracy")
## # A tibble: 5 x 8
##   cost_complexity tree_depth .metric  .estimator  mean     n std_err .config    
##             <dbl>      <int> <chr>    <chr>      <dbl> <int>   <dbl> <chr>      
## 1   0.00316                8 accuracy binary     0.738    25 0.00248 Preprocess…
## 2   0.0000000001           8 accuracy binary     0.736    25 0.00249 Preprocess…
## 3   0.00000000316          8 accuracy binary     0.736    25 0.00249 Preprocess…
## 4   0.0000001              8 accuracy binary     0.736    25 0.00249 Preprocess…
## 5   0.00000316             8 accuracy binary     0.736    25 0.00249 Preprocess…
autoplot(tree_res)

Looks like a tree depth of 8 is best. How do the ROC curves look for the resampled training set?

collect_predictions(tree_res) %>%
  group_by(id) %>%
  roc_curve(shortcut, .pred_No) %>%
  autoplot() +
  theme(legend.position = "none")

Let’s choose the tree parameters we want to use, finalize our (tuneable) workflow with this choice, and then fit one last time to the training data and evaluate on the testing data. This is the first time we have used the test set.

choose_tree <- select_best(tree_res, metric = "accuracy")

final_res <- mario_wf %>%
  finalize_workflow(choose_tree) %>%
  last_fit(mario_split)

collect_metrics(final_res)
## # A tibble: 2 x 4
##   .metric  .estimator .estimate .config             
##   <chr>    <chr>          <dbl> <chr>               
## 1 accuracy binary         0.721 Preprocessor1_Model1
## 2 roc_auc  binary         0.847 Preprocessor1_Model1

One of the objects contained in final_res is a fitted workflow that we can save for future use or deployment (perhaps via readr::write_rds()) and use for prediction on new data.

final_fitted <- final_res$.workflow[[1]]
predict(final_fitted, mario_test[10:12, ])
## # A tibble: 3 x 1
##   .pred_class
##   <fct>      
## 1 No         
## 2 No         
## 3 Yes

We can use this fitted workflow to explore model explainability as well. Decision trees are pretty explainable already, but we might, for example, want to see a partial dependence plot for the shortcut probability and time. I like using the DALEX package for tasks like this, because it is very fully featured and has good support for tidymodels. To use DALEX with tidymodels, first you create an explainer and then you use that explainer for the task you want, like computing a PDP or Shapley explanations.

Let’s start by creating our “explainer.”

library(DALEXtra)

mario_explainer <- explain_tidymodels(
  final_fitted,
  data = dplyr::select(mario_train, -shortcut),
  y = as.integer(mario_train$shortcut),
  verbose = FALSE
)

Then let’s compute a partial dependence profile for time, grouped by type, which is three laps vs. one lap.

pdp_time <- model_profile(
  mario_explainer,
  variables = "time",
  N = NULL,
  groups = "type"
)

You can use the default plotting from DALEX by calling plot(pdp_time), but if you like to customize your plots, you can access the underlying data via pdp_time$agr_profiles and pdp_time$cp_profiles.

as_tibble(pdp_time$agr_profiles) %>%
  mutate(`_label_` = str_remove(`_label_`, "workflow_")) %>%
  ggplot(aes(`_x_`, `_yhat_`, color = `_label_`)) +
  geom_line(size = 1.2, alpha = 0.8) +
  labs(
    x = "Time to complete track",
    y = "Predicted probability of shortcut",
    color = NULL,
    title = "Partial dependence plot for Mario Kart world records",
    subtitle = "Predictions from a decision tree model"
  )

The shapes that we see here reflect how the decision tree model makes decisions along the time variable.

Posted on:
May 28, 2021
Length:
7 minute read, 1280 words
Categories:
rstats tidymodels
Tags:
rstats tidymodels
See Also:
Educational attainment in #TidyTuesday UK towns
Changes in #TidyTuesday US polling places
Empirical Bayes for #TidyTuesday Doctor Who episodes