Skip to main content

How to create and optimize decision trees

Comprehensive guide on Decision Trees, including theory, model building, evaluation, and advanced algorithms like Random Forest and XGBoost.
Created on February 23|Last edited on January 23


Table of Content



Understanding decision trees

Decision Trees are simple yet powerful machine learning tools that can assist us in making decisions or predictions, like figuring out if an email is spam or not, or predicting house prices. They work by asking a series of yes/no questions, similar to how we make decisions in everyday life, which makes them easy to understand and use. These models are great because they can work with different types of data, such as numbers or categories, and are used in many fields like finance and healthcare.
In this article, we'll take you through everything you need to know about Decision Trees, from the basics to using them in real projects with Weights & Biases, a popular tool for machine learning projects. We'll start with what Decision Trees are and how they work. Then, we'll get into the math that makes them work, like how they decide where to split the data. After that, we'll put what we've learned into practice by working on a project to predict who survived on the Titanic. Along the way, we'll learn how to make Decision Trees perform better and even look at some more advanced models related to Decision Trees, like Random Forest and XGBoost.

Theoretical Foundations of Decision Trees

Decision trees are a fundamental component of many machine learning algorithms, known for their simplicity and interpretability. They are used both for classification tasks, where the goal is to categorize inputs into two or more classes, and for regression tasks, where the goal is to predict a continuous value. Let's delve into the theoretical foundations of decision trees, including their structure, criteria for splitting, and the types of decision trees.

Structure of Decision Trees

A decision tree is composed of nodes, edges, and leaves, arranged in a tree-like structure.

Root Node: This is the top most node from where the decision-making starts.
Internal Nodes: These nodes test a particular attribute and branch based on the outcome of the test. Each internal node corresponds to a feature or attribute in the dataset.
Edges/Branches: These represent the outcome of a test and connect to the next node or leaf.
Leaf Nodes: These nodes represent a decision or final outcome. In classification trees, they represent the class label. In regression trees, they represent a continuous value.

The main purpose of a decision tree is to assist us reach a final conclusion based on some given criteria.
For example, the below decision tree would help us reach the conclusion of whether to go running or not on a specific day:

By navigating through each node's decision branches guides us to the answer to our query, ensuring a precise and informed conclusion.

Criteria for Splitting

For more complex decision trees, like those in machine learning, creating an accurate tree is tricky. Instead of simple yes or no decisions, these trees include multiple factors and a variety of numbers. The big challenge is figuring out the best way to split each branch of the tree.
That being said, there are several strategies for splitting each node to build a decision tree. The two primary methods employed are:

  • Information Gain: This is used in decision tree algorithms such as ID3 and C4.5. Information gain measures the reduction in entropy or impurity in the dataset after a dataset is split on an attribute. The goal is to maximize information gain - the higher the information gain, the more homogeneous the groups are.

Information Gain=Entropy (parent)−∑(Weight of child×Entropy (child))

  • Gini Impurity: Used in the CART (Classification and Regression Trees) algorithm, Gini impurity measures the frequency at which any element of the dataset will be incorrectly labeled if it was randomly labeled according to the distribution of labels in the subset. The tree will choose the split that has the lowest Gini impurity.

Gini Impurity=1−∑(pi)2

Where pi is the probability of an item being classified into a specific class.

Decision Tree Node Split Example



Step 1: Calculate Overall Impurity (Gini for Classification)

First, we calculate the Gini impurity of the entire dataset. Assuming "Yes" and "No" are the only two classes for Buys:

  • Total instances: 14
  • Yes instances: 9
  • No instances: 5

Gini_overall=1−((9/14)2+(5/14)2)≈ 0.459


Step 2: Calculate Impurity for Each Split

Next, we consider how to split the data based on Age and Income.

Split on Age
  • Young: 5 instances (2 Yes, 3 No)
  • Middle: 4 instances (4 Yes, 0 No)
  • Old: 5 instances (3 Yes, 2 No)

Weighted average Gini for Age split ≈ 0.343

Calculate the Gini for each income group and then calculate the weighted average Gini for the split on Income.

Split on Income
  • High: 4 instances (2 Yes, 2 No)
  • Medium: 6 instances (3 Yes, 3 No)
  • Low: 4 instances (2 Yes, 2 No)

Weighted average Gini for Income split ≈ 0.500

Calculate the Gini for each income group and then calculate the weighted average Gini for the split on Income.


Step 3: Determine the Best Split

The best split is the one that results in the lowest weighted average Gini impurity, indicating a greater reduction in uncertainty about the Buys outcome. Thus the best split would be the age, with the lowest weighted average Gini impurity of 0.343.

