Hyperparameter tuning and #TidyTuesday food consumption

By Julia Silge in rstats tidymodels

February 18, 2020

Last week I published a screencast demonstrating how to use the tidymodels framework and specifically the recipes package. Today, I’m using this week’s #TidyTuesday dataset on food consumption around the world to show hyperparameter tuning!


Here is the code I used in the video, for those who prefer reading instead of or in addition to video.

Explore the data

Our modeling goal here is to predict which countries are Asian countries and which countries are not, based on their patterns of food consumption in the eleven categories from the #TidyTuesday dataset. The original data is in a long, tidy format, and includes information on the carbon emission associated with each category of food consumption.

library(tidyverse)

food_consumption <- readr::read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2020/2020-02-18/food_consumption.csv")

food_consumption
## # A tibble: 1,430 x 4
##    country   food_category            consumption co2_emmission
##    <chr>     <chr>                          <dbl>         <dbl>
##  1 Argentina Pork                           10.5          37.2
##  2 Argentina Poultry                        38.7          41.5
##  3 Argentina Beef                           55.5        1712
##  4 Argentina Lamb & Goat                     1.56         54.6
##  5 Argentina Fish                            4.36          6.96
##  6 Argentina Eggs                           11.4          10.5
##  7 Argentina Milk - inc. cheese            195.          278.
##  8 Argentina Wheat and Wheat Products      103.           19.7
##  9 Argentina Rice                            8.77         11.2
## 10 Argentina Soybeans                        0             0
## # … with 1,420 more rows

Let’s build a dataset for modeling that is wide instead of long using pivot_wider() from tidyr. We can use the countrycode package to find which continent each country is in, and create a new variable for prediction asia that tells us whether a country is in Asia or not.

library(countrycode)
library(janitor)

food <- food_consumption %>%
  select(-co2_emmission) %>%
  pivot_wider(
    names_from = food_category,
    values_from = consumption
  ) %>%
  clean_names() %>%
  mutate(continent = countrycode(
    country,
    origin = "country.name",
    destination = "continent"
  )) %>%
  mutate(asia = case_when(
    continent == "Asia" ~ "Asia",
    TRUE ~ "Other"
  )) %>%
  select(-country, -continent) %>%
  mutate_if(is.character, factor)

food
## # A tibble: 130 x 12
##     pork poultry  beef lamb_goat  fish  eggs milk_inc_cheese wheat_and_wheat…
##    <dbl>   <dbl> <dbl>     <dbl> <dbl> <dbl>           <dbl>            <dbl>
##  1  10.5    38.7  55.5      1.56  4.36 11.4             195.            103.
##  2  24.1    46.1  33.9      9.87 17.7   8.51            234.             70.5
##  3  10.9    13.2  22.5     15.3   3.85 12.5             304.            139.
##  4  21.7    26.9  13.4     21.1  74.4   8.24            226.             72.9
##  5  22.3    35.0  22.5     18.9  20.4   9.91            137.             76.9
##  6  27.6    50.0  36.2      0.43 12.4  14.6             255.             80.4
##  7  16.8    27.4  29.1      8.23  6.53 13.1             211.            109.
##  8  43.6    21.4  29.9      1.67 23.1  14.6             255.            103.
##  9  12.6    45    39.2      0.62 10.0   8.98            149.             53
## 10  10.4    18.4  23.4      9.56  5.21  8.29            288.             92.3
## # … with 120 more rows, and 4 more variables: rice <dbl>, soybeans <dbl>,
## #   nuts_inc_peanut_butter <dbl>, asia <fct>

This is not a big dataset, but it will be good for demonstrating how to tune hyperparameters. Before we get started on that, how are the categories of food consumption related? Since these are all numeric variables, we can use ggscatmat() for a quick visualization.

library(GGally)
ggscatmat(food, columns = 1:11, color = "asia", alpha = 0.7)

Notice how important rice is! Also see how the relationships between different food categories is different for Asian and non-Asian countries; a tree-based model like a random forest is good as learning interactions like this.

Tune hyperparameters

Now it’s time to tune the hyperparameters for a random forest model. First, let’s create a set of bootstrap resamples to use for tuning, and then let’s create a model specification for a random forest where we will tune mtry (the number of predictors to sample at each split) and min_n (the number of observations needed to keep splitting nodes). There are hyperparameters that can’t be learned from data when training the model.

library(tidymodels)

set.seed(1234)
food_boot <- bootstraps(food, times = 30)
food_boot
## # Bootstrap sampling
## # A tibble: 30 x 2
##    splits           id
##    <list>           <chr>
##  1 <split [130/48]> Bootstrap01
##  2 <split [130/49]> Bootstrap02
##  3 <split [130/49]> Bootstrap03
##  4 <split [130/51]> Bootstrap04
##  5 <split [130/47]> Bootstrap05
##  6 <split [130/51]> Bootstrap06
##  7 <split [130/57]> Bootstrap07
##  8 <split [130/51]> Bootstrap08
##  9 <split [130/44]> Bootstrap09
## 10 <split [130/53]> Bootstrap10
## # … with 20 more rows
rf_spec <- rand_forest(
  mode = "classification",
  mtry = tune(),
  trees = 1000,
  min_n = tune()
) %>%
  set_engine("ranger")

