Skip to main content

Time Series Classification on Weights & Biases With TSAI

In this article, we take a look at how to use the TSAI library with Weights & Biases to effectively tackle a multi-label time series classification problem.
Created on April 14|Last edited on January 27

This introduction will be a short one. Today, we're going to look at how you can use tsai to train a state-of-the-art classification model to detect abnormal electrocardiograms (ECG or EKG).

Table of Contents



Let's get started:

Time Series

Time series data is simply data that is recorded over some period of time. That can mean anything from stock value over time to rainfall in California to business data like subscriber count or sales. In our case today, we'll be looking at EKGs.
If you aren't familiar, an ECG records signals from a heart to monitor health or detect arrhythmias. EKGs are recorded by a machine that creates 12 signals at a given moment, with a sampling frequency defined by the machine's capabilities, usually between 200 and 50Hz. This means that we have a series of data points with 12 signals and as many points as the duration of the procedure times the sampling frequency.
If the ECG duration is 10 seconds and the sampling frequency is 50Hz, we will have 10 x 50 x 12 = 600 data points.
💡
A time series is just a vector, an image with a height equal to one.
For this example, we have 12 lead ECG signals, which are readings from 12 electrodes attached to a patient's body from the 2020 PhysioNet challenge. Here's how that challenge was framed:
Objective: The goal of the 2020 Challenge is to identify clinical diagnoses from 12-lead ECG recordings. We ask participants to design and implement a working, open-source algorithm that can based only on the clinical data provided, automatically identify the cardiac abnormality or abnormalities present in each 12-lead ECG recording.
This is a real-world, multi-label class problem: physicians use these signals to detect if a patient has one or more of 9 heart conditions.

Visualizing the Dataset 🧐

We can visualize the data using a wandb.Table where we logged the time-series input rendered as a plotly figure, the metadata of the procedure, and the corresponding diagnostic.



Data Analysis

Using wandb.Table, we can explore the data to understand the distribution of diseases. Here we grouped by:
✅ Sex: We see that the diagnosis distribution is not equal for male and female patients.
✅ Diagnosis: We see a distribution of conditions present in the dataset.

Run set
15

As you can see, the time series have different lengths. This is probably due to longer diagnosis times, different machines taking the ECG, and different practitioners. To be able to train a model on this data, we will need to first preprocess it and make equal-length samples.
There are multiple ways of homogenizing time series data––resampling is just one. We know that resampling changes the significance of important features such as heart beat. The purpose of this article is to demonstrate the use of neural networks with time series data, so beware of using this methodology for any medical analyses.
💡


Preprocessing

As seen in the previous section, the ECG signal lengths vary from 3,000 to 72,000, with the most common length being 5000.

Run set
15

Since simply padding signals with any numerical values would add a lot of unwanted noise to the signal, we decided to resample the signals instead. This was done using the scipy library:
import numpy as np
from scipy.signal import resample
from scipy.io import loadmat

resampled_signals = []

for file_path in files:
signal = loadmat(f"{file_path}")["val"] # Since this is 12-lead ECG, this is a 12 * (length of signal) matrix
resampled_signal = []
for si in signal:
resampled_signal.append(resample(si, 5000))
resampled_signal = np.array(resampled_signal)
resampled_signals.append(resampled_signal)
resampled_signals = np.array(resampled_signals)

np.save("downsampled_signals.npy", signals)
artifact = wandb.Artifact(name="preprocessed_dataset", type="dataset")
artifact.add_file("downsampled_signals.npy")

wandb.init(project="PhysioNet_Challenge")
wandb.log_artifact(artifact)
wandb.finish()
The artifact generated using this cell block is available here.

TSAI: The Time Series Deep Learning Library 📈

Time series has not traditionally been done with deep learning techniques. Enter TSAI:
tsai is an open-source deep learning package built on top of Pytorch & fastai focused on state-of-the-art techniques for time series tasks like classification, regression, forecasting, imputation...
We can create a baseline quickly using tsai. First we get the preprocessed dataset from the wandb.Artifact.
api = wandb.Api()
artifact = api.artifact('timeseriesbois/PhysioNet_Challenge/preprocessed_dataset:latest')
artifact.download()
The artifact contains the preprocessed dataset in pandas DataFrame format, where each row consist of one ECG sample.
X = np.load(dataset_path/"downsampled_signals.npy")
df = pd.read_csv(dataset_path/"labels.csv", index_col=0)
y = df.values

