Fit and predict with tidymodels for #TidyTuesday bird baths in Australia

By Julia Silge in rstats tidymodels

September 1, 2021

This is the latest in my series of screencasts demonstrating how to use the tidymodels packages, from just getting started to tuning more complex models. Today’s screencast is good for folks who are newer to modeling or tidymodels; it focuses on how to use feature engineering together with a model algorithm and how to fit and predict, with this week’s #TidyTuesday dataset on bird baths in Australia. 🐦


Here is the code I used in the video, for those who prefer reading instead of or in addition to video.

Explore data

Our modeling goal is to predict whether we’ll see a bird at a bird bath in Australia, given info like what kind of bird we’re looking for and whether the bird bath is in an urban or rural location.

library(tidyverse)

bird_baths <- readr::read_csv("https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2021/2021-08-31/bird_baths.csv")

bird_baths %>%
  count(urban_rural)
## # A tibble: 3 × 2
##   urban_rural      n
##   <chr>        <int>
## 1 Rural        49686
## 2 Urban       111202
## 3 <NA>           169

Notice that there are some summary rows in the dataset with NA values for urban_rural, survey_year, etc. We can use that to choose some top bird types to focus on, instead of all the many bird types included in this dataset.

top_birds <-
  bird_baths %>%
  filter(is.na(urban_rural)) %>%
  arrange(-bird_count) %>%
  slice_max(bird_count, n = 15) %>%
  pull(bird_type)

top_birds
##  [1] "Noisy Miner"        "Australian Magpie"  "Rainbow Lorikeet"  
##  [4] "Red Wattlebird"     "Superb Fairy-wren"  "Magpie-lark"       
##  [7] "Pied Currawong"     "Crimson Rosella"    "Eastern Spinebill" 
## [10] "Spotted Dove"       "Lewin's Honeyeater" "Satin Bowerbird"   
## [13] "Crested Pigeon"     "Grey Fantail"       "Red-browed Finch"

How likely were the citizen scientists who collected this data to see birds of different types, in different locations?

bird_parsed <-
  bird_baths %>%
  filter(
    !is.na(urban_rural),
    bird_type %in% top_birds
  ) %>%
  group_by(urban_rural, bird_type) %>%
  summarise(bird_count = mean(bird_count), .groups = "drop")

p1 <-
  bird_parsed %>%
  ggplot(aes(bird_count, bird_type)) +
  geom_segment(
    data = bird_parsed %>%
      pivot_wider(
        names_from = urban_rural,
        values_from = bird_count
      ),
    aes(x = Rural, xend = Urban, y = bird_type, yend = bird_type),
    alpha = 0.7, color = "gray70", size = 1.5
  ) +
  geom_point(aes(color = urban_rural), size = 3) +
  scale_x_continuous(labels = scales::percent) +
  labs(x = "Probability of seeing bird", y = NULL, color = NULL)

p1

Superb fairy-wrens are more rural, while noisy miners are more urban.

Let’s build a model to predict this probability of seeing a bird using just these two predictors.

bird_df <-
  bird_baths %>%
  filter(
    !is.na(urban_rural),
    bird_type %in% top_birds
  ) %>%
  mutate(bird_count = if_else(bird_count > 0, "bird", "no bird")) %>%
  mutate_if(is.character, as.factor)

Build a first model

Let’s start our modeling by setting up our “data budget.” We are going to use a simple logistic regression model that is unlikely to overfit, but let’s still split our data into training and testing, and then create resampling folds.

library(tidymodels)

set.seed(123)
bird_split <- initial_split(bird_df, strata = bird_count)
bird_train <- training(bird_split)
bird_test <- testing(bird_split)

set.seed(234)
bird_folds <- vfold_cv(bird_train, strata = bird_count)
bird_folds
## #  10-fold cross-validation using stratification 
## # A tibble: 10 × 2
##    splits              id    
##    <list>              <chr> 
##  1 <split [9637/1072]> Fold01
##  2 <split [9638/1071]> Fold02
##  3 <split [9638/1071]> Fold03
##  4 <split [9638/1071]> Fold04
##  5 <split [9638/1071]> Fold05
##  6 <split [9638/1071]> Fold06
##  7 <split [9638/1071]> Fold07
##  8 <split [9638/1071]> Fold08
##  9 <split [9639/1070]> Fold09
## 10 <split [9639/1070]> Fold10

We’ll make a couple of attempts at fitting models here, but they will all use straightforward logistic regression.

glm_spec <- logistic_reg()

For this first model, let’s set up our feature engineering recipe with our outcome and two predictors, and begin with only one preprocessing step to transform our nominal (factor or character, like urban_rural and bird_type) predictors to dummy or indicator variables. Then let’s put our preprocessing recipe together with our model specification in a workflow.

rec_basic <-
  recipe(bird_count ~ urban_rural + bird_type, data = bird_train) %>%
  step_dummy(all_nominal_predictors())

wf_basic <- workflow(rec_basic, glm_spec)

We could fit this one time to the training data, but to get better estimates of performance, let’s fit 10 times to our 10 resampling folds.

doParallel::registerDoParallel()
ctrl_preds <- control_resamples(save_pred = TRUE)
rs_basic <- fit_resamples(wf_basic, bird_folds, control = ctrl_preds)

How did this turn out? If we look at some overall metrics, accuracy does not look so bad:

collect_metrics(rs_basic)
## # A tibble: 2 × 6
##   .metric  .estimator  mean     n   std_err .config             
##   <chr>    <chr>      <dbl> <int>     <dbl> <chr>               
## 1 accuracy binary     0.822    10 0.0000762 Preprocessor1_Model1
## 2 roc_auc  binary     0.601    10 0.00783   Preprocessor1_Model1