Example Calculation for Split on Age
Let's calculate the Gini impurity for the Young group as an example:
GiniYoung=1((2/5)2+(3/5)2)GiniYoung=1−((2/5)2+(3/5)2)


Then, you would calculate the Gini for Middle and Old similarly, followed by the weighted average Gini for the entire Age split.



The Dataset Used



For the practical portion of this exploration, we will dive into the Titanic dataset, a rich and widely studied dataset within the machine learning community.
The Titanic dataset comprises 12 distinct columns, each offering insights into the passengers aboard the ill-fated Titanic voyage. Among these columns are critical attributes such as passenger class (Pclass), name, sex, age, siblings/spouses aboard (SibSp), parents/children aboard (Parch), ticket number, fare, cabin number, and port of embarkation. Most importantly, it includes a 'Survived' column, which indicates whether a passenger survived the disaster, with 1 representing survival and 0 representing non-survival.
This dataset serves as a foundational benchmark for classification problems in machine learning. By analyzing these attributes, we aim to build a decision tree model that can predict the survival outcome of passengers. Such models help us understand the key factors that contributed to the likelihood of survival during this tragic event.

Preparing Your Dataset

Step 1: Importing Necessary Libraries

First, we will begin by importing some essential Python libraries for the purposes of data manipulation, visualization, machine learning modeling, and evaluation. This step lays the foundation for all subsequent data analysis and model building activities.
import seaborn as sns
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
from sklearn.utils import shuffle
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import cross_val_score


Step 2: Importing and Initizlaing W&B

import wandb

# Initialize a new run
wandb.init(project='titanic_decision_tree', entity='Enter your W&B user here')

Step 3: Loading the Dataset

Next, we will load the Titanic dataset from a CSV file into a Pandas DataFrame. This step is crucial for making the dataset available in a structured form for analysis and modeling.
titanic = pd.read_csv('/kaggle/input/test-file/tested.csv')


Step 4: Preprocessing the Data - Handling Missing Values

We will be filling the missing values in the 'Fare' and 'Age' columns with their median values to address data incompleteness. Removing or imputing missing values is essential for the smooth functioning of most machine learning algorithms.
titanic['Fare'].fillna(titanic['Fare'].median(), inplace=True)

titanic['Age'].fillna(titanic['Age'].median(), inplace=True)

titanic.drop('Cabin', axis=1, inplace=True)


Step 5: Dropping Unnecessary Columns

We will be removing some columns that are not needed for the model, such as 'Cabin', 'Name', 'Ticket', and 'PassengerId'. This step focuses on simplifying the model by eliminating features that are unlikely to be informative for predicting the target variable.
titanic.drop(['Name', 'Ticket', 'PassengerId'], axis=1, inplace=True)


Step 6: Encoding Categorical Variables

Moving on, we will convert the categorical variables ('Embarked', 'Sex') into a numeric format through one-hot encoding and label encoding, respectively. Machine learning algorithms require numerical input, so this step converts categorical data into a suitable format.
embarked_dummies = pd.get_dummies(titanic['Embarked'], prefix='Embarked')
titanic = pd.concat([titanic, embarked_dummies], axis=1)
titanic.drop('Embarked', axis=1, inplace=True)

le = LabelEncoder()
titanic['Sex'] = le.fit_transform(titanic['Sex'])


Step 7: Feature Engineering

To complete our data processing, we will create a new feature 'family_size' by combining 'SibSp' and 'Parch'. This step involves generating new features from existing ones to potentially uncover additional insights that improve model performance.
titanic['family_size'] = titanic['SibSp'] + titanic['Parch'] + 1


Step 8: Preparing Features and Target

We will then separate the dataset into features (X) and the target variable (y). This foundational step organizes the data for the subsequent training and testing phases.
X = titanic.drop('Survived', axis=1)
y = titanic['Survived']


Step 9: Shuffling the Labels

Before last, we will shuffle the target variable 'Survived' to prepare a baseline for understanding model performance on randomized data. This step helps in assessing the model's ability to capture genuine patterns versus memorizing the data.
y_shuffled = shuffle(y, random_state=42)


Step 10: Splitting Dataset into Training and Testing Sets

Lastly, for our data processing steps, we will divide the dataset into training and testing sets to evaluate the model's performance on unseen data. This critical step helps in assessing the generalization capability of the model.
X_train_shuffled, X_test_shuffled, y_train_shuffled, y_test_shuffled = train_test_split(X, y_shuffled, test_size=0.2, random_state=42)


Building The Decision Tree Model

