Skip to main content

Fine-Tuning Whisper ASR Models

This short article comprises a work log and guide for fine-tuning whisper automatic speech recognition (ASR) models — using W&B progress callback along the way.
Created on December 9|Last edited on July 9
Automatic speech recognition(ASR) is the task of turning a speech signal into the corresponding text, known as the transcript of the speech. In recent times deep learning has been widely adopted to this task due to the availability of large datasets that contain speech and corresponding transcripts.
However, most deep learning algorithms are data-hungry and can take a very long time to train. A recent trend in specifically in this domain of research has been self-supervised learning. For instance, the Wave2Vec family of models use a task similar to Masked Language Modelling to pre-train a network using unlabelled speech before fine-tuning on the ASR task. This allows the network to learn contextual speech embeddings that can then be finetuned using a parallel corpora of speech and transcripts. Whisper is as ASR model released by OpenAI. Unlike other recent speech recognition models, the model is completely trained using a supervised learning on weakly labelled data. The researchers gathered 680k hours of multilingual speech and transcription data in 96 languages. It is an Encoder-decoder transformer trained using multi-task learning with tasks that include transcription, translation and timestamp prediction. The following image gives an overview of the architecture of the model.
Architecture of the Whisper Model: Source: https://openai.com/blog/whisper/

Background

The awesome folks at huggingface ported the Whisper model into the transformers library within a few weeks of the model's release. However, the model is not without it's limitations. The authors of the Whisper paper detail the following limitations (AKA hackathon ideas).
  • Inaccurate timestamp predictions
  • Hallucinations
  • Low-performance on low resource languages
  • No speaker recognition
  • No real-time transcription
This leads us to the Whisper Fine-tuning event run be huggingface and lambdalabs. The event is organized as a community sprint from Dec 5 to Dec 19. With huggingface providing the models, starter code and community support via discord and lambda labs providing the necessary compute (roughly 100 hours of compute to use A100 GPUS)
The main components of the sprint include:
  • Open AI’s state-of-the-art Whisper model
  • Public datasets like Common Voice 11, VoxPopuli, CoVoST2 and more
  • Real-world audio for evaluation
With the key outcomes being:
  • Fine-tuned Whisper checkpoint (e.g. Whisper-large)
  • Evaluation script for your fine-tuned checkpoint
  • Hugging Face space to demo your fine-tuned model

Finetuning for Tamil

As an machine learning practitioner I have been at awe at the rapid increase in performance of ASR systems. However, it's no secret that most recent ASR models target languages like English where large amounts of annotated speech data is readily available. But this also made me quite unhappy since the performance of these models on my native tongue "Tamil" has been quite poor so far. I decided to participate in the huggingface sprint to change this. To do plan to I collect over ~1000 hours of annotated speech data in Tamil from various data sources and fine-tune the whisper models for the ASR task.

Data collection

I used the following data sources:
  • Common Voice 11 - ~392 hours of speech with 58% validated with 811 speakers
  • Fleurs: ~10 hours of validated Speech from 76 speakers
  • OpenSLR65: ~7 Hours of Validated Speech with 50 speakers
  • OpenSLR127: ~150 hours of read speech collected from 531 speakers
  • UCLA Corpus: ~1000 hours of speech from various sources(News, radio, contributions, etc)
Common Voice 11, Fleurs and OpenSLR65 are available in huggingface datasets . However to bridge the gap I also downloaded and ported the OpenSLR127 and UCLA Corpus in to the datasets hub. While the OpenSLR127 corpora was easily downloadable from the opensrl website, downloading the UCLA corpus was a huge challenge since I couldn't find any resource to download the corpus from the official website. Luckily, I found this awesome repo that contained links to various ASR datasets for indic languages. I used this as a starting point to download the datasets for Tamil ASR data. I used the following code download, clean and preprocess the dataset. The code uses multi-threading to perform IO-bound operations (downloading, and reading files) and uses multi-processing to perform compute operations (verifying and converting audio).

Click to expand


The script took nearly half a day to run and I ended up collecting roughly half a million files with audio and transcripts.

Dataset Creation

