BAIT 509 Class Meeting 05

Classification and Regression Trees (CART)

Outline

  • Classification and Regression Trees

Motivation

Some problems (or at least potential problems) with the local methods introduced last time:

  1. They lack interpretation.
    • It’s not easy to say how the predictors influence the response from the fitted model.
  2. They typically require a data-rich situation so that the estimation variance is acceptable, without compromising the estimation bias.

We’ll look at classification and regression trees as another method (sometimes called CART). CARTs are not necessarily better, and in fact, tend to be poor competitors to other methods – but ensemble extensions of these tend to perform well (a topic for next Monday).

Our setting this time is the usual: we have a response \(Y\) (either categorical or numeric), and hope to predict this response using \(p\) predictors \(X_1,\ldots,X_p\).

  • When the response is categorical, we aim to estimate the mode and take that as our prediction.
  • When the response is numeric, we aim to estimate the mean, and take that as our prediction.

Stumps: A preliminary concept

What are they?

Let’s say I get an upset stomach once in a while, and I suspect certain foods might be responsible. My response and predictors are:

  • \(Y\): sick or not sick (categorical)
  • \(X_1\): amount of eggs consumed in a day.
  • \(X_2\): amount of milk consumed in a day, in liters.

You make a food diary, and record the following:

Eggs Milk Sick?
0 0.7 Yes
1 0.6 Yes
0 0 No
1 0.7 Yes
1 0 Yes
0 0.4 No

(Example from Mark Schmidt’s CPSC 340)

A decision/classification stump is a decision on \(Y\) based on the value of one of the predictors. You can make one by doing the following steps:

  1. Choose a predictor (in this case, Milk or Eggs?)
  2. Choose a cutpoint on that predictor:
    • Subset the data having that predictor less than that cutpoint, and take the mode as your decision/classification.
    • Subset the data having that predictor greater than that cutpoint, and take the mode as your decision/classification.

This will get you a decision/classification stump, like so:

(Image attribution: Hyeju Jang, DSCI 571)

The same idea applies in the regression case, except we take the average of the subsetted data, and choose the best one according to the mean squared error.

Some things to note:

  • A decision on a numeric variable is always made based on a threshold (either less than or greater than). Although it makes sense to talk about a stump based on a more complicated decision, allowing for more complicated decisions would increase the search space, and would make fitting these models impractical. Plus, sometimes a more complicated decision can be broken down into several single-threshold decisions.
  • It follows that decisions are always binary. Again, although it makes sense to talk about a stump that has more than two options, we can usually write these as a tree.

Choosing the best stumps

Question: What is the error of the above decision stump? If we decide to make a prediction without a decision stump (or any predictors at all), what would the best prediction be, and what’s the error?

We have many stumps we can choose from – which one is best?

Classification:

  • The one that gives the least error.
    • Tends to have a harder time disambiguating several options.
  • The Gini and Entropy criteria are two of the more commonly used measures.
    • These are measures of purity.

Regression:

  • The one that gives the least mean squared error.

Partitioning of predictor space

How does a stump partition the predictor space?

Trees

Notice that a tree is just a bunch of stumps!

This means:

1. For numeric predictors, decisions can only involve one threshold.
2. Decisions are always binary.

Optimal tree – computationally infeasible

Ideally, we’d examine the error from every possible regression tree, and choose the best one. But there are far too many trees possible! We turn to recursive binary splitting as the next best thing.

Recursive Binary Splitting

The idea to making a tree is to take a greedy approach: keep adding decision stumps, each time choosing the best stump.

This is called “greedy” because this method may not end up with the best tree – it’s not computationally feasible to look many steps ahead to see whether a different path is more feasible.

Adding “Eggs < 1” to the above stump will be the next best iteration – in this case, resulting in 100% prediction accuracy (on the training data). Notice that every time we add a stump, the error on the training set decreases.

