Note: Some results may differ from the hard copy book due to the changing of sampling procedures introduced in R 3.6.0. See http://bit.ly/35D1SW7 for more details. Access and run the source code for this notebook here.

Hidden chapter requirements used in the book to set the plotting theme and load packages used in hidden code chunks:

# Set global R options
options(scipen = 999)

# Set the graphical theme
ggplot2::theme_set(ggplot2::theme_light())

# Set global knitr chunk options
knitr::opts_chunk$set(
  warning = FALSE, 
  message = FALSE
)

library(tidyverse)
library(rpart)
library(rpart.plot)
ames <- AmesHousing::make_ames()

Prerequisites

In this chapter we’ll use the following packages:

# Helper packages
library(dplyr)       # for data wrangling
library(ggplot2)     # for awesome plotting

# Modeling packages
library(rpart)       # direct engine for decision tree application
library(caret)       # meta engine for decision tree application

# Model interpretability packages
library(rpart.plot)  # for plotting decision trees
library(vip)         # for feature importance
library(pdp)         # for feature effects

We’ll continue to illustrate the main concepts using the Ames housing data:

# Create training (70%) set for the Ames housing data.
set.seed(123)
split  <- rsample::initial_split(ames, prop = 0.7, strata = "Sale_Price")
ames_train  <- rsample::training(split)

Structure

Figure 9.1:

knitr::include_graphics("images/exemplar-decision-tree.png")

Figure 9.2:

knitr::include_graphics("images/decision-tree-terminology.png")

Partitioning

Figure 9.3:

# create data
set.seed(1112)  # for reproducibility
df <- tibble::tibble(
  x = seq(from = 0, to = 2 * pi, length = 500),
  y = sin(x) + rnorm(length(x), sd = 0.5),
  truth = sin(x)
)

# run decision stump model
ctrl <- list(cp = 0, minbucket = 5, maxdepth = 1)
fit <- rpart(y ~ x, data = df, control = ctrl)

# plot tree 
par(mar = c(1, 1, 1, 1))
rpart.plot(fit)


# plot decision boundary
df %>%
  mutate(pred = predict(fit, df)) %>%
  ggplot(aes(x, y)) +
  geom_point(alpha = .2, size = 1) +
  geom_line(aes(x, y = truth), color = "blue", size = .75) +
  geom_line(aes(y = pred), color = "red", size = .75) +
  geom_segment(x = 3.1, xend = 3.1, y = -Inf, yend = -.95,
               arrow = arrow(length = unit(0.25,"cm")), size = .25) +
  annotate("text", x = 3.1, y = -Inf, label = "split", hjust = 1.2, vjust = -1, size = 3) +
  geom_segment(x = 5.5, xend = 6, y = 2, yend = 2, size = .75, color = "blue") +
  geom_segment(x = 5.5, xend = 6, y = 1.7, yend = 1.7, size = .75, color = "red") +
  annotate("text", x = 5.3, y = 2, label = "truth", hjust = 1, size = 3, color = "blue") +
  annotate("text", x = 5.3, y = 1.7, label = "decision boundary", hjust = 1, size = 3, color = "red")

Figure 9.4:

# fit depth 3 decision tree
ctrl <- list(cp = 0, minbucket = 5, maxdepth = 3)
fit <- rpart(y ~ x, data = df, control = ctrl)
rpart.plot(fit)


# plot decision boundary
df %>%
  mutate(pred = predict(fit, df)) %>%
  ggplot(aes(x, y)) +
  geom_point(alpha = .2, size = 1) +
  geom_line(aes(x, y = truth), color = "blue", size = .75) +
  geom_line(aes(y = pred), color = "red", size = .75)

Figure 9.5:

# decision tree
iris_fit <- rpart(Species ~ Sepal.Length + Sepal.Width, data = iris)
rpart.plot(iris_fit)