rf_spec
## Random Forest Model Specification (classification)
##
## Main Arguments:
##   mtry = tune()
##   trees = 1000
##   min_n = tune()
##
## Computational engine: ranger

We can’t learn the right values when training a single model, but we can train a whole bunch of models and see which ones turn out best. We can use parallel processing to make this go faster, since the different parts of the grid are independent.

doParallel::registerDoParallel()

rf_grid <- tune_grid(
  asia ~ .,
  model = rf_spec,
  resamples = food_boot
)

rf_grid
## # Bootstrap sampling
## # A tibble: 30 x 4
##    splits           id          .metrics          .notes
##    <list>           <chr>       <list>            <list>
##  1 <split [130/48]> Bootstrap01 <tibble [20 × 5]> <tibble [0 × 1]>
##  2 <split [130/49]> Bootstrap02 <tibble [20 × 5]> <tibble [0 × 1]>
##  3 <split [130/49]> Bootstrap03 <tibble [20 × 5]> <tibble [0 × 1]>
##  4 <split [130/51]> Bootstrap04 <tibble [20 × 5]> <tibble [0 × 1]>
##  5 <split [130/47]> Bootstrap05 <tibble [20 × 5]> <tibble [0 × 1]>
##  6 <split [130/51]> Bootstrap06 <tibble [20 × 5]> <tibble [0 × 1]>
##  7 <split [130/57]> Bootstrap07 <tibble [20 × 5]> <tibble [0 × 1]>
##  8 <split [130/51]> Bootstrap08 <tibble [20 × 5]> <tibble [0 × 1]>
##  9 <split [130/44]> Bootstrap09 <tibble [20 × 5]> <tibble [0 × 1]>
## 10 <split [130/53]> Bootstrap10 <tibble [20 × 5]> <tibble [0 × 1]>
## # … with 20 more rows

Once we have our tuning results, we can check them out.

rf_grid %>%
  collect_metrics()
## # A tibble: 20 x 7
##     mtry min_n .metric  .estimator  mean     n std_err
##    <int> <int> <chr>    <chr>      <dbl> <int>   <dbl>
##  1     2     4 accuracy binary     0.836    30 0.00801
##  2     2     4 roc_auc  binary     0.841    30 0.00942
##  3     2    12 accuracy binary     0.826    30 0.00779
##  4     2    12 roc_auc  binary     0.837    30 0.00902
##  5     4    33 accuracy binary     0.817    30 0.00895
##  6     4    33 roc_auc  binary     0.820    30 0.0105
##  7     4    37 accuracy binary     0.813    30 0.00851
##  8     4    37 roc_auc  binary     0.818    30 0.0102
##  9     5    31 accuracy binary     0.817    30 0.00875
## 10     5    31 roc_auc  binary     0.819    30 0.0103
## 11     6     9 accuracy binary     0.826    30 0.00954
## 12     6     9 roc_auc  binary     0.832    30 0.00956
## 13     7    21 accuracy binary     0.814    30 0.00926
## 14     7    21 roc_auc  binary     0.823    30 0.0103
## 15     8    18 accuracy binary     0.822    30 0.00887
## 16     8    18 roc_auc  binary     0.823    30 0.0104
## 17     9    26 accuracy binary     0.815    30 0.00992
## 18     9    26 roc_auc  binary     0.822    30 0.0109
## 19    11    15 accuracy binary     0.812    30 0.0114
## 20    11    15 roc_auc  binary     0.822    30 0.0102

And we can see which models performed the best, in terms of some given metric.

rf_grid %>%
  show_best("roc_auc")
## # A tibble: 5 x 7
##    mtry min_n .metric .estimator  mean     n std_err
##   <int> <int> <chr>   <chr>      <dbl> <int>   <dbl>
## 1     2     4 roc_auc binary     0.841    30 0.00942
## 2     2    12 roc_auc binary     0.837    30 0.00902
## 3     6     9 roc_auc binary     0.832    30 0.00956
## 4     8    18 roc_auc binary     0.823    30 0.0104
## 5     7    21 roc_auc binary     0.823    30 0.0103

If you would like to specific the grid for tuning yourself, check out the dials package!

Posted on:
February 18, 2020
Length:
7 minute read, 1289 words
Categories:
rstats tidymodels
Tags:
rstats tidymodels
See Also:
Educational attainment in #TidyTuesday UK towns
Changes in #TidyTuesday US polling places
Empirical Bayes for #TidyTuesday Doctor Who episodes