This is because there were not many birds overall, though! The model is just saying “no bird” everywhere and getting good accuracy. The ROC curve, on the other hand, looks not so great.

augment(rs_basic) %>%
  roc_curve(bird_count, .pred_bird) %>%
  autoplot()

Add interactions

We know from the plot we made during EDA that there are interactions between whether a bird bath is urban/rural and what kinds of birds we see there; we could model these interactions either with a model type that can handle it natively (like trees) or with explicit interaction terms like this:

rec_interact <-
  rec_basic %>%
  step_interact(~ starts_with("urban_rural"):starts_with("bird_type"))

wf_interact <- workflow(rec_interact, glm_spec)
rs_interact <- fit_resamples(wf_interact, bird_folds, control = ctrl_preds)

How did this do, our same logistic regression model specification but now with interactions?

collect_metrics(rs_interact)
## # A tibble: 2 × 6
##   .metric  .estimator  mean     n   std_err .config             
##   <chr>    <chr>      <dbl> <int>     <dbl> <chr>               
## 1 accuracy binary     0.822    10 0.0000762 Preprocessor1_Model1
## 2 roc_auc  binary     0.669    10 0.00660   Preprocessor1_Model1

The accuracy is about the same (since the model is always predicting “no bird”) but the probabilities look better.

augment(rs_interact) %>%
  roc_curve(bird_count, .pred_bird) %>%
  autoplot()

Evaluate model on new data

Let’s stick with this model, logistic regression together with interactions between urban/rural and bird type. We can fit the model one time to the entire training set.

bird_fit <- fit(wf_interact, bird_train)

Now this trained model is ready to be applied to new data. For example, we can predict the test set, perhaps to get out probabilities.

predict(bird_fit, bird_test, type = "prob")
## # A tibble: 3,571 × 2
##    .pred_bird `.pred_no bird`
##         <dbl>           <dbl>
##  1     0.213            0.787
##  2     0.123            0.877
##  3     0.141            0.859
##  4     0.283            0.717
##  5     0.119            0.881
##  6     0.252            0.748
##  7     0.0380           0.962
##  8     0.123            0.877
##  9     0.129            0.871
## 10     0.119            0.881
## # … with 3,561 more rows

In fact, we can predict on any kind of new data that has the right input variables. Let’s make some ourselves.

new_bird_data <-
  tibble(bird_type = top_birds) %>%
  crossing(urban_rural = c("Urban", "Rural"))

new_bird_data
## # A tibble: 30 × 2
##    bird_type         urban_rural
##    <chr>             <chr>      
##  1 Australian Magpie Rural      
##  2 Australian Magpie Urban      
##  3 Crested Pigeon    Rural      
##  4 Crested Pigeon    Urban      
##  5 Crimson Rosella   Rural      
##  6 Crimson Rosella   Urban      
##  7 Eastern Spinebill Rural      
##  8 Eastern Spinebill Urban      
##  9 Grey Fantail      Rural      
## 10 Grey Fantail      Urban      
## # … with 20 more rows

We can use a helpful function like augment() to take this new data and “augment” it with predicted probabilities and class predictions, and we can use predict() with specific type arguments to return specialized predictions like confidence intervals. Let’s bind these together.

bird_preds <-
  augment(bird_fit, new_bird_data) %>%
  bind_cols(
    predict(bird_fit, new_bird_data, type = "conf_int")
  )

bird_preds
## # A tibble: 30 × 9
##    bird_type urban_rural .pred_class .pred_bird `.pred_no bird` .pred_lower_bird
##    <chr>     <chr>       <fct>            <dbl>           <dbl>            <dbl>
##  1 Australi… Rural       no bird         0.245            0.755           0.193 
##  2 Australi… Urban       no bird         0.287            0.713           0.249 
##  3 Crested … Rural       no bird         0.0826           0.917           0.0526
##  4 Crested … Urban       no bird         0.141            0.859           0.113 
##  5 Crimson … Rural       no bird         0.215            0.785           0.166 
##  6 Crimson … Urban       no bird         0.123            0.877           0.0969
##  7 Eastern … Rural       no bird         0.283            0.717           0.227 
##  8 Eastern … Urban       no bird         0.0973           0.903           0.0736
##  9 Grey Fan… Rural       no bird         0.254            0.746           0.200 
## 10 Grey Fan… Urban       no bird         0.0614           0.939           0.0435
## # … with 20 more rows, and 3 more variables: .pred_upper_bird <dbl>,
## #   .pred_lower_no bird <dbl>, .pred_upper_no bird <dbl>

Now let’s visualize these predictions.

p2 <-
  bird_preds %>%
  ggplot(aes(.pred_bird, bird_type, color = urban_rural)) +
  geom_errorbar(aes(
    xmin = .pred_lower_bird,
    xmax = .pred_upper_bird
  ),
  width = .2, size = 1.2, alpha = 0.5
  ) +
  geom_point(size = 2.5) +
  scale_x_continuous(labels = scales::percent) +
  labs(x = "Predicted probability of seeing bird", y = NULL, color = NULL)

p2

Actually, let’s put this together with our earlier plot!

library(patchwork)

p1 + p2

Posted on:
September 1, 2021
Length:
7 minute read, 1416 words
Categories:
rstats tidymodels
Tags:
rstats tidymodels
See Also:
Positron in action with #TidyTuesday orca encounters
Educational attainment in #TidyTuesday UK towns
Changes in #TidyTuesday US polling places