# decision boundary
ggplot(iris, aes(Sepal.Length, Sepal.Width, color = Species, shape = Species)) +
  geom_point(show.legend = FALSE) +
  annotate("rect", xmin = -Inf, xmax = 5.44, ymin = 2.8, ymax = Inf, alpha = .75, fill = "orange") +
  annotate("text", x = 4.0, y = 4.4, label = "setosa", hjust = 0, size = 3) +
  annotate("rect", xmin = -Inf, xmax = 5.44, ymin = 2.79, ymax = -Inf, alpha = .75, fill = "grey") +
  annotate("text", x = 4.0, y = 2, label = "versicolor", hjust = 0, size = 3) +
  annotate("rect", xmin = 5.45, xmax = 6.15, ymin = 3.1, ymax = Inf, alpha = .75, fill = "orange") +
  annotate("text", x = 6, y = 4.4, label = "setosa", hjust = 1, vjust = 0, size = 3) +
  annotate("rect", xmin = 5.45, xmax = 6.15, ymin = 3.09, ymax = -Inf, alpha = .75, fill = "grey") +
  annotate("text", x = 6.15, y = 2, label = "versicolor", hjust = 1, vjust = 0, fill = "grey", size = 3) +
  annotate("rect", xmin = 6.16, xmax = Inf, ymin = -Inf, ymax = Inf, alpha = .75, fill = "green") +
  annotate("text", x = 8, y = 2, label = "virginica", hjust = 1, vjust = 0, fill = "green", size = 3)

How deep?

Figure 9.6:

ctrl <- list(cp = 0, minbucket = 1, maxdepth = 50)
fit <- rpart(y ~ x, data = df, control = ctrl)
rpart.plot(fit)


df %>%
  mutate(pred = predict(fit, df)) %>%
  ggplot(aes(x, y)) +
  geom_point(alpha = .2, size = 1) +
  geom_line(aes(x, y = truth), color = "blue", size = 0.75) +
  geom_line(aes(y = pred), color = "red", size = 0.75)

Early stopping

Figure 9.7:

hyper_grid <- expand.grid(
  maxdepth = c(1, 5, 15),
  minbucket = c(1, 5, 15)
)
results <- data.frame(NULL)

for(i in seq_len(nrow(hyper_grid))) {
 ctrl <- list(cp = 0, maxdepth = hyper_grid$maxdepth[i], minbucket = hyper_grid$minbucket[i])
 fit <- rpart(y ~ x, data = df, control = ctrl) 
 
 predictions <- mutate(
   df, 
   minbucket = factor(paste("Min node size =", hyper_grid$minbucket[i]), ordered = TRUE),
   maxdepth = factor(paste("Max tree depth =", hyper_grid$maxdepth[i]), ordered = TRUE)
   )
 predictions$pred <- predict(fit, df)
 results <- rbind(results, predictions)
   
}

ggplot(results, aes(x, y)) +
  geom_point(alpha = .2, size = 1) +
  geom_line(aes(x, y = truth), color = "blue", size = .75) +
  geom_line(aes(y = pred), color = "red", size = 1) +
  facet_grid(minbucket ~ maxdepth)

Pruning

Figure 9.8:

ctrl <- list(cp = 0, minbucket = 1, maxdepth = 50)
fit <- rpart(y ~ x, data = df, control = ctrl)

p1 <- df %>%
  mutate(pred = predict(fit, df)) %>%
  ggplot(aes(x, y)) +
  geom_point(alpha = .3, size = 2) +
  geom_line(aes(x, y = truth), color = "blue", size = 1) +
  geom_line(aes(y = pred), color = "red", size = 1)

fit2 <- rpart(y ~ x, data = df)

p2 <- df %>%
  mutate(pred2 = predict(fit2, df)) %>%
  ggplot(aes(x, y)) +
  geom_point(alpha = .3, size = 2) +
  geom_line(aes(x, y = truth), color = "blue", size = 1) +
  geom_line(aes(y = pred2), color = "red", size = 1)

gridExtra::grid.arrange(p1, p2, nrow = 1)

Ames housing example

ames_dt1 <- rpart(
  formula = Sale_Price ~ .,
  data    = ames_train,
  method  = "anova"
)
ames_dt1
n= 2053 