With the main data downloaded I needed a way to create a huggingface dataset that can be easily combined with the other hf-datasets to create a large enough monolingual corpus to fine-tune the model in Tamil. The datasets library provides 3 options to read and share custom audio datasets.
  1. Create an audio dataset from local files. This is an easy way that requires only a few steps in python.
  2. Create an audio dataset repository with the AudioFolder builder. This is a no-code solution for quickly creating small dataset to experiment with.
  3. Create an audio dataset by writing a loading script. This requires more effort and coding, but provides greater flexibility over how a dataset is defined, downloaded, and generated.
Since I was working in multiple systems (colab, kaggle, AWS, lambdalabs), I knew that the created datasets need to be accessible across systems.

Gotchas and challenges

  • The 1st method is relatively easy. All we need is do is to load the dataset from python objects such as as dictionary or pandas DataFrame with two columns (path, sentence) and create the dataset using the class methods Dataset.from_dict() or Dataset.from_pandas() and cast the audio to a Audio datatype. For example:
audio_dataset = Dataset.from_dict(
{"audio": ["path/to/audio_1", "path/to/audio_2", ..., "path/to/audio_n"]})
audio_dataset = audio_dataset.cast_column("audio", Audio())

## or alternatively
df = pd.DataFrame({"audio": ["path/to/audio_1", "path/to/audio_2", ..., "path/to/audio_n"]})
audio_dataset = Dataset.from_pandas(df)
audio_dataset = audio_dataset.cast_column("audio", Audio())
And then `audio_dataset.push_to_hub() , However, this has one disadvantage. The paths are linked to the systems the dataset was uploaded from. This means in order to reuse the dataset across systems one needs to download the audio files across these systems. Obviously, this is not an option when you are dealing with roughly half a million files.
  • The next option that looked promising was to use AudioFolder(). To use this we need to organize the audio files in the following directory structure.
metadata.csv
data/first_audio_file.mp3
data/second_audio_file.mp3
data/third_audio_file.mp3
and create a metadata.csv or metadata.json file with two specific column names.
file_name,transcription
data/first_audio_file.mp3,<transcript1 ...>
data/second_audio_file.mp3,<transcript2 ...>
data/third_audio_file.mp3,<transcript3 ...>
And then call the following code to load the AudioDataset.
from datasets import load_dataset
dataset = load_dataset("audiofolder", data_dir="/path/to/data",)
When I initially tried this, I loaded only one record from my half-million records. I dug a little further into the dataset loading code and found that audiofolder configuration analysed files for patterns such as "train", "dev" etc. I had one file that was roughly looked something like this <restoffilename>dev.mp3 and noticed that this was the only file being loaded into my dataset. This gave me the idea to change the dataset structure to the following:
metadata.csv
data/train/first_audio_file.mp3
data/train/second_audio_file.mp3
data/train/third_audio_file.mp3
I also changed the loading command to the following
from datasets import load_dataset
dataset = load_dataset("audiofolder", data_dir="/path/to/data", drop_labels=True)
This finally loaded the datasets with all my 500k records. I proceeded to push this to the hub using `dataset.push_to_hub() command.
Now that it was time to reuse the uploaded dataset in colab I ran the `load_dataset() command with my hub dataset path. Alas, the loading failed with the following related issue. The core takeaway being mp3 files cannot be loaded by librosa or soundfile libraries in linux systems due to licensing constraints and libsndfile not supporting mp3 currently. This meant I either need to convert the mp3 files to wav(supported by soundfile) or create a data loading script. I chose the latter since the converting half a million mp3 files to wav was not an option considering wav takes 100x the disks space.
  • I finally created a data loading script following this tutorial on huggingface. While this took some coding and file organization I found this to be the best option. Here's an example of the data loading script I created for the UCLA dataset.

Click to expand

Streaming Datasets

Next I began to fine-tune the openai/whisper-tiny on my custom dataset that contained all the datasets I had gathered using colab. I created a function to interleave the datasets in streaming mode. Here's my function:

Click to expand


I ran the fine-tuning on colab with GPU enabled and notice that the training was extremely slow. It took 3 hours to process ~10 batches on a P100 GPU. To check if the issue was related to training or dataset streaming I re-ran the fine-tuning without streaming on a subset of the dataset and noticed that the model trained quite fast (roughly 1 hour to run 1000 steps). This convinced me that the slow training progress was related to streaming. Each of the datasets are hosted in different regions and network traffic and speed became a huge bottleneck. So, I concatenated all the datasets and converted them into a single dataset. This dataset contains ~1100 hours of training data and ~20 hours of test data.
However, streaming the single datasets also proved to be a bottleneck in training. So, I cloned the dataset into colab by doing the following
!git lfs install
!git clone https://huggingface.co/datasets/parambharat/tamil_asr_corpus
This effectively hosted the dataset locally but allowed me to stream the dataset for training. But why stream when you have the dataset hosted locally ? you ask. Well, the datasets library allows you to map preprocessing functions over datasets. This helps preprocess large datasets quickly using the Arrow backend. However, this also generates huge cache files, I had over 700k records which when preprocessed into input and label tensors created cache files that quickly killed my colab due to lack of disk space. While we can disable caching by calling the `datasets.disable_caching() method, it also meant that I needed to wait for the processing to complete each time I ran the notebook. However, when streaming from local files I noticed that there were effectively no cache files and that the training time was quite similar to the quick test I had done with the non-streaming dataset other than the small issue that the first few batches took a few minutes to load.

