Note: Some results may differ from the hard copy book due to the changing of sampling procedures introduced in R 3.6.0. See http://bit.ly/35D1SW7 for more details. Access and run the source code for this notebook here.
Hidden chapter requirements used in the book to set the plotting theme and load packages used in hidden code chunks:
# Set global R options
options(scipen = 999)
# Set the graphical theme
ggplot2::theme_set(ggplot2::theme_light())
# Set global knitr chunk options
knitr::opts_chunk$set(
warning = FALSE,
message = FALSE
)
library(tidyverse)
library(rpart)
library(rpart.plot)
ames <- AmesHousing::make_ames()
Prerequisites
In this chapter we’ll use the following packages:
# Helper packages
library(dplyr) # for data wrangling
library(ggplot2) # for awesome plotting
# Modeling packages
library(rpart) # direct engine for decision tree application
library(caret) # meta engine for decision tree application
# Model interpretability packages
library(rpart.plot) # for plotting decision trees
library(vip) # for feature importance
library(pdp) # for feature effects
We’ll continue to illustrate the main concepts using the Ames housing data:
# Create training (70%) set for the Ames housing data.
set.seed(123)
split <- rsample::initial_split(ames, prop = 0.7, strata = "Sale_Price")
ames_train <- rsample::training(split)
Structure
Figure 9.1:
knitr::include_graphics("images/exemplar-decision-tree.png")
Figure 9.2:
knitr::include_graphics("images/decision-tree-terminology.png")
Partitioning
Figure 9.3:
# create data
set.seed(1112) # for reproducibility
df <- tibble::tibble(
x = seq(from = 0, to = 2 * pi, length = 500),
y = sin(x) + rnorm(length(x), sd = 0.5),
truth = sin(x)
)
# run decision stump model
ctrl <- list(cp = 0, minbucket = 5, maxdepth = 1)
fit <- rpart(y ~ x, data = df, control = ctrl)
# plot tree
par(mar = c(1, 1, 1, 1))
rpart.plot(fit)
# plot decision boundary
df %>%
mutate(pred = predict(fit, df)) %>%
ggplot(aes(x, y)) +
geom_point(alpha = .2, size = 1) +
geom_line(aes(x, y = truth), color = "blue", size = .75) +
geom_line(aes(y = pred), color = "red", size = .75) +
geom_segment(x = 3.1, xend = 3.1, y = -Inf, yend = -.95,
arrow = arrow(length = unit(0.25,"cm")), size = .25) +
annotate("text", x = 3.1, y = -Inf, label = "split", hjust = 1.2, vjust = -1, size = 3) +
geom_segment(x = 5.5, xend = 6, y = 2, yend = 2, size = .75, color = "blue") +
geom_segment(x = 5.5, xend = 6, y = 1.7, yend = 1.7, size = .75, color = "red") +
annotate("text", x = 5.3, y = 2, label = "truth", hjust = 1, size = 3, color = "blue") +
annotate("text", x = 5.3, y = 1.7, label = "decision boundary", hjust = 1, size = 3, color = "red")
Figure 9.4:
# fit depth 3 decision tree
ctrl <- list(cp = 0, minbucket = 5, maxdepth = 3)
fit <- rpart(y ~ x, data = df, control = ctrl)
rpart.plot(fit)
# plot decision boundary
df %>%
mutate(pred = predict(fit, df)) %>%
ggplot(aes(x, y)) +
geom_point(alpha = .2, size = 1) +
geom_line(aes(x, y = truth), color = "blue", size = .75) +
geom_line(aes(y = pred), color = "red", size = .75)
Figure 9.5:
# decision tree
iris_fit <- rpart(Species ~ Sepal.Length + Sepal.Width, data = iris)
rpart.plot(iris_fit)
# decision boundary
ggplot(iris, aes(Sepal.Length, Sepal.Width, color = Species, shape = Species)) +
geom_point(show.legend = FALSE) +
annotate("rect", xmin = -Inf, xmax = 5.44, ymin = 2.8, ymax = Inf, alpha = .75, fill = "orange") +
annotate("text", x = 4.0, y = 4.4, label = "setosa", hjust = 0, size = 3) +
annotate("rect", xmin = -Inf, xmax = 5.44, ymin = 2.79, ymax = -Inf, alpha = .75, fill = "grey") +
annotate("text", x = 4.0, y = 2, label = "versicolor", hjust = 0, size = 3) +
annotate("rect", xmin = 5.45, xmax = 6.15, ymin = 3.1, ymax = Inf, alpha = .75, fill = "orange") +
annotate("text", x = 6, y = 4.4, label = "setosa", hjust = 1, vjust = 0, size = 3) +
annotate("rect", xmin = 5.45, xmax = 6.15, ymin = 3.09, ymax = -Inf, alpha = .75, fill = "grey") +
annotate("text", x = 6.15, y = 2, label = "versicolor", hjust = 1, vjust = 0, fill = "grey", size = 3) +
annotate("rect", xmin = 6.16, xmax = Inf, ymin = -Inf, ymax = Inf, alpha = .75, fill = "green") +
annotate("text", x = 8, y = 2, label = "virginica", hjust = 1, vjust = 0, fill = "green", size = 3)
How deep?
Figure 9.6:
ctrl <- list(cp = 0, minbucket = 1, maxdepth = 50)
fit <- rpart(y ~ x, data = df, control = ctrl)
rpart.plot(fit)
df %>%
mutate(pred = predict(fit, df)) %>%
ggplot(aes(x, y)) +
geom_point(alpha = .2, size = 1) +
geom_line(aes(x, y = truth), color = "blue", size = 0.75) +
geom_line(aes(y = pred), color = "red", size = 0.75)
Early stopping
Figure 9.7:
hyper_grid <- expand.grid(
maxdepth = c(1, 5, 15),
minbucket = c(1, 5, 15)
)
results <- data.frame(NULL)
for(i in seq_len(nrow(hyper_grid))) {
ctrl <- list(cp = 0, maxdepth = hyper_grid$maxdepth[i], minbucket = hyper_grid$minbucket[i])
fit <- rpart(y ~ x, data = df, control = ctrl)
predictions <- mutate(
df,
minbucket = factor(paste("Min node size =", hyper_grid$minbucket[i]), ordered = TRUE),
maxdepth = factor(paste("Max tree depth =", hyper_grid$maxdepth[i]), ordered = TRUE)
)
predictions$pred <- predict(fit, df)
results <- rbind(results, predictions)
}
ggplot(results, aes(x, y)) +
geom_point(alpha = .2, size = 1) +
geom_line(aes(x, y = truth), color = "blue", size = .75) +
geom_line(aes(y = pred), color = "red", size = 1) +
facet_grid(minbucket ~ maxdepth)
Pruning
Figure 9.8:
ctrl <- list(cp = 0, minbucket = 1, maxdepth = 50)
fit <- rpart(y ~ x, data = df, control = ctrl)
p1 <- df %>%
mutate(pred = predict(fit, df)) %>%
ggplot(aes(x, y)) +
geom_point(alpha = .3, size = 2) +
geom_line(aes(x, y = truth), color = "blue", size = 1) +
geom_line(aes(y = pred), color = "red", size = 1)
fit2 <- rpart(y ~ x, data = df)
p2 <- df %>%
mutate(pred2 = predict(fit2, df)) %>%
ggplot(aes(x, y)) +
geom_point(alpha = .3, size = 2) +
geom_line(aes(x, y = truth), color = "blue", size = 1) +
geom_line(aes(y = pred2), color = "red", size = 1)
gridExtra::grid.arrange(p1, p2, nrow = 1)
Ames housing example
ames_dt1 <- rpart(
formula = Sale_Price ~ .,
data = ames_train,
method = "anova"
)
ames_dt1
n= 2053
node), split, n, deviance, yval
* denotes terminal node
1) root 2053 13217940000000 180996.30
2) Overall_Qual=Very_Poor,Poor,Fair,Below_Average,Average,Above_Average,Good 1722 4107888000000 156954.70
4) Neighborhood=North_Ames,Old_Town,Edwards,Sawyer,Mitchell,Brookside,Iowa_DOT_and_Rail_Road,South_and_West_of_Iowa_State_University,Meadow_Village,Briardale,Northpark_Villa,Blueste 1022 1332227000000 132318.00
8) Overall_Qual=Very_Poor,Poor,Fair,Below_Average 199 179295400000 98856.51 *
9) Overall_Qual=Average,Above_Average,Good 823 876239900000 140409.00
18) First_Flr_SF< 1089 517 290531200000 129244.00 *
19) First_Flr_SF>=1089 306 412375700000 159272.60 *
5) Neighborhood=College_Creek,Somerset,Northridge_Heights,Gilbert,Northwest_Ames,Sawyer_West,Crawford,Timberland,Northridge,Stone_Brook,Clear_Creek,Bloomington_Heights,Veenker,Green_Hills 700 1249681000000 192924.20
10) Gr_Liv_Area< 1477 287 250826800000 165395.90 *
11) Gr_Liv_Area>=1477 413 630227700000 212054.00
22) Total_Bsmt_SF< 959.5 199 139087700000 192493.10 *
23) Total_Bsmt_SF>=959.5 214 344191200000 230243.70 *
3) Overall_Qual=Very_Good,Excellent,Very_Excellent 331 2936700000000 306070.70
6) Overall_Qual=Very_Good 231 946974600000 270626.10
12) Gr_Liv_Area< 1919 142 334978300000 244016.60 *
13) Gr_Liv_Area>=1919 89 351030800000 313081.60 *
7) Overall_Qual=Excellent,Very_Excellent 100 1029126000000 387948.00
14) Total_Bsmt_SF< 1907.5 72 314985900000 350532.40 *
15) Total_Bsmt_SF>=1907.5 28 354159900000 484159.40 *
rpart.plot(ames_dt1)
plotcp(ames_dt1)
ames_dt2 <- rpart(
formula = Sale_Price ~ .,
data = ames_train,
method = "anova",
control = list(cp = 0, xval = 10)
)
plotcp(ames_dt2)
abline(v = 11, lty = "dashed")
# rpart cross validation results
ames_dt1$cptable
CP nsplit rel error xerror xstd
1 0.46704344 0 1.0000000 1.0015334 0.06051267
2 0.11544770 1 0.5329566 0.5343697 0.03079312
3 0.07267387 2 0.4175089 0.4209603 0.03007122
4 0.02788834 3 0.3448350 0.3502963 0.02145751
5 0.02723422 4 0.3169466 0.3319341 0.02225037
6 0.02093301 5 0.2897124 0.3125117 0.02150290
7 0.01974328 6 0.2687794 0.2986956 0.02139660
8 0.01311346 7 0.2490361 0.2726862 0.01738257
9 0.01111737 8 0.2359227 0.2654669 0.01725615
10 0.01000000 9 0.2248053 0.2584346 0.01721996
# caret cross validation results
ames_dt3 <- train(
Sale_Price ~ .,
data = ames_train,
method = "rpart",
trControl = trainControl(method = "cv", number = 10),
tuneLength = 20
)
ggplot(ames_dt3)
Feature interpretation
vip(ames_dt3, num_features = 40, bar = FALSE)
# Construct partial dependence plots
p1 <- partial(ames_dt3, pred.var = "Gr_Liv_Area") %>% autoplot()
p2 <- partial(ames_dt3, pred.var = "Year_Built") %>% autoplot()
p3 <- partial(ames_dt3, pred.var = c("Gr_Liv_Area", "Year_Built")) %>%
plotPartial(levelplot = FALSE, zlab = "yhat", drape = TRUE,
colorkey = TRUE, screen = list(z = -20, x = -60))
# Display plots side by side
gridExtra::grid.arrange(p1, p2, p3, ncol = 3)
---
title: "Chapter 9: Decision Trees"
output: html_notebook
---