node), split, n, deviance, yval
      * denotes terminal node

 1) root 2053 13217940000000 180996.30  
   2) Overall_Qual=Very_Poor,Poor,Fair,Below_Average,Average,Above_Average,Good 1722  4107888000000 156954.70  
     4) Neighborhood=North_Ames,Old_Town,Edwards,Sawyer,Mitchell,Brookside,Iowa_DOT_and_Rail_Road,South_and_West_of_Iowa_State_University,Meadow_Village,Briardale,Northpark_Villa,Blueste 1022  1332227000000 132318.00  
       8) Overall_Qual=Very_Poor,Poor,Fair,Below_Average 199   179295400000  98856.51 *
       9) Overall_Qual=Average,Above_Average,Good 823   876239900000 140409.00  
        18) First_Flr_SF< 1089 517   290531200000 129244.00 *
        19) First_Flr_SF>=1089 306   412375700000 159272.60 *
     5) Neighborhood=College_Creek,Somerset,Northridge_Heights,Gilbert,Northwest_Ames,Sawyer_West,Crawford,Timberland,Northridge,Stone_Brook,Clear_Creek,Bloomington_Heights,Veenker,Green_Hills 700  1249681000000 192924.20  
      10) Gr_Liv_Area< 1477 287   250826800000 165395.90 *
      11) Gr_Liv_Area>=1477 413   630227700000 212054.00  
        22) Total_Bsmt_SF< 959.5 199   139087700000 192493.10 *
        23) Total_Bsmt_SF>=959.5 214   344191200000 230243.70 *
   3) Overall_Qual=Very_Good,Excellent,Very_Excellent 331  2936700000000 306070.70  
     6) Overall_Qual=Very_Good 231   946974600000 270626.10  
      12) Gr_Liv_Area< 1919 142   334978300000 244016.60 *
      13) Gr_Liv_Area>=1919 89   351030800000 313081.60 *
     7) Overall_Qual=Excellent,Very_Excellent 100  1029126000000 387948.00  
      14) Total_Bsmt_SF< 1907.5 72   314985900000 350532.40 *
      15) Total_Bsmt_SF>=1907.5 28   354159900000 484159.40 *
rpart.plot(ames_dt1)

plotcp(ames_dt1)

ames_dt2 <- rpart(
    formula = Sale_Price ~ .,
    data    = ames_train,
    method  = "anova", 
    control = list(cp = 0, xval = 10)
)

plotcp(ames_dt2)
abline(v = 11, lty = "dashed")

# rpart cross validation results
ames_dt1$cptable
           CP nsplit rel error    xerror       xstd
1  0.46704344      0 1.0000000 1.0015334 0.06051267
2  0.11544770      1 0.5329566 0.5343697 0.03079312
3  0.07267387      2 0.4175089 0.4209603 0.03007122
4  0.02788834      3 0.3448350 0.3502963 0.02145751
5  0.02723422      4 0.3169466 0.3319341 0.02225037
6  0.02093301      5 0.2897124 0.3125117 0.02150290
7  0.01974328      6 0.2687794 0.2986956 0.02139660
8  0.01311346      7 0.2490361 0.2726862 0.01738257
9  0.01111737      8 0.2359227 0.2654669 0.01725615
10 0.01000000      9 0.2248053 0.2584346 0.01721996
# caret cross validation results
ames_dt3 <- train(
  Sale_Price ~ .,
  data = ames_train,
  method = "rpart",
  trControl = trainControl(method = "cv", number = 10),
  tuneLength = 20
)

ggplot(ames_dt3)

Feature interpretation

vip(ames_dt3, num_features = 40, bar = FALSE)

# Construct partial dependence plots
p1 <- partial(ames_dt3, pred.var = "Gr_Liv_Area") %>% autoplot()
p2 <- partial(ames_dt3, pred.var = "Year_Built") %>% autoplot()
p3 <- partial(ames_dt3, pred.var = c("Gr_Liv_Area", "Year_Built")) %>% 
  plotPartial(levelplot = FALSE, zlab = "yhat", drape = TRUE, 
              colorkey = TRUE, screen = list(z = -20, x = -60))

# Display plots side by side
gridExtra::grid.arrange(p1, p2, p3, ncol = 3)