Finetuning Whisper

During my initial experiments fine-tuning whisper based on the code in the blog - "Fine-Tune Whisper For Multilingual ASR with 🤗 Transformers" I found that model quickly overfits on mono-lingual data. This prompted me to augment the audio with some noise. I used the audiomentations library to perform a few augmentations before conversion to spectrograms using the feature extractor. Here's the sample code I used to perform the augmentations.
augment_waveform = Compose([
AddGaussianNoise(min_amplitude=0.005, max_amplitude=0.015, p=0.2),
TimeStretch(min_rate=0.8, max_rate=1.25, p=0.2, leave_length_unchanged=False),
PitchShift(min_semitones=-4, max_semitones=4, p=0.2)
,])

def augment_dataset(batch):

audio = batch["audio"]["array"]
# apply augmentation
augmented_audio = augment_waveform(samples=audio, sample_rate=16000)

batch["audio"]["array"] = augmented_audio

return batch


# call augment dataset on the training set
dataset_dict["train"] = dataset_dict["train"].map(augment_dataset)
Next I logged a few examples from the common_voice_11 dataset to a wandb table to investigate the samples. Here's the code I used.
def convert_dataset_to_table(dataset):
records = []
for item in dataset:
record = {}
record["audio"] = wandb.Audio(item["audio"]["array"], 16000)
record["sentence"] = item["sentence"]
records.append(record)
records = pd.DataFrame(records)
records = wandb.Table(dataframe=records)
return records

eda_table = convert_dataset_to_table(dataset)
run = wandb.init(project="whisper_finetuning", job_type="eda",)
run.log({"sample_commonvoice": eda_table})
run.finish()

The resulting eda table is below.


Key Insights
  • The Audios are quite short in length (>10s) when compared to whisper's 30s limit.
  • There is little or no noise in most of the audio sample, this means the augmentations are going to help in real world performance.
  • The sentence are not normalized for punctuation, while whisper is trained to predict punctuation it also means more effort in getting the right difference between `!,`?`,`.` .
  • Some sentences begin and end with a `" punctuation.
To address these issues added a few preprocessing and normalization steps to the pipeline. Here's the code.
def fix_sentence(sentence):
transcription = sentence
if transcription.startswith('"') and transcription.endswith('"'):
# we can remove trailing quotation marks as they do not affect the transcription
transcription = transcription[1:-1]
if transcription[-1] not in [".", "?", "!"]:
# append a full-stop to sentences that do not end in punctuation
transcription = transcription + "."
transcription = transcription[:-1].translate(str.maketrans('', '', string.punctuation)) + transcription[-1]
return transcription
def prepare_dataset(examples):
# compute log-Mel input features from input audio array
audio = examples["audio"]
examples["input_features"] = feature_extractor(
audio["array"], sampling_rate=16000).input_features[0]
sentences = fix_sentence(examples["sentence"])
# encode target text to label ids
examples["labels"] = tokenizer(sentences, max_length=225, truncation=True).input_ids
return examples