__Note__: Some results may differ from the hard copy book due to the changing of sampling procedures introduced in R 3.6.0. See http://bit.ly/35D1SW7 for more details. Access and run the source code for this notebook [here](https://rstudio.cloud/project/801185). 

Hidden chapter requirements used in the book to set the plotting theme and load packages used in hidden code chunks:

```{r setup}
# Set global R options
options(scipen = 999)

# Set the graphical theme
ggplot2::theme_set(ggplot2::theme_light())

# Set global knitr chunk options
knitr::opts_chunk$set(
  warning = FALSE, 
  message = FALSE
)

library(tidyverse)
library(rpart)
library(rpart.plot)
ames <- AmesHousing::make_ames()
```

## Prerequisites

In this chapter we'll use the following packages:

```{r dt-pkgs}
# Helper packages
library(dplyr)       # for data wrangling
library(ggplot2)     # for awesome plotting

# Modeling packages
library(rpart)       # direct engine for decision tree application
library(caret)       # meta engine for decision tree application

# Model interpretability packages
library(rpart.plot)  # for plotting decision trees
library(vip)         # for feature importance
library(pdp)         # for feature effects
```

We'll continue to illustrate the main concepts using the Ames housing data:

```{r dt-data-prereq, echo=TRUE}
# Create training (70%) set for the Ames housing data.
set.seed(123)
split  <- rsample::initial_split(ames, prop = 0.7, strata = "Sale_Price")
ames_train  <- rsample::training(split)
```


