Class imbalance and classification metrics with aircraft wildlife strikes
By Julia Silge in rstats tidymodels
June 21, 2021
This is the latest in my series of screencasts demonstrating how to use the tidymodels packages, from just starting out to tuning more complex models with many hyperparameters. I recently participated in SLICED, a competitive data science prediction challenge. I did not necessarily cover myself in glory but in today’s screencast, I walk through the data set on aircraft wildlife strikes we used and how different choices around handling class imbalance affect different classification metrics. ✈️
Here is the code I used in the video, for those who prefer reading instead of or in addition to video.
Explore data
Our modeling goal is to predict whether an
aircraft strike with wildlife resulted in damage to the aircraft. There are two data sets provided, training (which has the label damaged
) and testing (which does not).
library(tidyverse)
train_raw <- read_csv("train.csv", guess_max = 1e5) %>%
mutate(damaged = case_when(
damaged > 0 ~ "damage",
TRUE ~ "no damage"
))
test_raw <- read_csv("test.csv", guess_max = 1e5)
There is lots available in the data!
skimr::skim(train_raw)
Name | train_raw |
Number of rows | 21000 |
Number of columns | 34 |
_______________________ | |
Column type frequency: | |
character | 20 |
numeric | 14 |
________________________ | |
Group variables | None |
Table 1: Data summary
Variable type: character
skim_variable | n_missing | complete_rate | min | max | empty | n_unique | whitespace |
---|---|---|---|---|---|---|---|
operator_id | 0 | 1.00 | 3 | 5 | 0 | 276 | 0 |
operator | 0 | 1.00 | 3 | 33 | 0 | 275 | 0 |
aircraft | 0 | 1.00 | 3 | 20 | 0 | 424 | 0 |
aircraft_type | 4992 | 0.76 | 1 | 1 | 0 | 2 | 0 |
aircraft_make | 5231 | 0.75 | 2 | 3 | 0 | 62 | 0 |
engine_model | 6334 | 0.70 | 1 | 2 | 0 | 39 | 0 |
engine_type | 5703 | 0.73 | 1 | 3 | 0 | 8 | 0 |
engine3_position | 19671 | 0.06 | 1 | 11 | 0 | 4 | 0 |
airport_id | 0 | 1.00 | 3 | 5 | 0 | 1039 | 0 |
airport | 34 | 1.00 | 4 | 53 | 0 | 1038 | 0 |
state | 2664 | 0.87 | 2 | 2 | 0 | 60 | 0 |
faa_region | 2266 | 0.89 | 3 | 3 | 0 | 14 | 0 |
flight_phase | 6728 | 0.68 | 4 | 12 | 0 | 12 | 0 |
visibility | 7699 | 0.63 | 3 | 7 | 0 | 5 | 0 |
precipitation | 10327 | 0.51 | 3 | 15 | 0 | 8 | 0 |
species_id | 0 | 1.00 | 1 | 6 | 0 | 447 | 0 |
species_name | 7 | 1.00 | 4 | 50 | 0 | 445 | 0 |
species_quantity | 532 | 0.97 | 1 | 8 | 0 | 4 | 0 |
flight_impact | 8944 | 0.57 | 4 | 21 | 0 | 6 | 0 |
damaged | 0 | 1.00 | 6 | 9 | 0 | 2 | 0 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
id | 0 | 1.00 | 14980.94 | 8663.24 | 1 | 7458.75 | 14978.5 | 22472.25 | 30000 | ▇▇▇▇▇ |
incident_year | 0 | 1.00 | 2006.06 | 6.72 | 1990 | 2001.00 | 2007.0 | 2012.00 | 2015 | ▂▃▅▆▇ |
incident_month | 0 | 1.00 | 7.19 | 2.79 | 1 | 5.00 | 8.0 | 9.00 | 12 | ▃▅▆▇▆ |
incident_day | 0 | 1.00 | 15.63 | 8.82 | 1 | 8.00 | 15.0 | 23.00 | 31 | ▇▇▇▇▆ |
aircraft_model | 6259 | 0.70 | 24.65 | 21.70 | 0 | 10.00 | 22.0 | 37.00 | 98 | ▇▆▂▁▁ |
aircraft_mass | 5694 | 0.73 | 3.50 | 0.89 | 1 | 3.00 | 4.0 | 4.00 | 5 | ▁▁▂▇▁ |
engine_make | 6155 | 0.71 | 21.22 | 11.04 | 1 | 10.00 | 22.0 | 34.00 | 47 | ▇▂▆▇▁ |
engines | 5696 | 0.73 | 2.05 | 0.46 | 1 | 2.00 | 2.0 | 2.00 | 4 | ▁▇▁▁▁ |
engine1_position | 5838 | 0.72 | 2.99 | 2.09 | 1 | 1.00 | 1.0 | 5.00 | 7 | ▇▁▂▅▁ |
engine2_position | 6776 | 0.68 | 2.91 | 2.01 | 1 | 1.00 | 1.0 | 5.00 | 7 | ▇▁▂▅▁ |
engine4_position | 20650 | 0.02 | 2.02 | 1.43 | 1 | 1.00 | 1.0 | 4.00 | 5 | ▇▁▁▃▁ |
height | 8469 | 0.60 | 819.24 | 1772.53 | 0 | 0.00 | 50.0 | 800.00 | 24000 | ▇▁▁▁▁ |
speed | 12358 | 0.41 | 141.39 | 52.25 | 0 | 120.00 | 137.0 | 160.00 | 2500 | ▇▁▁▁▁ |
distance | 8913 | 0.58 | 0.66 | 3.33 | 0 | 0.00 | 0.0 | 0.00 | 100 | ▇▁▁▁▁ |
The data is imbalanced, with not many incidents resulting in damage.
train_raw %>%
count(damaged)
## # A tibble: 2 x 2
## damaged n
## <chr> <int>
## 1 damage 1799
## 2 no damage 19201
For numeric predictors, I often like to make a pairs plot for EDA.
library(GGally)
train_raw %>%
select(damaged, incident_year, height, speed, distance) %>%
ggpairs(columns = 2:5, aes(color = damaged, alpha = 0.5))
For categorical predictors, plots like these can be useful. Notice especially that NA
values look like they may be informative so we likely don’t want to throw them out.
train_raw %>%
select(
damaged, precipitation, visibility, engine_type,
flight_impact, flight_phase, species_quantity
) %>%
pivot_longer(precipitation:species_quantity) %>%
ggplot(aes(y = value, fill = damaged)) +
geom_bar(position = "fill") +
facet_wrap(vars(name), scales = "free", ncol = 2) +
labs(x = NULL, y = NULL, fill = NULL)
Let’s use the following variables for this post.
bird_df <- train_raw %>%
select(
damaged, flight_impact, precipitation,
visibility, flight_phase, engines, incident_year,
incident_month, species_id, engine_type,
aircraft_model, species_quantity, height, speed
)
Build a model
If I had enough time to try many models, I would
split the provided training data via initial_split()
, but I learned that two hours isn’t really enough time for me to try that many models. Let’s just create resampling folds from the provided training data.
library(tidymodels)
set.seed(123)
bird_folds <- vfold_cv(train_raw, v = 5, strata = damaged)
bird_folds
## # 5-fold cross-validation using stratification
## # A tibble: 5 x 2
## splits id
## <list> <chr>
## 1 <split [16800/4200]> Fold1
## 2 <split [16800/4200]> Fold2
## 3 <split [16800/4200]> Fold3
## 4 <split [16800/4200]> Fold4
## 5 <split [16800/4200]> Fold5
The SLICED prediction problem was evaluate on a single metric, log loss, so let’s create a metric set for that metric plus a few others for demonstration purposes.
bird_metrics <- metric_set(mn_log_loss, accuracy, sensitivity, specificity)
This data requires lots of preprocessing, such as handling new levels in the test set, pooling infrequent factor levels, and imputing or replacing the NA
values.
bird_rec <- recipe(damaged ~ ., data = bird_df) %>%
step_novel(all_nominal_predictors()) %>%
step_other(all_nominal_predictors(), threshold = 0.01) %>%
step_unknown(all_nominal_predictors()) %>%
step_impute_median(all_numeric_predictors()) %>%
step_zv(all_predictors())
bird_rec
## Data Recipe
##
## Inputs:
##
## role #variables
## outcome 1
## predictor 13
##
## Operations:
##
## Novel factor level assignment for all_nominal_predictors()
## Collapsing factor levels for all_nominal_predictors()
## Unknown factor level assignment for all_nominal_predictors()
## Median Imputation for all_numeric_predictors()
## Zero variance filter on all_predictors()
For this post, let’s use a model I didn’t try out during the stream, a bagged tree model. It’s similar to the kinds of models that perform well in SLICED-like situations but it is easy to set up and very fast to fit.
library(baguette)
bag_spec <-
bag_tree(min_n = 10) %>%
set_engine("rpart", times = 25) %>%
set_mode("classification")
imb_wf <-
workflow() %>%
add_recipe(bird_rec) %>%
add_model(bag_spec)
imb_fit <- fit(imb_wf, data = bird_df)
imb_fit
## ══ Workflow [trained] ══════════════════════════════════════════════════════════
## Preprocessor: Recipe
## Model: bag_tree()
##
## ── Preprocessor ────────────────────────────────────────────────────────────────
## 5 Recipe Steps
##
## • step_novel()
## • step_other()
## • step_unknown()
## • step_impute_median()
## • step_zv()
##
## ── Model ───────────────────────────────────────────────────────────────────────
## Bagged CART (classification with 25 members)
##
## Variable importance scores include:
##
## # A tibble: 13 x 4
## term value std.error used
## <chr> <dbl> <dbl> <int>
## 1 flight_impact 480. 6.81 25
## 2 aircraft_model 363. 4.97 25
## 3 incident_year 354. 5.51 25
## 4 species_id 337. 4.62 25
## 5 height 332. 5.45 25
## 6 speed 297. 4.82 25
## 7 incident_month 285. 6.18 25
## 8 flight_phase 246. 4.41 25
## 9 engine_type 213. 3.31 25
## 10 visibility 196. 3.82 25
## 11 precipitation 136. 3.23 25
## 12 engines 117. 2.67 25
## 13 species_quantity 83.7 3.12 25
We automatically get out some variable importance too, which is nice! We see that flight_impact
and aircraft_model
are very important for this model.
Resample and compare models
Now let’s evaluate how this model performs using resampling.
doParallel::registerDoParallel()
set.seed(123)
imb_rs <-
fit_resamples(
imb_wf,
resamples = bird_folds,
metrics = bird_metrics
)
collect_metrics(imb_rs)
## # A tibble: 4 x 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 accuracy binary 0.925 5 0.00221 Preprocessor1_Model1
## 2 mn_log_loss binary 0.212 5 0.00511 Preprocessor1_Model1
## 3 sens binary 0.278 5 0.00941 Preprocessor1_Model1
## 4 spec binary 0.986 5 0.000843 Preprocessor1_Model1
This is quite good compared to how other folks did with this data, especially for such a simple model. We could take this as a starting point and move to a similar but better performing model like xgboost.
What happens, though, if we change the preprocessing recipe to account for the class imbalance?
library(themis)
bal_rec <- bird_rec %>%
step_dummy(all_nominal_predictors()) %>%
step_smote(damaged)
bal_wf <-
workflow() %>%
add_recipe(bal_rec) %>%
add_model(bag_spec)
set.seed(234)
bal_rs <-
fit_resamples(
bal_wf,
resamples = bird_folds,
metrics = bird_metrics
)
collect_metrics(bal_rs)
## # A tibble: 4 x 6
## .metric .estimator mean n std_err .config
## <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 accuracy binary 0.919 5 0.00215 Preprocessor1_Model1
## 2 mn_log_loss binary 0.224 5 0.00559 Preprocessor1_Model1
## 3 sens binary 0.322 5 0.00967 Preprocessor1_Model1
## 4 spec binary 0.975 5 0.00103 Preprocessor1_Model1
Notice that the log loss and accuracy got worse, while the sensitivity got better. This is very common and expected, and frankly I wish I hadn’t been so laser focused on needing to get subsampling to work during the SLICED stream! In most real-world situations, a single metric is not adequate to measure how useful a model will be practically, and also unfortunately we often are most interested in detecting the minority class. This means that learning how to account for class imbalance is important in many real modeling scenarios. However, if you are ever in a situation where you are being evaluated on a single metric like log loss, you may want to stick with an imbalanced fit.
test_df <- test_raw %>%
select(
id, flight_impact, precipitation,
visibility, flight_phase, engines, incident_year,
incident_month, species_id, engine_type,
aircraft_model, species_quantity, height, speed
)
augment(imb_fit, test_df) %>%
select(id, .pred_damage)
## # A tibble: 9,000 x 2
## id .pred_damage
## <dbl> <dbl>
## 1 11254 0.346
## 2 27716 0.00606
## 3 29066 0.000544
## 4 3373 0.0406
## 5 1996 0.153
## 6 18061 0.000654
## 7 22237 0.00489
## 8 25346 0.274
## 9 21554 0.348
## 10 4273 0.00390
## # … with 8,990 more rows
- Posted on:
- June 21, 2021
- Length:
- 8 minute read, 1598 words
- Categories:
- rstats tidymodels
- Tags:
- rstats tidymodels