Decision Trees: A Guide with Examples

A tutorial covering Decision Trees, complete with code and interactive visualizations . Made by Saurav Maheshkar using Weights & Biases
Saurav Maheshkar

Quick Start Colab

Table of Contents (click to expand)

🌲What Are Decision Trees in Machine Learning?

A decision tree is a hierarchical data structure (i.e. tree-like) implementing a divide-and-conquer approach to machine learning. It's an efficient nonparametric supervised method, which can be used for both classification and regression.
In parametric estimation, we define a model over the whole input space and learn its parameters from all of the training data. Then we use the same model and the same parameter set for any test input. In non-parametric estimation, we divide the input space into local regions, defined by distance measures.
A decision tree is a non-parametric model in the sense that we do not assume any parametric form for the class densities and the tree structure is not fixed a priori but the tree grows, branches and leaves are added, during learning depending on the complexity of the problem inherent in the data.
A tree 🌲 is a set of one or more nodes. A node x_i, subscript i being either empty or a sequence of one or more non-negative integers, is joined to another node x_{ij} by an arc directed from x_i to x_{ij}.

Figure 1: What a Tree looks like.
Today, we'll go over Decision Trees, one of the most classic deep learning models. We'll also learn how to log and visualize them using wandb.

βœ‚οΈ Where to Split Your Decision Tree

A decision at a particular node is known as a split since we're essentially splitting our input data distribution into multiple sub-sets. The root (top-most) node represents the entire training data. Each split thereafter represents some split in order to aid decision making. This decision is usually made based on some sort of rule / impurity measure. If a leaf πŸ€ is "pure" that basically means that the leaf just represents a single class. But on the other hand if the leaf is impure, then it possibly contains multiple classes.
Here, we go over some of the rules/criterion typically used in decision trees:

Information Gain

One such measure from Information Theory which can be used is known as Entropy. It's basically the measure of the randomness in data points, and is given by the following equation
Entropy = -\sum_{i=1}^{n}p_i * log_2(p_i)
where p_i represents the probability that any particular data point belongs to a particular node.
Our overall goal is to reduce the entropy of the system. This is done using another metric called "Information Gain (IG)". As the name suggests, Information Gain gives us some idea about how informative a particular split is. It's calculated by subtracting the entropy of a particular attribute from the entropy of the whole data distribution.
Gain(D, A) = Entropy(D) - Entropy(D, A)
In other words, IG tells us how much one would "gain" by branching on some particular attribute A. Thus, the splitting attribute at any particular node, is the one with the highest gain.

Gini Index

Gini \, Index = 1 - \sum_{i-1}^{n} (P_i)^2

Reduction in Variance

Variance = \frac{\sum (X - \bar{X})^2}{n}


Now deciding which measure to use depends on the problem. For example, Gini Index or Information Gain is used for classification problems whereas other typical measures like Root Mean Square (RMS) or Mean Absolute Error (MAE) is used for regression problems.

πŸ’‡πŸ» Pruning Your Decision Tree

Frequently, a node is not split further if the number of training instances reaching a node is smaller than a certain percentage of the training set. The idea is that any decision based on too few instances causes variance and thus generalization error. Stopping tree constructs early on before it is full is called pre-pruning the tree.
Another possibility to get simpler trees is post-pruning, which in practice works better than pre-pruning. Tree growing is greedy and at each step, we make a decision, and continue further on, never backtracking and trying out an alternative. The only exception is post-pruning where we try to find and prune unnecessary subtrees.

Decision Trees, Post-Pruning

In post-pruning, we grow the tree full until all the leaves are pure and we have no training error. We then find subtrees that cause overfitting and we prune them. From the initial labeled set, we set aside a pruning set, unused during training. For each subtree, we replace it with a leaf node labeled with the training instances covered by the subtree. If the leaf node does not perform worse than the subtree on the pruning set, we prune the subtree and keep the leaf node because the additional complexity of the subtree is not justifies; otherwise, we keep the subtree.
Comparing pre-pruning and post-pruning, we can say that pre-pruning is faster but post-pruning generally leads to more accurate trees

πŸ’ͺ🏻 Strengths of Decision Trees

πŸ€’ Weaknesses of Decision Trees

πŸ§‘πŸΌβ€πŸ’» Python Code Example for Decision Trees

In this section we'll go over some code snippets for creating and training Decision Trees and logging appropriate metrics and graphs to wandb. For more details, have a look at our docs.


run = wandb.init(project='...', entity='...', config = config)X, y = load_iris(return_X_y=True)x_train, x_test, y_train, y_test = train_test_split(X, y, test_size = config.test_size, random_state = config.random_state)clf = DecisionTreeClassifier( max_depth=config.max_depth, min_samples_split=config.min_samples_split, criterion=config.clf_criterion, splitter=config.splitter)clf =, y_train)y_pred = clf.predict(x_test)# Visualize Confusion Matrixwandb.sklearn.plot_confusion_matrix(y_test, y_pred, config.labels)run.finish()


run = wandb.init(project='...', entity='...', config = config)X, y = load_iris(return_X_y=True)x_train, x_test, y_train, y_test = train_test_split(X, y, test_size = config.test_size, random_state = config.random_state)reg = DecisionTreeRegressor( max_depth=config.max_depth, min_samples_split=config.min_samples_split, criterion=config.reg_criterion, splitter=config.splitter)reg =, y_train)# All regression plotswandb.sklearn.plot_regressor(reg, x_train, x_test, y_train, y_test, model_name='DecisionTreeRegressor')run.finish()


And that wraps up our short tutorial on Decision Trees. To see the full suite of wandb features please check out this short 5 minutes guide.

πŸ“š Resources

🧾 Articles

🐍 Python Notebooks

πŸ“„ Research Papers

πŸ’Ώ Dataset Links