## Structure

Figure 9.1:

```{r exemplar-decision-tree, echo=TRUE, fig.cap="Exemplar decision tree predicting whether or not a customer will redeem a coupon (yes or no) based on the customer's loyalty, household income, last month's spend, coupon placement, and shopping mode.", out.height="100%", out.width="100%"}
knitr::include_graphics("images/exemplar-decision-tree.png")
```

Figure 9.2:

```{r decision-tree-terminology, echo=TRUE, fig.cap="Terminology of a decision tree.", out.height="80%", out.width="80%"}
knitr::include_graphics("images/decision-tree-terminology.png")
```


## Partitioning

Figure 9.3:

```{r decision-stump, echo=TRUE, fig.width=4, fig.height=3, fig.show='hold', fig.cap="Decision tree illustrating the single split on feature x (left). The resulting decision boundary illustrates the predicted value when x < 3.1 (0.64), and when x > 3.1 (-0.67) (right).", out.width="48%"}
# create data
set.seed(1112)  # for reproducibility
df <- tibble::tibble(
  x = seq(from = 0, to = 2 * pi, length = 500),
  y = sin(x) + rnorm(length(x), sd = 0.5),
  truth = sin(x)
)

# run decision stump model
ctrl <- list(cp = 0, minbucket = 5, maxdepth = 1)
fit <- rpart(y ~ x, data = df, control = ctrl)

# plot tree 
par(mar = c(1, 1, 1, 1))
rpart.plot(fit)

# plot decision boundary
df %>%
  mutate(pred = predict(fit, df)) %>%
  ggplot(aes(x, y)) +
  geom_point(alpha = .2, size = 1) +
  geom_line(aes(x, y = truth), color = "blue", size = .75) +
  geom_line(aes(y = pred), color = "red", size = .75) +
  geom_segment(x = 3.1, xend = 3.1, y = -Inf, yend = -.95,
               arrow = arrow(length = unit(0.25,"cm")), size = .25) +
  annotate("text", x = 3.1, y = -Inf, label = "split", hjust = 1.2, vjust = -1, size = 3) +
  geom_segment(x = 5.5, xend = 6, y = 2, yend = 2, size = .75, color = "blue") +
  geom_segment(x = 5.5, xend = 6, y = 1.7, yend = 1.7, size = .75, color = "red") +
  annotate("text", x = 5.3, y = 2, label = "truth", hjust = 1, size = 3, color = "blue") +
  annotate("text", x = 5.3, y = 1.7, label = "decision boundary", hjust = 1, size = 3, color = "red")
```