# sklearn friendly splitter ❤️
split = train_test_split(range_of(X), test_size=1000)
Here we start using tsai. We build our data loaders using get_ts_dls method. This function is very sklearn-inspired and expects a pair (X , y), along with the corresponding pair of transforms tfms. We don't need to do anything to X, so we pass None. We flag y as classification labels, so we pass TSClassification transform. This way tsai knows how y needs to be treated.
At the same time, we can pass batch transformation (batch_fms); these transformations are applied once the batch is formed (after the collation of individual samples and sent to the GPU). In our case, we pass the TSStandarize so the data get's scaled by the standard deviation and subtracted from the mean. We also pass the bs (batch size) and a split of indexes to generate a Train/Test split.
You can also pre compute mean and standard deviation and pass them to the TSStandarize transform.
tfms = [None, MultiCategorize()] # MultiCategorize is used with one-hot encoded labels
batch_tfms = TSStandardize()
dls = get_ts_dls(X, y,
splits=split,
tfms=tfms,
batch_tfms=batch_tfms,
bs=BS)
TSAI is built on top of fastai (😍), so to enable W&B logging we only need to pass the WandbCallback to the ts_learner and we are good to go.
We pass binary cross entropy as a loss function (read this article to understand more about BCE). At the same time, we pass a list of metrics to assess our model performance: accuracy_multi and F1. For the Physionet Challenge, the F1 metric was used to score submissions. The last part is calling to_fp16 into the ts_learner to enable Mixed Precision Training.
wandb_cb = WandbCallback()
learn = ts_learner(dls,
InceptionTimePlus,
metrics=[accuracy_multi, F1ScoreMulti(), RecallMulti()],
cbs=[wandb_cb]).to_fp16()

Evaluating the Model:

We will use the same metrics for the competition: Accuracy and F1. We will also add balanced accuracy (as this dataset is highly imbalanced).
We can observe that the InceptionTime model performs the best. We trained a higher learning rate version that achieved the best results.



Looking at Individual Model Predictions

We can look at the model's individual predictions in a wandb.Table. Here, we log the input signal (a plotly graph) and the predictions along the ground truth (label). We chose a fixed number of samples to log (N=36).


Run: royal-wood-23
1


Code to create this table and log it to your workspace

import plotly.graph_objects as go

#gather model predictions
inps, preds, tars = learn.get_preds(with_input=True)

def make_plot(signal):
"Plots the signal on plotly"
figure = go.Figure()
x = list(range(signal.shape[1]))
figure.add_trace(go.Scatter(x=x, y=signal[0], name=f"Channel 1"))

for i in range(1, 12):
figure.add_trace(go.Scatter(x=x, y=signal[i], name=f"Channel {i + 1}", visible="legendonly"))
return figure

def map_disease(t, th=0.5):
"maps interger to diseas names"
diseases = list(df.columns)
return [diseases[i] for i in np.where(p>th)[0]]

def log_preds(N=36, seed-42):
"Log `N` model predictions form the validation dataset"
columns = ["Input", "Label", "Prediction"]
table = wandb.Table(columns=columns)

wandbRandom = random.Random(42)
idxs = wandbRandom.sample(range_of(preds), k=N)

for inp,p,t in tqdm(zip(inps[idxs], preds[idxs], tars[idxs]), total=N):
figure = make_plot(inp)

row = [
wandb.Html(figure.to_html()),
map_disease(p),
map_disease(p)
]
table.add_data(*row)
with wandb.init(entity="timeseriesbois", project="PhysioNet_Challenge", job_type="log_predictions"):
wandb.log({"Predictions_Sample":table})


Conclusion

Here, we've created a full pipeline to train a state-of-the-art deep learning model for time series classification, going through all the steps necessary to ingest and preprocess the data and then train a neural network using the powerful tsai library. If you have any problems you're working on but aren't sure where to start, let us know in the comments. We're always keen to help!
Tim Whittaker
Tim Whittaker •  
metadata great! detaiil
Reply
Iterate on AI agents and models faster. Try Weights & Biases today.