Next, I has also identified that initially the model was not generating the <|endoftranscript|> token leading to word-error rates way over a 1000. I ran a few debugging steps by generating a few examples using the fine-tuned model and found that this was due to an error in the tokenizer of the v4.5 version of the transformers library. The tokenizer was not generating the right initial special tokens for the given language. Since the model was being fine-tuned it was important to ensure that the tokenizer performed the same tokenization. After bringing this up on the discord channel I found out that the the dev branch of the library was updated to fix this issue. So I uninstalled the transformers library and installed the latest version using `pip install git+https://github.com/huggingface/transformers this resulted in the right tokenization.
Another issue I faced was my colab kept crashing after a few hours of running. These issues prompted me to.
  1. Track training progress by visualizing the errors that the model was generating as it logged metrics
  2. Save the model periodically to a wandb.Artifact so that training runs could be resumed.
To achieve these steps I built a Custom Callback that visualized the training progress every log steps. The callback built on top of the WandbCallback does this by first sampling a small part of the dataset and then logging a table of the predictions and sentences along with the WER metric for each sentence. It also adds the different types of error (substitutions, insertions, deletions) etc to the table. The callback also saves the model as a artifact every save steps. These enabled me to resume training whenever colab crashed. Here's the main chunk of code from the callback.

The Progress callback

class WandbProgressResultsCallback(WandbCallback):
def __init__(self, trainer, sample_dataset):
super().__init__()
self.trainer = trainer
self.sample_dataset = sample_dataset
self.records_df = dataset_to_records(sample_dataset)
def on_log(self, args, state, control, model=None, logs=None, **kwargs):
super().on_log(args, state, control, model, logs)
predictions = trainer.predict(self.sample_dataset)
predictions = decode_predictions(self.trainer, predictions)
measures_df = compute_measures(predictions, self.records_df["sentence"].tolist())
records_df = pd.concat([self.records_df, measures_df], axis=1)
records_df["prediction"] = predictions
records_df["step"] = state.global_step
records_table = self._wandb.Table(dataframe=records_df)
self._wandb.log({"sample_predictions": records_table})
def on_save(self, args, state, control, model=None, tokenizer=None, **kwargs):
if self._wandb is None:
return
if self._log_model and self._initialized and state.is_world_process_zero:
with tempfile.TemporaryDirectory() as temp_dir:
self.trainer.save_model(temp_dir)
metadata = (
{
k: v
for k, v in dict(self._wandb.summary).items()
if isinstance(v, numbers.Number) and not k.startswith("_")
}
if not args.load_best_model_at_end
else {
f"eval/{args.metric_for_best_model}": state.best_metric,
"train/total_floss": state.total_flos,
}
)
artifact = self._wandb.Artifact(
name=f"model-{self._wandb.run.id}",
type="model", metadata=metadata)
for f in Path(temp_dir).glob("*"):
if f.is_file():
with artifact.new_file(f.name, mode="wb") as fa:
fa.write(f.read_bytes())
self._wandb.run.log_artifact(artifact)

With these steps I was able to sit back, relax and monitor the training progress. Here's the table generated by the callback

audio_with_spec
sentence
prediction
wer
hits
substitutions
deletions
insertions
step
1
2
3
4
5
Finally, here's the training and eval logs for my
`


And the model generated model artifact.

model-2k10w4qq
Artifact overview
Type
model
Created At
December 8th, 2022
Description
Versions
Version
Aliases
Logged By
Tags
Created
TTL Remaining
# of Consuming Runs
Size
m.eval/wer
m.train/total_floss
5
latest
v5
Fri Dec 09 2022
Inactive
0
151.1MB
30.103
15756086476800000000
4
v4
Fri Dec 09 2022
Inactive
0
151.1MB
30.183
14180477829120000000
3
v3
Fri Dec 09 2022
Inactive
0
151.1MB
30.516
12604869181440000000
2
v2
Fri Dec 09 2022
Inactive
0
151.1MB
31.134
11029260533760000000
1
v1
Thu Dec 08 2022
Inactive
0
151.1MB
31.4
9453651886080000000
0
v0
Thu Dec 08 2022
Inactive
0
152.9MB
31.4
7878043238400000000
Loading...
Finally, I pushed the model as whisper-tiny-ta into hugginface hub.

Next Steps

Now that I have a working pipeline I intend to spend the next few days of the spring doing the following.
  • Add spectrogram augmentation
  • Train base and small models
  • Add multi-lingual models for Telugu, Kannada, Malayalam and Tamil (if possible)
  • Add a huggingface space so that people can tryout the models in the space.
  • Try Model Distillation into the tiny model once the small model is trained.

Here are the links to colab and github repo that documents all the code used in this report.


File<(table)>
artifact