Figure 9.4:

```{r depth-3-decision-tree, echo=TRUE, fig.width=4, fig.height=3, fig.show='hold', fig.cap="Decision tree illustrating with depth = 3, resulting in 7 decision splits along values of feature x and 8 prediction regions (left). The resulting decision boundary (right).", out.width="48%"}
# fit depth 3 decision tree
ctrl <- list(cp = 0, minbucket = 5, maxdepth = 3)
fit <- rpart(y ~ x, data = df, control = ctrl)
rpart.plot(fit)

# plot decision boundary
df %>%
  mutate(pred = predict(fit, df)) %>%
  ggplot(aes(x, y)) +
  geom_point(alpha = .2, size = 1) +
  geom_line(aes(x, y = truth), color = "blue", size = .75) +
  geom_line(aes(y = pred), color = "red", size = .75)
```

Figure 9.5:

```{r iris-decision-tree, echo=TRUE, fig.width=4, fig.height=3, fig.show='hold', fig.cap="Decision tree for the iris classification problem (left). The decision boundary results in rectangular regions that enclose the observations.  The class with the highest proportion in each region is the predicted value (right).", out.width="48%"}
# decision tree
iris_fit <- rpart(Species ~ Sepal.Length + Sepal.Width, data = iris)
rpart.plot(iris_fit)

# decision boundary
ggplot(iris, aes(Sepal.Length, Sepal.Width, color = Species, shape = Species)) +
  geom_point(show.legend = FALSE) +
  annotate("rect", xmin = -Inf, xmax = 5.44, ymin = 2.8, ymax = Inf, alpha = .75, fill = "orange") +
  annotate("text", x = 4.0, y = 4.4, label = "setosa", hjust = 0, size = 3) +
  annotate("rect", xmin = -Inf, xmax = 5.44, ymin = 2.79, ymax = -Inf, alpha = .75, fill = "grey") +
  annotate("text", x = 4.0, y = 2, label = "versicolor", hjust = 0, size = 3) +
  annotate("rect", xmin = 5.45, xmax = 6.15, ymin = 3.1, ymax = Inf, alpha = .75, fill = "orange") +
  annotate("text", x = 6, y = 4.4, label = "setosa", hjust = 1, vjust = 0, size = 3) +
  annotate("rect", xmin = 5.45, xmax = 6.15, ymin = 3.09, ymax = -Inf, alpha = .75, fill = "grey") +
  annotate("text", x = 6.15, y = 2, label = "versicolor", hjust = 1, vjust = 0, fill = "grey", size = 3) +
  annotate("rect", xmin = 6.16, xmax = Inf, ymin = -Inf, ymax = Inf, alpha = .75, fill = "green") +
  annotate("text", x = 8, y = 2, label = "virginica", hjust = 1, vjust = 0, fill = "green", size = 3)

```


