Skip to main content

Decision Trees: A Guide with Examples

A tutorial covering Decision Trees, complete with code and interactive visualizations
Created on September 14|Last edited on December 25
In this article, we provide a tutorial on Decision Trees, one of the most classic deep learning models. We'll also learn how to log and visualize them using Weights & Biases and provide code examples so you can follow along.
Get started here: Quick Start Colab
Here's what we'll be covering:

Table of Contents (click to expand)



Let's get going!

What Are Decision Trees in Machine Learning?

A decision tree is a hierarchical data structure (i.e. tree-like) implementing a divide-and-conquer approach to machine learning. It's an efficient nonparametric supervised method, which can be used for both classification and regression.
In parametric estimation, we define a model over the whole input space and learn its parameters from all of the training data. Then we use the same model and the same parameter set for any test input. In non-parametric estimation, we divide the input space into local regions, defined by distance measures.
A decision tree is a non-parametric model in the sense that we do not assume any parametric form for the class densities, and the tree structure is not fixed a priori, but the tree grows, branches and leaves are added, during learning depending on the complexity of the problem inherent in the data.
A tree 🌲 is a set of one or more nodes. A node xix_i, subscript ii being either empty or a sequence of one or more non-negative integers, is joined to another node xijx_{ij} by an arc directed from xix_i to xijx_{ij}.


Figure 1: What a Tree looks like.

Where to Split Your Decision Tree

A decision at a particular node is known as a split since we're essentially splitting our input data distribution into multiple subsets. The root (top-most) node represents the entire training data. Each split thereafter represents some split in order to aid decision-making.
This decision is usually made based on some sort of rule/impurity measure. If a leaf is "pure," that basically means that the leaf just represents a single class. But on the other hand, if the leaf is impure, then it possibly contains multiple classes.
Here, we go over some of the rules/criteria typically used in decision trees:

Information Gain

One such measure from Information Theory that can be used is known as Entropy. It's basically the measure of the randomness in data points and is given by the following equation
Entropy=i=1npilog2(pi)Entropy = -\sum_{i=1}^{n}p_i * log_2(p_i)

where pip_i represents the probability that any particular data point belongs to a particular node.
Our overall goal is to reduce the entropy of the system. This is done using another metric called "Information Gain (IG)". As the name suggests, Information Gain gives us some idea about how informative a particular split is. It's calculated by subtracting the entropy of a particular attribute from the entropy of the whole data distribution.
Gain(D,A)=Entropy(D)Entropy(D,A)Gain(D, A) = Entropy(D) - Entropy(D, A)

In other words, IG tells us how much one would "gain" by branching on some particular attribute AA. Thus, the splitting attribute at any particular node, is the one with the highest gain.

Gini Index

  • Gini Index, also known as Gini impurity, calculates the amount of probability of a specific feature that is classified incorrectly when selected randomly. If all the elements are linked with a single class then it can be called pure.
  • It varies between 0 and 1
  • It's calculated by deducting the sum of square of probabilities of each class from one
GiniIndex=1i1n(Pi)2Gini \, Index = 1 - \sum_{i-1}^{n} (P_i)^2


Reduction in Variance in Decision Trees

Reduction in variance is an algorithm used for continuous target variables. This algorithm uses the standard formula of variance to choose the best split. The split with lower variance is selected as the criteria to split the population. Here's the formula:
Variance=(XXˉ)2nVariance = \frac{\sum (X - \bar{X})^2}{n}


Chi-Square

  • It is an algorithm to find out the statistical significance of the differences between sub-nodes and parent nodes.
  • We measure it by the sum of squares of standardized differences between observed and expected frequencies of a target variable.
Now deciding which measure to use depends on the problem. For example, Gini Index or Information Gain is used for classification problems, whereas other typical measures like Root Mean Square (RMS) or Mean Absolute Error (MAE) are used for regression problems.



Pruning Your Decision Tree

Frequently, a node is not split further if the number of training instances reaching a node is smaller than a certain percentage of the training set. The idea is that any decision based on too few instances causes variance and, thus, generalization error. A stopping tree constructs early on before it is full is called pre-pruning the tree.
Another possibility to get simpler trees is post-pruning, which in practice, works better than pre-pruning. Tree growing is greedy, and at each step, we make a decision and continue further on, never backtracking and trying out an alternative. The only exception is post-pruning, where we try to find and prune unnecessary subtrees.

Decision Trees, Post-Pruning

In post-pruning, we grow the tree full until all the leaves are pure and we have no training error. We then find subtrees that cause overfitting, and we prune them. From the initial labeled set, we set aside a pruning set, unused during training. For each subtree, we replace it with a leaf node labeled with the training instances covered by the subtree.
If the leaf node does not perform worse than the subtree on the pruning set, we prune the subtree and keep the leaf node because the additional complexity of the subtree is not justified; otherwise, we keep the subtree.
Comparing pre-pruning and post-pruning, we can say that pre-pruning is faster but post-pruning generally leads to more accurate trees



Strengths of Decision Trees

  • They produce very simple, understandable rules. For smaller trees, not much mathematical and computational knowledge is required to understand this model.
  • Works well for most of the problems.
  • They can handle both numerical and categorical variables.
  • They can work well both with small and large training data sets.
  • Decision trees provide a definite clue of which features are more useful for classification.



Weaknesses of Decision Trees

  • Decision tree models are often biased towards features having more possible values.
  • These model gets overfitted or under-fitted quite easily.
  • Decision trees are prone to errors in classification problems with many classes and a relatively small number of training examples.
  • A decision tree can be computationally expensive to train.
  • Large trees are complex to understand.



Python Code Example for Decision Trees

In this section, we'll go over some code snippets for creating and training Decision Trees and logging appropriate metrics and graphs to W&B. For more details, have a look at our docs.

Classification

run = wandb.init(project='...', entity='...', config = config)

X, y = load_iris(return_X_y=True)
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size = config.test_size, random_state = config.random_state)

clf = DecisionTreeClassifier(
max_depth=config.max_depth,
min_samples_split=config.min_samples_split,
criterion=config.clf_criterion,
splitter=config.splitter
)
clf = clf.fit(x_train, y_train)

y_pred = clf.predict(x_test)

# Visualize Confusion Matrix
wandb.sklearn.plot_confusion_matrix(y_test, y_pred, config.labels)

run.finish()

Run set
3


Regression

run = wandb.init(project='...', entity='...', config = config)

X, y = load_iris(return_X_y=True)

x_train, x_test, y_train, y_test = train_test_split(X, y, test_size = config.test_size, random_state = config.random_state)

reg = DecisionTreeRegressor(
max_depth=config.max_depth,
min_samples_split=config.min_samples_split,
criterion=config.reg_criterion,
splitter=config.splitter
)

reg = reg.fit(x_train, y_train)

# All regression plots
wandb.sklearn.plot_regressor(reg, x_train, x_test, y_train, y_test, model_name='DecisionTreeRegressor')

run.finish()

Run set
3


Conclusion

And that wraps up our short tutorial on Decision Trees. To see the full suite of W&B features please check out this short 5 minutes guide.

📚 Resources

🧾 Articles

🐍 Python Notebooks

📄 Research Papers

  • L. Breiman, J. Friedman, R. Olshen, and C. Stone. Classification and Regression Trees. Wadsworth, Belmont, CA, 1984.
Iterate on AI agents and models faster. Try Weights & Biases today.