Step 11: Model Training and Evaluation on Shuffled Data

In this step, we will be training a Decision Tree Classifier with varying 'max_depth' on the shuffled labels and evaluate its accuracy on both the training and testing sets. This process involves iteratively adjusting model complexity to study its effect on performance, especially to identify overfitting or underfitting tendencies.
max_depth_range = range(1, 3)
shuffled_accuracies = []
cv_shuffled_accuracies = []

for depth in max_depth_range:
clf = DecisionTreeClassifier(max_depth=depth, random_state=42)
# Train and evaluate on shuffled labels
clf.fit(X_train_shuffled, y_train_shuffled)
shuffled_accuracy = clf.score(X_test_shuffled, y_test_shuffled)
shuffled_accuracies.append(shuffled_accuracy)
# Cross-validation on shuffled labels
cv_shuffled_accuracy = cross_val_score(clf, X, y_shuffled, cv=5).mean()
cv_shuffled_accuracies.append(cv_shuffled_accuracy)

Visualizing the Decision Tree

Step 12: Logging Our Data Intor W&B

# Log metrics to Weights & Biases
wandb.log({'max_depth': depth,
'training_accuracy': shuffled_accuracy,
'cv_accuracy': cv_shuffled_accuracy})


We have logged the training and accuracy graphs into Weights and Biases.


Step 13: Visualizing Model Performance

In order to evaluate our model, we will plot the model's training and cross-validation accuracies against tree depth to visually assess how model complexity influences accuracy. This visualization aids in identifying the optimal model complexity for balancing bias and variance.
plt.figure(figsize=(10, 6))
plt.plot(max_depth_range, shuffled_accuracies, label='Shuffled Training Accuracy')
plt.plot(max_depth_range, cv_shuffled_accuracies, label='Shuffled CV Accuracy')
plt.xlabel('Tree Depth')
plt.ylabel('Accuracy')
plt.title('Model Complexity vs. Accuracy on Shuffled Data')
plt.legend()
plt.show()


Evaluating Model Performance

Step 14: Reporting Final Model Accuracy

As our last step, we will print the final training and cross-validation accuracy of the model trained on shuffled data. This step summarizes the model's performance, providing insights into its effectiveness in predicting outcomes based on the shuffled labels.
final_shuffled_accuracy = shuffled_accuracies[-1]
print(f'Final Training Accuracy on Shuffled Data (Depth {max_depth_range.stop - 1}): {final_shuffled_accuracy:.2f}')

final_cv_shuffled_accuracy = cv_shuffled_accuracies[-1]
print(f'Final CV Accuracy on Shuffled Data (Depth {max_depth_range.stop - 1}): {final_cv_shuffled_accuracy:.2f}')

Output:
Final Training Accuracy on Shuffled Data (Depth 9): 0.60
Final CV Accuracy on Shuffled Data (Depth 9): 0.57


Advanced Decision Tree Algorithms

Random Forest

First on our list is Random Forests. Random Forest is an advanced decision tree algorithm that creates a 'forest' of multiple decision trees during the training process. Instead of relying on a single decision tree, Random Forest generates many trees, each trained on random subsets of the data and features.
When it's time to make a prediction, each tree in the forest votes, and the most common outcome (for classification tasks) or the average outcome (for regression tasks) becomes the final prediction. This approach improves accuracy and reduces the risk of overfitting by averaging the results of multiple trees, making Random Forest a powerful and versatile machine-learning model.
Below is the code replacement of steps 11, 13, and 14, in which we utilize the more complex Random Forests instead of the more simple decision tree algorithm.

Step 11: Model Training and Evaluation on Shuffled Data

from sklearn.ensemble import RandomForestClassifier

n_estimators_range = range(1, 3, 1)
rf_shuffled_accuracies = []
rf_cv_shuffled_accuracies = []

for n_estimators in n_estimators_range:
rf_clf = RandomForestClassifier(n_estimators=n_estimators, random_state=42)
rf_clf.fit(X_train_shuffled, y_train_shuffled)
shuffled_accuracy = rf_clf.score(X_test_shuffled, y_test_shuffled)
rf_shuffled_accuracies.append(shuffled_accuracy)
cv_shuffled_accuracy = cross_val_score(rf_clf, X, y_shuffled, cv=5).mean()
rf_cv_shuffled_accuracies.append(cv_shuffled_accuracy)


Step 13: Visualizing Model Performance