## How deep?

Figure 9.6:

```{r deep-overfit-tree, echo=TRUE, fig.width=4, fig.height=3, fig.show='hold', fig.cap="Overfit decision tree with 56 splits.", out.width="48%"}
ctrl <- list(cp = 0, minbucket = 1, maxdepth = 50)
fit <- rpart(y ~ x, data = df, control = ctrl)
rpart.plot(fit)

df %>%
  mutate(pred = predict(fit, df)) %>%
  ggplot(aes(x, y)) +
  geom_point(alpha = .2, size = 1) +
  geom_line(aes(x, y = truth), color = "blue", size = 0.75) +
  geom_line(aes(y = pred), color = "red", size = 0.75)
```

### Early stopping

Figure 9.7:

```{r dt-early-stopping, fig.width=10, fig.height=8, fig.cap="Illustration of how early stopping affects the decision boundary of a regression decision tree. The columns illustrate how tree depth impacts the decision boundary and the rows illustrate how the minimum number of observations in the terminal node influences the decision boundary.", echo=TRUE}
hyper_grid <- expand.grid(
  maxdepth = c(1, 5, 15),
  minbucket = c(1, 5, 15)
)
results <- data.frame(NULL)

for(i in seq_len(nrow(hyper_grid))) {
 ctrl <- list(cp = 0, maxdepth = hyper_grid$maxdepth[i], minbucket = hyper_grid$minbucket[i])
 fit <- rpart(y ~ x, data = df, control = ctrl) 
 
 predictions <- mutate(
   df, 
   minbucket = factor(paste("Min node size =", hyper_grid$minbucket[i]), ordered = TRUE),
   maxdepth = factor(paste("Max tree depth =", hyper_grid$maxdepth[i]), ordered = TRUE)
   )
 predictions$pred <- predict(fit, df)
 results <- rbind(results, predictions)
   
}

ggplot(results, aes(x, y)) +
  geom_point(alpha = .2, size = 1) +
  geom_line(aes(x, y = truth), color = "blue", size = .75) +
  geom_line(aes(y = pred), color = "red", size = 1) +
  facet_grid(minbucket ~ maxdepth)
```