We keep iterating until we reach some stopping criterion, often based on things such as:

  • Tree depth
  • Number of observations in a leaf becomes too small
  • Error is not reduced significantly

Questions:

  • How does “tree size” affect the bias-variance tradeoff?
  • How can you choose these parameters?

Pruning

Sometimes, a stopping criterion may cause the tree to stop growing prematurely. For example, perhaps we need an insignificant stump to form before getting a more significant one.

To avoid this, we take the approach of growing an overly large (overfit) tree, and then pruning it. We won’t get into details of how the tree is pruned back, but one technique is called cost complexity pruning. The general idea is to control a tuning parameter \(\alpha\) to control how much pruning is done – it can be chosen by cross-validation, or the validation set approach.

Exercise

  • Coding a decision tree with tree::tree() in R, using the titanic data.
  • predict() and plot(), and “S3 generics” in R.

Classification and Regression Trees: in R

To fit classification and regression trees in R, we use the package tree and the function tree(), which works similarly to lm() and loess():

suppressPackageStartupMessages(library(tree))
suppressPackageStartupMessages(library(tidyverse))
## Warning: package 'ggplot2' was built under R version 3.5.2
## Warning: package 'tibble' was built under R version 3.5.2
## Warning: package 'dplyr' was built under R version 3.5.2
fit <- tree(Sepal.Width ~ ., data=iris)
summary(fit)
## 
## Regression tree:
## tree(formula = Sepal.Width ~ ., data = iris)
## Variables actually used in tree construction:
## [1] "Petal.Length" "Sepal.Length" "Petal.Width" 
## Number of terminal nodes:  10 
## Residual mean deviance:  0.06268 = 8.776 / 140 
## Distribution of residuals:
##     Min.  1st Qu.   Median     Mean  3rd Qu.     Max. 
## -0.60560 -0.16780  0.03182  0.00000  0.16280  0.63180
plot(fit)  # Plot the tree, without labels
text(fit)  # Add labels to the tree

(Messy – would need to modify the predictor names). Make predictions with predict:

predict(fit, newdata=iris) %>% 
    head
##        1        2        3        4        5        6 
## 3.635294 3.273913 3.273913 3.273913 3.273913 3.635294

If you want to control the depth of the tree, use the tree.control function in the control argument. Arguments of tree.control that are relevant are:

  • mindev: The minimum error reduction acceptable before stopping the tree growth.
  • minsize: The minimum number of observations required in a node in order for a tree to keep growing.

Let’s fit a tree to the max, and check its MSE:

fitfull <- tree(Sepal.Width ~ ., data=iris, 
                control=tree.control(nrow(iris), 
                                     mindev=0, minsize=2))
mean((predict(fitfull) - iris$Sepal.Width)^2)
## [1] 0.0008666667

You can investigate the pruning of a tree via cross validation using the cv.tree function. Specify FUN=prune.misclass if you want to prune based on classification error instead of purity measurements, for classification. It returns a list, where the important components are named "size" (number of terminal nodes) and "dev" (the error). Let’s plot those:

set.seed(4)
fitfull_cv <- cv.tree(fitfull)
plot(fitfull_cv$size, log(fitfull_cv$dev))

The x-axis represents the number of terminal nodes present in the subtree. The y-axis is the cross-validation error for that subtree.

From the plot, it looks like it’s best to prune the tree to have approximately 10 terminal nodes. Use prune.tree(fitfull, best=10) to prune the tree to 10 terminal nodes.

*Note: if you encounter an error running prune.tree(fitfull, best=10), it’s not a true error (I believe it’s only an error to the print call, which is called by default). Wrap the code with try:

fit_pruned <- try(prune.tree(fitfull, best=10))
plot(fit_pruned)
text(fit_pruned)

Avatar
Vincenzo Coia
he/him/his 🌈 👨

I’m a data scientist at the University of British Columbia, Vancouver.

comments powered by Disqus