plt.figure(figsize=(10, 6))
plt.plot(list(n_estimators_range), rf_shuffled_accuracies, label='Shuffled Training Accuracy - RF')
plt.plot(list(n_estimators_range), rf_cv_shuffled_accuracies, label='Shuffled CV Accuracy - RF')
plt.xlabel('Number of Estimators')
plt.ylabel('Accuracy')
plt.title('Random Forest Complexity vs. Accuracy on Shuffled Data')
plt.legend()
plt.show()


Step 14: Reporting Final Model Accuracy

final_rf_shuffled_accuracy = rf_shuffled_accuracies[-1]
final_rf_cv_shuffled_accuracy = rf_cv_shuffled_accuracies[-1]
print(f'Final Training Accuracy on Shuffled Data (RF, Estimators {list(n_estimators_range)[-1]}): {final_rf_shuffled_accuracy:.2f}')
print(f'Final CV Accuracy on Shuffled Data (RF, Estimators {list(n_estimators_range)[-1]}): {final_rf_cv_shuffled_accuracy:.2f}')

Final Training Accuracy on Shuffled Data (XGB, Max Depth 9): 0.69
Final CV Accuracy on Shuffled Data (XGB, Max Depth 9): 0.64


XGBoost

Moving on we have the XGBoost algorithm. XGBoost stands for eXtreme Gradient Boosting and is an efficient and scalable implementation of gradient boosting. It works by sequentially adding predictors (decision trees), where each new tree corrects errors made by the previously trained trees. The "gradient" part refers to the algorithm's use of gradient descent to minimize the loss when adding new models.
XGBoost is known for its speed and performance, particularly in structured or tabular data competitions. It handles large datasets well, offers built-in regularization to avoid overfitting, and has been the winning algorithm in many machine-learning competitions. XGBoost is versatile, being used for classification, regression, ranking, and user-defined prediction problems.
Below is the code replacement of steps 11, 13, and 14, in which we utilize the more complex XGBoost algorithm instead of the more simple decision tree algorithm.


Step 11: Model Training and Evaluation on Shuffled Data

import xgboost as xgb

max_depth_range = range(1, 31)
xgb_shuffled_accuracies = []
xgb_cv_shuffled_accuracies = []

for max_depth in max_depth_range:
xgb_clf = xgb.XGBClassifier(max_depth=max_depth, n_estimators=300, use_label_encoder=False, eval_metric='logloss', random_state=42)
xgb_clf.fit(X_train_shuffled, y_train_shuffled)
shuffled_accuracy = xgb_clf.score(X_test_shuffled, y_test_shuffled)
xgb_shuffled_accuracies.append(shuffled_accuracy)
cv_shuffled_accuracy = cross_val_score(xgb_clf, X, y_shuffled, cv=5).mean()
xgb_cv_shuffled_accuracies.append(cv_shuffled_accuracy)


Step 13: Visualizing Model Performance

plt.figure(figsize=(10, 6))
plt.plot(list(max_depth_range), xgb_shuffled_accuracies, label='Shuffled Training Accuracy - XGB')
plt.plot(list(max_depth_range), xgb_cv_shuffled_accuracies, label='Shuffled CV Accuracy - XGB')
plt.xlabel('Max Depth')
plt.ylabel('Accuracy')
plt.title('XGBoost Complexity vs. Accuracy on Shuffled Data')
plt.legend()
plt.show()


Step 14: Reporting Final Model Accuracy

final_xgb_shuffled_accuracy = xgb_shuffled_accuracies[-1]
final_xgb_cv_shuffled_accuracy = xgb_cv_shuffled_accuracies[-1]
print(f'Final Training Accuracy on Shuffled Data (XGB, Max Depth {list(max_depth_range)[-1]}): {final_xgb_shuffled_accuracy:.2f}')
print(f'Final CV Accuracy on Shuffled Data (XGB, Max Depth {list(max_depth_range)[-1]}): {final_xgb_cv_shuffled_accuracy:.2f}')
Final Training Accuracy on Shuffled Data (RF, Estimators 2): 0.67
Final CV Accuracy on Shuffled Data (RF, Estimators 2): 0.63

As we can see, the two advanced algorithms return better performance when compared with the more plain decision tree algorithm. While the difference here may not seem too big to be seen, in a more complex dataset the difference between such models and the normal decision trees will be much more clearer.

Summary

To sum up, decision trees offer a solid foundation for anyone looking to delve into machine learning, providing a blend of simplicity and powerful predictive capabilities.
Throughout this article, we've covered the basics of how these models work, their application in real-world datasets like the Titanic, and ways to enhance their performance with advanced techniques and algorithms. Embracing Decision Trees not only equips you with a versatile tool for data analysis but also paves the way for tackling more complex models, ensuring a broad understanding of machine learning's potential and applications.