### Pruning

Figure 9.8:

```{r pruned-tree, fig.width=10, fig.height = 4, fig.cap="To prune a tree, we grow an overly complex tree (left) and then use a cost complexity parameter to identify the optimal subtree (right).", echo=TRUE}
ctrl <- list(cp = 0, minbucket = 1, maxdepth = 50)
fit <- rpart(y ~ x, data = df, control = ctrl)

p1 <- df %>%
  mutate(pred = predict(fit, df)) %>%
  ggplot(aes(x, y)) +
  geom_point(alpha = .3, size = 2) +
  geom_line(aes(x, y = truth), color = "blue", size = 1) +
  geom_line(aes(y = pred), color = "red", size = 1)

fit2 <- rpart(y ~ x, data = df)

p2 <- df %>%
  mutate(pred2 = predict(fit2, df)) %>%
  ggplot(aes(x, y)) +
  geom_point(alpha = .3, size = 2) +
  geom_line(aes(x, y = truth), color = "blue", size = 1) +
  geom_line(aes(y = pred2), color = "red", size = 1)

gridExtra::grid.arrange(p1, p2, nrow = 1)
```


## Ames housing example

```{r basic-ames-tree}
ames_dt1 <- rpart(
  formula = Sale_Price ~ .,
  data    = ames_train,
  method  = "anova"
)
```

```{r basic-ames-dt-tree-results, linewidth = 70}
ames_dt1
```

```{r basic-ames-tree-plot, fig.width=10, fig.height=6, fig.cap="Diagram displaying the pruned decision tree for the Ames Housing data."}
rpart.plot(ames_dt1)
```

```{r plot-cp, fig.width = 5, fig.height=3.5, fig.cap="Pruning complexity parameter (cp) plot illustrating the relative cross validation error (y-axis) for various cp values (lower x-axis). Smaller cp values lead to larger trees (upper x-axis). Using the 1-SE rule, a tree size of 10-12 provides optimal cross validation results."}
plotcp(ames_dt1)
```

```{r no-cp-tree, fig.cap="Pruning complexity parameter plot for a fully grown tree. Significant reduction in the cross validation error is achieved with tree sizes 6-20 and then the cross validation error levels off with minimal or no additional improvements."}
ames_dt2 <- rpart(
    formula = Sale_Price ~ .,
    data    = ames_train,
    method  = "anova", 
    control = list(cp = 0, xval = 10)
)

plotcp(ames_dt2)
abline(v = 11, lty = "dashed")
```

```{r cp-table, fig.cap="Cross-validated accuracy rate for the 20 different $\\alpha$ parameter values in our grid search. Lower $\\alpha$ values (deeper trees) help to minimize errors.", fig.height=3}
# rpart cross validation results
ames_dt1$cptable

# caret cross validation results
ames_dt3 <- train(
  Sale_Price ~ .,
  data = ames_train,
  method = "rpart",
  trControl = trainControl(method = "cv", number = 10),
  tuneLength = 20
)

ggplot(ames_dt3)
```

## Feature interpretation 

```{r dt-vip, fig.height=5.5, fig.cap="Variable importance based on the total reduction in MSE for the Ames Housing decision tree."}
vip(ames_dt3, num_features = 40, bar = FALSE)
```

```{r dt-pdp, fig.width=10, fig.height= 3.5, fig.cap="Partial dependence plots to understand the relationship between sale price and the living space, and year built features."}
# Construct partial dependence plots
p1 <- partial(ames_dt3, pred.var = "Gr_Liv_Area") %>% autoplot()
p2 <- partial(ames_dt3, pred.var = "Year_Built") %>% autoplot()
p3 <- partial(ames_dt3, pred.var = c("Gr_Liv_Area", "Year_Built")) %>% 
  plotPartial(levelplot = FALSE, zlab = "yhat", drape = TRUE, 
              colorkey = TRUE, screen = list(z = -20, x = -60))

# Display plots side by side
gridExtra::grid.arrange(p1, p2, p3, ncol = 3)
```
