Note: Some results may differ from the hard copy book due to the changing of sampling procedures introduced in R 3.6.0. See http://bit.ly/35D1SW7 for more details. Access and run the source code for this notebook here.
Hidden chapter requirements used in the book to set the plotting theme and load packages used in hidden code chunks:
# Set global R options
options(scipen = 999)
# Set the graphical theme
ggplot2::theme_set(ggplot2::theme_light())
# Set global knitr chunk options
knitr::opts_chunk$set(
warning = FALSE,
message = FALSE
)
library(tidyverse)
library(rpart)
library(rpart.plot)
ames <- AmesHousing::make_ames()
Prerequisites
In this chapter we’ll use the following packages:
# Helper packages
library(dplyr) # for data wrangling
library(ggplot2) # for awesome plotting
# Modeling packages
library(rpart) # direct engine for decision tree application
library(caret) # meta engine for decision tree application
# Model interpretability packages
library(rpart.plot) # for plotting decision trees
library(vip) # for feature importance
library(pdp) # for feature effects
We’ll continue to illustrate the main concepts using the Ames housing data:
# Create training (70%) set for the Ames housing data.
set.seed(123)
split <- rsample::initial_split(ames, prop = 0.7, strata = "Sale_Price")
ames_train <- rsample::training(split)
Structure
Figure 9.1:
knitr::include_graphics("images/exemplar-decision-tree.png")
Figure 9.2:
knitr::include_graphics("images/decision-tree-terminology.png")
Partitioning
Figure 9.3:
# create data
set.seed(1112) # for reproducibility
df <- tibble::tibble(
x = seq(from = 0, to = 2 * pi, length = 500),
y = sin(x) + rnorm(length(x), sd = 0.5),
truth = sin(x)
)
# run decision stump model
ctrl <- list(cp = 0, minbucket = 5, maxdepth = 1)
fit <- rpart(y ~ x, data = df, control = ctrl)
# plot tree
par(mar = c(1, 1, 1, 1))
rpart.plot(fit)
# plot decision boundary
df %>%
mutate(pred = predict(fit, df)) %>%
ggplot(aes(x, y)) +
geom_point(alpha = .2, size = 1) +
geom_line(aes(x, y = truth), color = "blue", size = .75) +
geom_line(aes(y = pred), color = "red", size = .75) +
geom_segment(x = 3.1, xend = 3.1, y = -Inf, yend = -.95,
arrow = arrow(length = unit(0.25,"cm")), size = .25) +
annotate("text", x = 3.1, y = -Inf, label = "split", hjust = 1.2, vjust = -1, size = 3) +
geom_segment(x = 5.5, xend = 6, y = 2, yend = 2, size = .75, color = "blue") +
geom_segment(x = 5.5, xend = 6, y = 1.7, yend = 1.7, size = .75, color = "red") +
annotate("text", x = 5.3, y = 2, label = "truth", hjust = 1, size = 3, color = "blue") +
annotate("text", x = 5.3, y = 1.7, label = "decision boundary", hjust = 1, size = 3, color = "red")
Figure 9.4:
# fit depth 3 decision tree
ctrl <- list(cp = 0, minbucket = 5, maxdepth = 3)
fit <- rpart(y ~ x, data = df, control = ctrl)
rpart.plot(fit)
# plot decision boundary
df %>%
mutate(pred = predict(fit, df)) %>%
ggplot(aes(x, y)) +
geom_point(alpha = .2, size = 1) +
geom_line(aes(x, y = truth), color = "blue", size = .75) +
geom_line(aes(y = pred), color = "red", size = .75)
Figure 9.5:
# decision tree
iris_fit <- rpart(Species ~ Sepal.Length + Sepal.Width, data = iris)
rpart.plot(iris_fit)
# decision boundary
ggplot(iris, aes(Sepal.Length, Sepal.Width, color = Species, shape = Species)) +
geom_point(show.legend = FALSE) +
annotate("rect", xmin = -Inf, xmax = 5.44, ymin = 2.8, ymax = Inf, alpha = .75, fill = "orange") +
annotate("text", x = 4.0, y = 4.4, label = "setosa", hjust = 0, size = 3) +
annotate("rect", xmin = -Inf, xmax = 5.44, ymin = 2.79, ymax = -Inf, alpha = .75, fill = "grey") +
annotate("text", x = 4.0, y = 2, label = "versicolor", hjust = 0, size = 3) +
annotate("rect", xmin = 5.45, xmax = 6.15, ymin = 3.1, ymax = Inf, alpha = .75, fill = "orange") +
annotate("text", x = 6, y = 4.4, label = "setosa", hjust = 1, vjust = 0, size = 3) +
annotate("rect", xmin = 5.45, xmax = 6.15, ymin = 3.09, ymax = -Inf, alpha = .75, fill = "grey") +
annotate("text", x = 6.15, y = 2, label = "versicolor", hjust = 1, vjust = 0, fill = "grey", size = 3) +
annotate("rect", xmin = 6.16, xmax = Inf, ymin = -Inf, ymax = Inf, alpha = .75, fill = "green") +
annotate("text", x = 8, y = 2, label = "virginica", hjust = 1, vjust = 0, fill = "green", size = 3)
How deep?
Figure 9.6:
ctrl <- list(cp = 0, minbucket = 1, maxdepth = 50)
fit <- rpart(y ~ x, data = df, control = ctrl)
rpart.plot(fit)
df %>%
mutate(pred = predict(fit, df)) %>%
ggplot(aes(x, y)) +
geom_point(alpha = .2, size = 1) +
geom_line(aes(x, y = truth), color = "blue", size = 0.75) +
geom_line(aes(y = pred), color = "red", size = 0.75)
Early stopping
Figure 9.7:
hyper_grid <- expand.grid(
maxdepth = c(1, 5, 15),
minbucket = c(1, 5, 15)
)
results <- data.frame(NULL)
for(i in seq_len(nrow(hyper_grid))) {
ctrl <- list(cp = 0, maxdepth = hyper_grid$maxdepth[i], minbucket = hyper_grid$minbucket[i])
fit <- rpart(y ~ x, data = df, control = ctrl)
predictions <- mutate(
df,
minbucket = factor(paste("Min node size =", hyper_grid$minbucket[i]), ordered = TRUE),
maxdepth = factor(paste("Max tree depth =", hyper_grid$maxdepth[i]), ordered = TRUE)
)
predictions$pred <- predict(fit, df)
results <- rbind(results, predictions)
}
ggplot(results, aes(x, y)) +
geom_point(alpha = .2, size = 1) +
geom_line(aes(x, y = truth), color = "blue", size = .75) +
geom_line(aes(y = pred), color = "red", size = 1) +
facet_grid(minbucket ~ maxdepth)
Pruning
Figure 9.8:
ctrl <- list(cp = 0, minbucket = 1, maxdepth = 50)
fit <- rpart(y ~ x, data = df, control = ctrl)
p1 <- df %>%
mutate(pred = predict(fit, df)) %>%
ggplot(aes(x, y)) +
geom_point(alpha = .3, size = 2) +
geom_line(aes(x, y = truth), color = "blue", size = 1) +
geom_line(aes(y = pred), color = "red", size = 1)
fit2 <- rpart(y ~ x, data = df)
p2 <- df %>%
mutate(pred2 = predict(fit2, df)) %>%
ggplot(aes(x, y)) +
geom_point(alpha = .3, size = 2) +
geom_line(aes(x, y = truth), color = "blue", size = 1) +
geom_line(aes(y = pred2), color = "red", size = 1)
gridExtra::grid.arrange(p1, p2, nrow = 1)
Ames housing example
ames_dt1 <- rpart(
formula = Sale_Price ~ .,
data = ames_train,
method = "anova"
)
ames_dt1
n= 2053
node), split, n, deviance, yval
* denotes terminal node
1) root 2053 13217940000000 180996.30
2) Overall_Qual=Very_Poor,Poor,Fair,Below_Average,Average,Above_Average,Good 1722 4107888000000 156954.70
4) Neighborhood=North_Ames,Old_Town,Edwards,Sawyer,Mitchell,Brookside,Iowa_DOT_and_Rail_Road,South_and_West_of_Iowa_State_University,Meadow_Village,Briardale,Northpark_Villa,Blueste 1022 1332227000000 132318.00
8) Overall_Qual=Very_Poor,Poor,Fair,Below_Average 199 179295400000 98856.51 *
9) Overall_Qual=Average,Above_Average,Good 823 876239900000 140409.00
18) First_Flr_SF< 1089 517 290531200000 129244.00 *
19) First_Flr_SF>=1089 306 412375700000 159272.60 *
5) Neighborhood=College_Creek,Somerset,Northridge_Heights,Gilbert,Northwest_Ames,Sawyer_West,Crawford,Timberland,Northridge,Stone_Brook,Clear_Creek,Bloomington_Heights,Veenker,Green_Hills 700 1249681000000 192924.20
10) Gr_Liv_Area< 1477 287 250826800000 165395.90 *
11) Gr_Liv_Area>=1477 413 630227700000 212054.00
22) Total_Bsmt_SF< 959.5 199 139087700000 192493.10 *
23) Total_Bsmt_SF>=959.5 214 344191200000 230243.70 *
3) Overall_Qual=Very_Good,Excellent,Very_Excellent 331 2936700000000 306070.70
6) Overall_Qual=Very_Good 231 946974600000 270626.10
12) Gr_Liv_Area< 1919 142 334978300000 244016.60 *
13) Gr_Liv_Area>=1919 89 351030800000 313081.60 *
7) Overall_Qual=Excellent,Very_Excellent 100 1029126000000 387948.00
14) Total_Bsmt_SF< 1907.5 72 314985900000 350532.40 *
15) Total_Bsmt_SF>=1907.5 28 354159900000 484159.40 *
rpart.plot(ames_dt1)
plotcp(ames_dt1)
ames_dt2 <- rpart(
formula = Sale_Price ~ .,
data = ames_train,
method = "anova",
control = list(cp = 0, xval = 10)
)
plotcp(ames_dt2)
abline(v = 11, lty = "dashed")
# rpart cross validation results
ames_dt1$cptable
CP nsplit rel error xerror xstd
1 0.46704344 0 1.0000000 1.0015334 0.06051267
2 0.11544770 1 0.5329566 0.5343697 0.03079312
3 0.07267387 2 0.4175089 0.4209603 0.03007122
4 0.02788834 3 0.3448350 0.3502963 0.02145751
5 0.02723422 4 0.3169466 0.3319341 0.02225037
6 0.02093301 5 0.2897124 0.3125117 0.02150290
7 0.01974328 6 0.2687794 0.2986956 0.02139660
8 0.01311346 7 0.2490361 0.2726862 0.01738257
9 0.01111737 8 0.2359227 0.2654669 0.01725615
10 0.01000000 9 0.2248053 0.2584346 0.01721996
# caret cross validation results
ames_dt3 <- train(
Sale_Price ~ .,
data = ames_train,
method = "rpart",
trControl = trainControl(method = "cv", number = 10),
tuneLength = 20
)
ggplot(ames_dt3)
Feature interpretation
vip(ames_dt3, num_features = 40, bar = FALSE)
# Construct partial dependence plots
p1 <- partial(ames_dt3, pred.var = "Gr_Liv_Area") %>% autoplot()
p2 <- partial(ames_dt3, pred.var = "Year_Built") %>% autoplot()
p3 <- partial(ames_dt3, pred.var = c("Gr_Liv_Area", "Year_Built")) %>%
plotPartial(levelplot = FALSE, zlab = "yhat", drape = TRUE,
colorkey = TRUE, screen = list(z = -20, x = -60))
# Display plots side by side
gridExtra::grid.arrange(p1, p2, p3, ncol = 3)
