Tune an xgboost model with early stopping and #TidyTuesday childcare costs
By Julia Silge in rstats tidymodels
May 11, 2023
This is the latest in my series of
screencasts! This screencast focuses on how to use tidymodels to tune an xgboost model with early stopping, using this weekβs
#TidyTuesday
dataset on childcare costs in the United States. π©βπ§βπ¦
Here is the code I used in the video, for those who prefer reading instead of or in addition to video.
Explore data
Mothers Day is coming up this weekend, and our modeling goal in this case is to predict the cost of childcare in US counties based on other characteristics of each county, like the poverty rate and labor force participation. Letβs start by reading in the data:
library(tidyverse)
childcare_costs <- read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2023/2023-05-09/childcare_costs.csv')
glimpse(childcare_costs)
## Rows: 34,567
## Columns: 61
## $ county_fips_code <dbl> 1001, 1001, 1001, 1001, 1001, 1001, 1001, 10β¦
## $ study_year <dbl> 2008, 2009, 2010, 2011, 2012, 2013, 2014, 20β¦
## $ unr_16 <dbl> 5.42, 5.93, 6.21, 7.55, 8.60, 9.39, 8.50, 7.β¦
## $ funr_16 <dbl> 4.41, 5.72, 5.57, 8.13, 8.88, 10.31, 9.18, 8β¦
## $ munr_16 <dbl> 6.32, 6.11, 6.78, 7.03, 8.29, 8.56, 7.95, 6.β¦
## $ unr_20to64 <dbl> 4.6, 4.8, 5.1, 6.2, 6.7, 7.3, 6.8, 5.9, 4.4,β¦
## $ funr_20to64 <dbl> 3.5, 4.6, 4.6, 6.3, 6.4, 7.6, 6.8, 6.1, 4.6,β¦
## $ munr_20to64 <dbl> 5.6, 5.0, 5.6, 6.1, 7.0, 7.0, 6.8, 5.9, 4.3,β¦
## $ flfpr_20to64 <dbl> 68.9, 70.8, 71.3, 70.2, 70.6, 70.7, 69.9, 68β¦
## $ flfpr_20to64_under6 <dbl> 66.9, 63.7, 67.0, 66.5, 67.1, 67.5, 65.2, 66β¦
## $ flfpr_20to64_6to17 <dbl> 79.59, 78.41, 78.15, 77.62, 76.31, 75.91, 75β¦
## $ flfpr_20to64_under6_6to17 <dbl> 60.81, 59.91, 59.71, 59.31, 58.30, 58.00, 57β¦
## $ mlfpr_20to64 <dbl> 84.0, 86.2, 85.8, 85.7, 85.7, 85.0, 84.2, 82β¦
## $ pr_f <dbl> 8.5, 7.5, 7.5, 7.4, 7.4, 8.3, 9.1, 9.3, 9.4,β¦
## $ pr_p <dbl> 11.5, 10.3, 10.6, 10.9, 11.6, 12.1, 12.8, 12β¦
## $ mhi_2018 <dbl> 58462.55, 60211.71, 61775.80, 60366.88, 5915β¦
## $ me_2018 <dbl> 32710.60, 34688.16, 34740.84, 34564.32, 3432β¦
## $ fme_2018 <dbl> 25156.25, 26852.67, 27391.08, 26727.68, 2796β¦
## $ mme_2018 <dbl> 41436.80, 43865.64, 46155.24, 45333.12, 4427β¦
## $ total_pop <dbl> 49744, 49584, 53155, 53944, 54590, 54907, 55β¦
## $ one_race <dbl> 98.1, 98.6, 98.5, 98.5, 98.5, 98.6, 98.7, 98β¦
## $ one_race_w <dbl> 78.9, 79.1, 79.1, 78.9, 78.9, 78.3, 78.0, 77β¦
## $ one_race_b <dbl> 17.7, 17.9, 17.9, 18.1, 18.1, 18.4, 18.6, 18β¦
## $ one_race_i <dbl> 0.4, 0.4, 0.3, 0.2, 0.3, 0.3, 0.4, 0.4, 0.4,β¦
## $ one_race_a <dbl> 0.4, 0.6, 0.7, 0.7, 0.8, 1.0, 0.9, 1.0, 0.8,β¦
## $ one_race_h <dbl> 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.1,β¦
## $ one_race_other <dbl> 0.7, 0.7, 0.6, 0.5, 0.4, 0.7, 0.7, 0.9, 1.4,β¦
## $ two_races <dbl> 1.9, 1.4, 1.5, 1.5, 1.5, 1.4, 1.3, 1.6, 2.0,β¦
## $ hispanic <dbl> 1.8, 2.0, 2.3, 2.4, 2.4, 2.5, 2.5, 2.6, 2.6,β¦
## $ households <dbl> 18373, 18288, 19718, 19998, 19934, 20071, 20β¦
## $ h_under6_both_work <dbl> 1543, 1475, 1569, 1695, 1714, 1532, 1557, 13β¦
## $ h_under6_f_work <dbl> 970, 964, 1009, 1060, 938, 880, 1191, 1258, β¦
## $ h_under6_m_work <dbl> 22, 16, 16, 106, 120, 161, 159, 211, 109, 10β¦
## $ h_under6_single_m <dbl> 995, 1099, 1110, 1030, 1095, 1160, 954, 883,β¦
## $ h_6to17_both_work <dbl> 4900, 5028, 5472, 5065, 4608, 4238, 4056, 40β¦
## $ h_6to17_fwork <dbl> 1308, 1519, 1541, 1965, 1963, 1978, 2073, 20β¦
## $ h_6to17_mwork <dbl> 114, 92, 113, 246, 284, 354, 373, 551, 322, β¦
## $ h_6to17_single_m <dbl> 1966, 2305, 2377, 2299, 2644, 2522, 2269, 21β¦
## $ emp_m <dbl> 27.40, 29.54, 29.33, 31.17, 32.13, 31.74, 32β¦
## $ memp_m <dbl> 24.41, 26.07, 25.94, 26.97, 28.59, 27.44, 28β¦
## $ femp_m <dbl> 30.68, 33.40, 33.06, 35.96, 36.09, 36.61, 37β¦
## $ emp_service <dbl> 17.06, 15.81, 16.92, 16.18, 16.09, 16.72, 16β¦
## $ memp_service <dbl> 15.53, 14.16, 15.09, 14.21, 14.71, 13.92, 13β¦
## $ femp_service <dbl> 18.75, 17.64, 18.93, 18.42, 17.63, 19.89, 20β¦
## $ emp_sales <dbl> 29.11, 28.75, 29.07, 27.56, 28.39, 27.22, 25β¦
## $ memp_sales <dbl> 15.97, 17.51, 17.82, 17.74, 17.79, 17.38, 15β¦
## $ femp_sales <dbl> 43.52, 41.25, 41.43, 38.76, 40.26, 38.36, 36β¦
## $ emp_n <dbl> 13.21, 11.89, 11.57, 10.72, 9.02, 9.27, 9.38β¦
## $ memp_n <dbl> 22.54, 20.30, 19.86, 18.28, 16.03, 16.79, 17β¦
## $ femp_n <dbl> 2.99, 2.52, 2.45, 2.09, 1.19, 0.77, 0.58, 0.β¦
## $ emp_p <dbl> 13.22, 14.02, 13.11, 14.38, 14.37, 15.04, 16β¦
## $ memp_p <dbl> 21.55, 21.96, 21.28, 22.80, 22.88, 24.48, 24β¦
## $ femp_p <dbl> 4.07, 5.19, 4.13, 4.77, 4.84, 4.36, 6.07, 7.β¦
## $ mcsa <dbl> 80.92, 83.42, 85.92, 88.43, 90.93, 93.43, 95β¦
## $ mfccsa <dbl> 81.40, 85.68, 89.96, 94.25, 98.53, 102.82, 1β¦
## $ mc_infant <dbl> 104.95, 105.11, 105.28, 105.45, 105.61, 105.β¦
## $ mc_toddler <dbl> 104.95, 105.11, 105.28, 105.45, 105.61, 105.β¦
## $ mc_preschool <dbl> 85.92, 87.59, 89.26, 90.93, 92.60, 94.27, 95β¦
## $ mfcc_infant <dbl> 83.45, 87.39, 91.33, 95.28, 99.22, 103.16, 1β¦
## $ mfcc_toddler <dbl> 83.45, 87.39, 91.33, 95.28, 99.22, 103.16, 1β¦
## $ mfcc_preschool <dbl> 81.40, 85.68, 89.96, 94.25, 98.53, 102.82, 1β¦
There are a lot of variables to possibly use as predictors in this dataset. In situations like this where there are lots of variables in a big rectangular dataset, many of which are highly correlated with each other, I will often turn to xgboost as a modeling algorithm. It often works great!
Before we get started with the modeling, letβs do a bit of exploratory data analysis. How have childcare costs as measured by mcsa
(median weekly price for school-aged kids in childcare centers) changed over time?
childcare_costs |>
ggplot(aes(study_year, mcsa, group = study_year, fill = study_year)) +
geom_boxplot(alpha = 0.8, show.legend = FALSE) +
scale_fill_distiller(palette = "RdPu")
How are childcare costs related to mhi_2018
(median household income) and flfpr_20to64
(labor force participation for women)?
childcare_costs |>
ggplot(aes(mhi_2018, mcsa, color = flfpr_20to64)) +
geom_point(alpha = 0.5) +
scale_x_log10() +
scale_color_viridis_c()
It looks like childcare costs are mostly flat for low income counties but increase for high income counties, and labor force participation for women is higher in high income counties.
What about the racial makeup of counties?
childcare_costs |>
select(mcsa, starts_with("one_race"), mhi_2018) |>
select(-one_race) |>
pivot_longer(starts_with("one_race")) |>
ggplot(aes(value, mcsa, color = mhi_2018)) +
geom_point(alpha = 0.5) +
facet_wrap(vars(name), scales = "free_x") +
scale_color_viridis_c() +
labs(x = "% of population")
Thereβs a lot going on in this one! When a county has more Black population (one_race_b
), household income is lower and childcare costs are lower; the opposite is true for the white population (one_race_w
). There looks to be a trend for the Asian population (one_race_a
) where a higher Asian population comes with higher childcare costs. None of these relationships are causal, of course, but related to complex relationships between race, class, and where people live in the US.
Build a model
Letβs start our modeling by setting up our βdata budget.β For this example, letβs predict mcsa
(costs for school-age kids in childcare centers) and remove the other measures of childcare costs for babies or toddlers, family-based childcare, etc. Letβs remove the FIPS codes which literally encode location and instead focus on the characteristics of counties like household income, number of households with children, and similar. Since this dataset is quite big, letβs use a single
validation set.
library(tidymodels)
set.seed(123)
childcare_split <- childcare_costs |>
select(-matches("^mc_|^mfc")) |>
select(-county_fips_code) |>
na.omit() |>
initial_split(strata = mcsa)
childcare_train <- training(childcare_split)
childcare_test <- testing(childcare_split)
set.seed(234)
childcare_set <- validation_split(childcare_train)
childcare_set
## # Validation Set Split (0.75/0.25)
## # A tibble: 1 Γ 2
## splits id
## <list> <chr>
## 1 <split [13269/4424]> validation
All these predictors are already numeric so we donβt need any special feature engineering; we can just use a formula like mcsa ~ .
. We do need to set up a tunable xgboost model specification with
early stopping, like we planned. We will keep the number of trees as a constant (and not too terribly high), set stop_iter
(the early stopping parameter) to tune()
, and then tune a few other parameters. Notice that we need to set a validation set (which in this case is a proportion of the training set) to hold back to use for deciding when to stop.
xgb_spec <-
boost_tree(
trees = 500,
min_n = tune(),
mtry = tune(),
stop_iter = tune(),
learn_rate = 0.01
) |>
set_engine("xgboost", validation = 0.2) |>
set_mode("regression")
xgb_wf <- workflow(mcsa ~ ., xgb_spec)
xgb_wf
## ββ Workflow ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
## Preprocessor: Formula
## Model: boost_tree()
##
## ββ Preprocessor ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
## mcsa ~ .
##
## ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
## Boosted Tree Model Specification (regression)
##
## Main Arguments:
## mtry = tune()
## trees = 500
## min_n = tune()
## learn_rate = 0.01
## stop_iter = tune()
##
## Engine-Specific Arguments:
## validation = 0.2
##
## Computational engine: xgboost
Our model is read to go! Letβs tune across possible hyperparameter configurations using our training set (with a subset that is held back for early stopping) plus our validation set.
doParallel::registerDoParallel()
set.seed(234)
xgb_rs <- tune_grid(xgb_wf, childcare_set, grid = 15)
xgb_rs
## # Tuning results
## # Validation Set Split (0.75/0.25)
## # A tibble: 1 Γ 4
## splits id .metrics .notes
## <list> <chr> <list> <list>
## 1 <split [13269/4424]> validation <tibble [30 Γ 7]> <tibble [0 Γ 3]>
All done!
Evaluate results
How did these results turn out? We can visualize them.
autoplot(xgb_rs)
Maybe we could consider going back to tune again with lower min_n
and/or higher mtry
to achieve better performance.
We can look at the top results we got like this:
show_best(xgb_rs, "rmse")
## # A tibble: 5 Γ 9
## mtry min_n stop_iter .metric .estimator mean n std_err .config
## <int> <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 42 5 14 rmse standard 21.8 1 NA Preprocessor1_Moβ¦
## 2 38 7 20 rmse standard 22.0 1 NA Preprocessor1_Moβ¦
## 3 35 2 9 rmse standard 22.0 1 NA Preprocessor1_Moβ¦
## 4 25 17 8 rmse standard 22.0 1 NA Preprocessor1_Moβ¦
## 5 41 17 16 rmse standard 22.1 1 NA Preprocessor1_Moβ¦
The best RMSE is a little more than $20, which is an estimate of how precisely we can predict the median childcare cost in a US county (remember that the median in this dataset was about $100).
Letβs use last_fit()
to fit one final time to the training data and evaluate one final time on the testing data, with the numerically optimal result from xgb_rs
.
childcare_fit <- xgb_wf |>
finalize_workflow(select_best(xgb_rs, "rmse")) |>
last_fit(childcare_split)
childcare_fit
## # Resampling results
## # Manual resampling
## # A tibble: 1 Γ 6
## splits id .metrics .notes .predictions .workflow
## <list> <chr> <list> <list> <list> <list>
## 1 <split [17693/5900]> train/test spl⦠<tibble> <tibble> <tibble> <workflow>
How did this model perform on the testing data, that was not used in tuning or training?
collect_metrics(childcare_fit)
## # A tibble: 2 Γ 4
## .metric .estimator .estimate .config
## <chr> <chr> <dbl> <chr>
## 1 rmse standard 21.5 Preprocessor1_Model1
## 2 rsq standard 0.625 Preprocessor1_Model1
What features are most important for this xgboost model?
library(vip)
extract_workflow(childcare_fit) |>
extract_fit_parsnip() |>
vip(num_features = 15, geom = "point")
The proportion of county population that is Asian has a big impact in this model, as does median household income, median earnings for women, year, and number of households in the county.
BONUS: create a deployable model object!
If you wanted to deploy this model, the next step is to create a deployable model object with vetiver:
library(vetiver)
v <- extract_workflow(childcare_fit) |>
vetiver_model("childcare-costs-xgb")
v
##
## ββ childcare-costs-xgb β <bundled_workflow> model for deployment
## A xgboost regression modeling workflow using 52 features
At posit::conf() this coming September in Chicago, I am teaching a workshop on how to deploy and maintain models with vetiver. Registration is open now if you are interested in learning more about this part of the modeling process, but you should also check out all the other workshops being organized!
- Posted on:
- May 11, 2023
- Length:
- 10 minute read, 1966 words
- Categories:
- rstats tidymodels
- Tags:
- rstats tidymodels