Fine-Tuning Whisper for Low-Resource Dravidian Languages
In this article, we get an overview of the project I completed as part of the Whisper fine-tuning sprint in December 2022, hosted by HuggingFace and Lambda Labs.
Created on December 16|Last edited on January 27
Comment
Introduction
We've recently seen deep learning widely adopted for automatic speech recognition (ASR) due to the availability of large datasets that contain speech and corresponding transcripts.
A recent trend, specifically in this domain, has been self-supervised learning. For instance, the Wave2Vec family of models uses a task similar to masked language modeling to pre-train a network using unlabelled speech before fine-tuning on the ASR task. This allows the network to learn contextual speech embeddings that can be finetuned using a parallel corpora of speech and transcripts.
Whisper is a different type of ASR model released by OpenAI. Unlike other recent models, Whisper is completely trained using supervised learning on weakly labeled data.
Specifically, the researchers gathered 680k hours of multilingual speech and transcription data in 96 languages from the web and fine-tuned the model to directly predict the transcript from audio. It's an encoder-decoder transformer trained using multi-task learning with tasks that include transcription, translation, and timestamp prediction. Here's an overview of the model architecture:

Here's what we'll be covering in this article:
Table of Contents
IntroductionBackgroundProblem StatementDescription of Software/ToolsDescription of DataModelingReleased ModelsTraining CurvesEvaluation ResultsDemonstrationsSummary
Let's get started.
Background
The awesome folks at HuggingFace ported the Whisper model into the transformers library within a few weeks of the model's release. However, the model is not without limitations. The authors of the Whisper paper detail the following limitations (a.k.a. hackathon ideas).
- Inaccurate timestamp predictions
- Hallucinations
- Low performance on low-resource languages
- No speaker recognition
- No real-time transcription
This leads us to the Whisper Fine-tuning event run by HuggingFace and Lambda Labs. The event was organized as a community sprint from December 5th to the 19th. HuggingFace provided the models, starter code, and community support via discord while lambda labs provided the necessary compute (roughly 100 hours of compute to use A100 GPUS).
The main components of the sprint included:
- Open AI’s state-of-the-art Whisper model
- Public datasets like Common Voice 11, VoxPopuli, CoVoST2 and more
- Real-world audio for evaluation
With the key outcomes being:
- Fine-tuned Whisper checkpoint (e.g., Whisper-large)
- Evaluation script for fine-tuned checkpoints
- Hugging Face space to demo fine-tuned models
Problem Statement
The goal of this project is to train a model that uses Seq2seq transformers(OpenAI/Whisper) for the speech-to-text task in four Dravidian languages - (Tamil, Malayalam, Telugu, and Kannada). Unlike languages such as English and Spanish, there are very few datasets and benchmarks available in these languages. We call these low-resource languages.
To this end, the project aimed to perform rigorous data collection in these languages and train deep learning models to transcribe the speech in these languages. As a result, this project aims to extend the applications of AI technologies for speakers of these languages by removing the barrier of available models for them.
Description of Software/Tools
This project was mostly created to run on Ubuntu Linux machines with NVIDIA Graphics Cards. This is primarily due to the following reasons:
- We'll process a large amount of audio data and rely on robust audio kernel implementations provided by the Linux software kernel
- Availability of CUDA GPU (Min 12GB RAM) to train and evaluate the models.
- A large part of the model training happens on cloud GPU machines with enough GPU RAM to load and process speech data.
- The project uses an anaconda virtual environment with Python 3.8. I'll provide a shell script to create the environment in the install, config and setup section below
- It’s recommended to use a machine with a large number of CPU cores when recreating this project since the code makes effective use of multi-processing and multi-threading libraries in Python.
- Some cloud providers with access to GPU machines I used for this project include:
Description of Data
Initially I collected data for all the South Indian languages from various heterogeneous sources. These include publicly available datasets released by open source organizations such as The Mozilla Foundation, Google and The Government of India.
Some datasets were curated and formatted while others had to be downloaded and preprocessed into the proper format. The following table gives you an overview of the various data sources used to create the datasets (note, you can scroll right for additional information about these datasets):
Downloading Source Datasets
While many data sources were available from the Huggingface Datasets Hub and can be accessed and downloaded using the datasets library, others (such as the UCLA corpus) needed to be manually downloaded. Here's how we handled that dataset:
Create source file lists:
- Save the URLs as individual text files for each language. These files can be seen in the code repository here.
Next, download the and save the zip files from the source file lists. This is done with the following function:
import osfrom keras.utils import get_filedef download_archive(url):filename = os.path.split(url)[-1]try:filename = get_file(fname=filename, origin=url, cache_dir="./", extract=False)except:print(f"Unable to download {url}")return Nonereturn filename
Extract each zip file archive into its own individual directory using the zip archives name. The zip archive contains audio data in wav format and a metadata file named data.json . We extract the audio data into the wav sub-directory and the data.json file in the zip file directory. Here’s the code to achieve this:
import pathlibfrom zipfile import ZipFiledef extract_zipfile(filename):filename = pathlib.Path(filename)dirname = filename.parent / filename.stemwith ZipFile(filename, "r") as zipf:for zipinfo in zipf.infolist():if zipinfo.filename[-1] == "/":continuezipinfo.filename = pathlib.Path(zipinfo.filename).nameif pathlib.Path(zipinfo.filename).suffix == ".wav":zipf.extract(zipinfo, dirname / "wav")else:zipf.extract(zipinfo, dirname)wavfiles = list((dirname / "wav").rglob("*.wav"))return wavfiles
Since the wav format is uncompressed this generates a lot of data and can quickly fill-up even a 1TB hard drive. Therefore, we convert the wav data into compressed mp3 format using the pydub library(convert and write mp3 file) and delete the wav file once converted. Any wav files that cannot be converted (for instance due to data corruption) are dropped and deleted. The following code achieves this:
from pydub import AudioSegment, effectsdef get_mp3_path(wavfile):mp3dir = list(wavfile.parents)[1] / "mp3"mp3dir.mkdir(parents=True, exist_ok=True)mp3file = mp3dir / wavfile.with_suffix(".mp3").namereturn mp3filedef convert_wav_to_mp3(wavfile):mp3file = get_mp3_path(wavfile)try:wavaudio = AudioSegment.from_wav(wavfile)wavaudio = wavaudio.set_frame_rate(16000).set_channels(1)wavaudio = effects.normalize(wavaudio)wavaudio.export(mp3file, format="mp3", bitrate="16k")wavfile.unlink()return mp3fileexcept Exception as e:print(f"Unable to convert {wavfile} to mp3")wavfile.unlink()
Finally, putting together all of the above steps for the various source list files in step 1 uses multi-processing and multi-threading to download, extract, and convert the source data files:
from concurrent.futures import as_completed, ThreadPoolExecutorfrom multiprocessing import cpu_count, Poolfrom tqdm import tqdmdef main():urls = list(map(lambda x: x.strip(), open("../data/file_list.txt").readlines()))print("Downloading archives")zipfiles = []with ThreadPoolExecutor(cpu_count()*4) as executor:results = executor.map(download_archive, urls)for url in tqdm(results):zipfiles.append(url)zipfiles = list(pathlib.Path("datasets").rglob("*.zip"))mp3files = []with ThreadPoolExecutor(max_workers=len(zipfiles)) as executor:print("Extracting archives")extracted = [executor.submit(extract_zipfile, file) for file in zipfiles]for wavfiles in as_completed(extracted):wavfiles = wavfiles.result()with Pool(cpu_count() - 1) as pool:results = pool.imap_unordered(convert_wav_to_mp3, wavfiles)print("Converting wav to mp3")for mp3file in tqdm(results, total=len(wavfiles)):mp3files.append(mp3file)return mp3files
Note: This takes a very long time (3days) and requires high speed internet and a lot of disk space!
💡
One final step that remains is to collate all the data from individual sub-directories into a directory for each language. The code for this can be obtained in the create_lang_dirs.py file in the repository. This collates all of the audio files in a language into a single folder with a metadata.json file in the same directory. For instance, it creates a tamil directory for the Tamil language with a tamil/train sub-directory for the audio data and tamil/metadata.jsonl containing all the metadata for the audio files.
Preprocessing Datasets
Once the source datasets have been collected, we are ready to preprocess and clean the datasets. Since the datasets have been collected from various heterogeneous sources, they are all available in multiple formats. For instance, the sampling rate of audio from each source varies from 8khz to 48khz.
Additionally, the metadata for each file contains data that includes transcription, gender, source, duration, etc. Also, some data sources have very short (less than 3s) or very long transcriptions (more than 30s). Therefore, we'll clean the data and create a homogeneous dataset for each language. This can be done using the following steps:
First, load the source datasets. We use the datasets library to load and process datasets. The library uses a memory-mapped Pyarrow backend to speed up the processing of large amounts of data without the requirement for large ram space. We additionally convert the audio into 16kHz at this step. Here’s an example of this step from the Kannada language conversion script.
from datasets import load_dataset, Audio, DatasetDict, concatenate_datasetsdef load_data_splits(is_streaming=True, stopping_strategy="all_exhausted"):data_dict = {}data_dict["openslr_dataset_train"] = load_dataset("openslr", "SLR79", split="train", use_auth_token=True)data_dict["ucla_dataset_train"] = load_dataset("audiofolder", data_dir="../data/kannada/", drop_labels=True)["train"]data_dict["fleurs_dataset_train"] = load_dataset("google/fleurs", "kn_in", split="train",use_auth_token=True).rename_column("transcription", "sentence")data_dict["fleurs_dataset_val"] = load_dataset("google/fleurs", "kn_in", split="validation",use_auth_token=True).rename_column("transcription", "sentence")data_dict["fleurs_dataset_test"] = load_dataset("google/fleurs", "kn_in", split="test",use_auth_token=True).rename_column("transcription", "sentence")for k in data_dict:data_dict[k] = (data_dict[k].remove_columns([col for col in data_dict[k].column_names if col not in ["audio", "sentence"]]))data_dict[k] = data_dict[k].cast_column("audio", Audio(sampling_rate=16000))dataset_dict = DatasetDict()train_datasets = []test_datasets = []for k in data_dict:if k.endswith("train") or k.endswith("val"):train_datasets.append(data_dict[k])if k.endswith("test"):test_datasets.append(data_dict[k])dataset_dict["train"] = concatenate_datasets(train_datasets)dataset_dict["test"] = concatenate_datasets(test_datasets)return dataset_dict
Next, we convert the audio from the source format to 16kHz single channel mp3 files and filter any records that have nulls, very short or long audio. Here’s the code to do this:
def audio_from_array(array):file = BytesIO()wavf.write(file, 16000, array)audio_segment = AudioSegment.from_file(file)audio_segment.set_frame_rate(16000).set_channels(1)audio_segment = effects.normalize(audio_segment)return audio_segmentdef filter_nans_and_short(example):sentence = example["sentence"]length = example["length"]if sentence is None:return Falseelif length < 3 or length > 30:return Falseelse:return Truedef export_audio_and_sentence(example, split):audio, sentence = example["audio"], example["sentence"]new_name = pathlib.Path(uuid(audio["path"])).with_suffix(".mp3").namenew_path = pathlib.Path(f"../data/filtered_datasets/kannada/{split}/{new_name}")new_path.parent.mkdir(parents=True, exist_ok=True)length = len(audio["array"])/16000if new_path.is_file():return {"path": str(new_path),"sentence": sentence, "length": length, "split": split}else:if not split == "test":is_not_filtered = filter_nans_and_short({"length": length, "sentence": sentence})else:is_not_filtered = Trueif is_not_filtered:try:new_path.parent.mkdir(exist_ok=True, parents=True)audio_segment = audio_from_array(audio["array"])audio_segment.export(new_path, format="mp3", bitrate="16k")segment = AudioSegment.from_mp3(new_path)except:return Nonereturn {"path": str(new_path),"sentence": sentence, "length": length, "split": split}else:return None
Finally, we put together all the above code for each language and create a single homogeneous dataset for each language. Here’s an example for the Kannada language:
if __name__ == "__main__":dataset_dict = load_data_splits(is_streaming=False)exports = []with Pool(cpu_count()-1) as pool:for k in dataset_dict:dataset = dataset_dict[k].shuffle(seed=np.random.randint(1000))exporter = partial(export_audio_and_sentence, split=k)results = pool.imap_unordered(exporter,dataset,chunksize=10)for result in tqdm(results, total=len(dataset)):if result:exports.append(result)exports = pd.DataFrame(exports)exports.to_json("../data/filtered_datasets/kannada/metadata.jsonl", lines=True, orient="records")
The full source code for cleaning up the data for each language is available in the following files:
While the above results in some duplication of code, it’s the quickest and hackiest way to collate data for each language from multiple data sources. I hope to further optimize the above code to take a language parameter and create a dataset for each language using a single source file. The final statistics of the collected and cleaned data is in the below table:
Duration(Hours) | Duration(Hours) | Duration(Hours) | Duration(Hours) | Character Length | Character Length | Character Length | Character Length | Word Length | Word Length | Word Length | Word Length | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
Language | Total | Mean | Min | Max | Total | Mean | Min | Max | Total | Mean | Min | Max |
Tamil | 1293.71 | 0.0019 | 0.0008 | 0.0083 | 70519146 | 103.26 | 2 | 524 | 8018783 | 11.74 | 1 | 62 |
Telugu | 387.21 | 0.0019 | 0.0008 | 0.0042 | 17061166 | 81.53 | 4 | 289 | 2115563 | 10.11 | 1 | 36 |
Malayalam | 10.13 | 0.0015 | 0.0008 | 0.0042 | 480309 | 72.00 | 4 | 303 | 53283 | 7.99 | 1 | 36 |
Kannada | 358.84 | 0.0021 | 0.0008 | 0.0074 | 12973393 | 75.11 | 2 | 249 | 1592203 | 9.22 | 1 | 31 |
Hosting the Datasets
Once cleaned and converted, we are ready to save the datasets in a place that is easily accessible to anyone across the globe! I chose to host the datasets in the datasets hub.
To do this, the audio files were compressed into tar archives, and a data-loading script was built for each language. As an example, here’s the main source code of the data-loading script for the Tamil language:
# https://huggingface.co/datasets/parambharat/tamil_asr_corpus/blob/main/tamil_asr_corpus.pyclass TamilASRCorpus(datasets.GeneratorBasedBuilder):"""Tamil ASR Corpus contains transcribed speech corpus for training ASR systems for Tamil language."""VERSION = datasets.Version("1.1.0")def _info(self):features = datasets.Features({"audio": datasets.Audio(sampling_rate=16_000),"path": datasets.Value("string"),"sentence": datasets.Value("string"),"length": datasets.Value("float")})return datasets.DatasetInfo(description=_DESCRIPTION,features=features,supervised_keys=("sentence", "label"),homepage=_HOMEPAGE,license=_LICENSE,citation=_CITATION,)def _split_generators(self, dl_manager):metadata_paths = dl_manager.download(_METADATA_URLS)train_archive = dl_manager.download(_URLS["train"])test_archive = dl_manager.download(_URLS["test"])local_extracted_train_archive = dl_manager.extract(train_archive) if not dl_manager.is_streaming else Nonelocal_extracted_test_archive = dl_manager.extract(test_archive) if not dl_manager.is_streaming else Nonetest_archive = dl_manager.download(_URLS["test"])train_dir = "train"test_dir = "test"return [datasets.SplitGenerator(name=datasets.Split.TRAIN,gen_kwargs={"metadata_path": metadata_paths["train"],"local_extracted_archive": local_extracted_train_archive,"path_to_clips": train_dir,"audio_files": dl_manager.iter_archive(train_archive),},),datasets.SplitGenerator(name=datasets.Split.TEST,gen_kwargs={"metadata_path": metadata_paths["test"],"local_extracted_archive": local_extracted_test_archive,"path_to_clips": test_dir,"audio_files": dl_manager.iter_archive(test_archive),},),]def _generate_examples(self, metadata_path, local_extracted_archive, path_to_clips, audio_files):"""Yields examples as (key, example) tuples."""examples = {}with open(metadata_path, encoding="utf-8") as f:for key, row in enumerate(f):data = json.loads(row)examples[data["path"]] = datainside_clips_dir = Falseid_ = 0for path, f in audio_files:if path.startswith(path_to_clips):inside_clips_dir = Trueif path in examples:result = examples[path]path = os.path.join(local_extracted_archive, path) if local_extracted_archive else pathresult["audio"] = {"path": path, "bytes": f.read()}result["path"] = pathyield id_, resultid_ += 1elif inside_clips_dir:break
Finally, the datasets were pushed to the datasets hub. The following datasets were created in this manner.
Here's a table preview of the Tamil dataset:
These datasets are provided with CC 4.0 License and can be reused, adapted, and shared by anyone, anywhere, with just a single line of code. This not only downloads the dataset but also caches it for future reuse:
from datasets import load_datasetdataset = load_dataset("parambharat/telugu_asr_corpus", split="train")
Additionally, if you are constrained by disk space, you can pass the streaming parameter to the load_dataset function and access the data like you would a python iterator here’s an example:
from datasets import load_datasetdataset = load_dataset("parambharat/telugu_asr_corpus", split="train", streaming=True)next(iter(dataset))>>{'audio': {'path': 'train/9zKDqLgBykz9FSyL6bFPRQ.mp3','array': array([ 1.2912806e-03, 2.4933945e-03, 3.6223403e-03, ...,5.2775860e-05, -1.2118297e-05, 8.1833590e-05], dtype=float32),'sampling_rate': 16000},'path': 'train/9zKDqLgBykz9FSyL6bFPRQ.mp3','sentence': 'అక్కడ మీడియాతో మాట్లాడిన అనంతరం తిరిగి హైదరాబాద్ బయలుదేరుతారు','length': 3.87}
Dataset EDA
We the cleaned datasets and now we can perform the exploratory data analysis to get an understanding of the datasets. You can see the full analysis of the datasets in this notebook.
Note: Run git clone <dataset_link> to fetch the datasets locally if you would like to re-run the notebook. the notebook expects all datasets to be in the data directory For instance if you would like to fetch the malayalam_dataset you need to run git clone https://huggingface.co/datasets/parambharat/malayalam_asr_corpus
💡
Here are a few plots to show what we're working with here:
Key Insights
- All the histograms are left-skewed with more data lying in the less than 15 seconds region. This will be useful in deciding the truncation and padding sizes when feeding the data into the neural network algorithm.
- The violin plots show that the audio is most dense in the 5-10 seconds region.
- Most of the transcripts across languages are distributed between 3 to 15 words. This will be helpful in deciding the decoding sequence length of the models.
Modeling
The model takes as inputs log-mel spectrogram features and processes them through two 1D Convolution layers before applying an encoder-decoder transformer with a multi-headed cross and self-attention to directly predict the transcript.
The model makes use of special tokens to predict the language, timestamps, and transcripts in an autoregressive fashion. The authors boast near human-level performance, and a significant decrease in word error rates (WER) resulting in new state-of-the-art (SOTA) results on multiple languages and benchmarks. The model was evaluated across 96 languages, including the Tamil, Kannada, Malayalam, and Telugu languages that we're interested in for this report.
However, the datasets used to train the model on these languages are quite small compared to the datasets we have been able to gather. The below screenshot of a table from the paper shows the performance on the google/fleurs benchmark dataset with the model's performance on languages of interest highlighted.

Table from the Whisper paper with performance on the languages of this project highlighted.
Models
In this project, we train 3 variants of the Whisper model for each language. The details of the original fine-tuned models are in the following table below:
Model | Layers | Embedding Size | Attention Heads | Parameters | Checkpoint |
---|---|---|---|---|---|
https://huggingface.co/openai/whisper-tiny | 4 | 384 | 6 | 39M | https://huggingface.co/openai/whisper-tiny |
https://huggingface.co/openai/whisper-base | 6 | 512 | 8 | 74M | https://huggingface.co/openai/whisper-base |
https://huggingface.co/openai/whisper-small | 12 | 768 | 12 | 244M | https://huggingface.co/openai/whisper-small |
Note: Here, Layers refers to a symmetric number of encoder and decoder transformer layers . i.e. 4 layers means 4 encoder and 4 decoders.
💡
Loading the Data
Since we're dealing with very large datasets that need to be turned into mel-spectrograms, we use the datasets library to lazily load and stream the datasets from disk while mapping the feature extraction process across the entire dataset.
While this potentially slows down the model training time, it's not long enough to be a bottleneck for training when compared to the amount of RAM required to process nearly half a million audio files. The below code snippet shows how to load the dataset in streaming format:
def load_data_splits(is_streaming=True,):data_dict = load_dataset("parambharat/tamil_asr_corpus", streaming=is_streaming)return data_dict
Data Augmentation
The authors of the Whisper paper note that the data used to pre-train the model was not augmented in any way. Since we wish the model to work in robust and noisy settings and generalize well across speakers, we choose to augment the dataset. We use the audiomentations library to perform the augmentations and apply GaussianNoise, TimeStretch, PitchShift augmentations to the audio waveform. This is performed only on the training dataset with the following code.
augment_waveform = Compose([AddGaussianNoise(min_amplitude=0.005, max_amplitude=0.015, p=0.2),TimeStretch(min_rate=0.8, max_rate=1.25, p=0.2, leave_length_unchanged=False),PitchShift(min_semitones=-4, max_semitones=4, p=0.2),])def augment_dataset(batch):audio = batch["audio"]["array"]# apply augmentationaugmented_audio = augment_waveform(samples=audio, sample_rate=16000)batch["audio"]["array"] = augmented_audioreturn batch# call augment dataset on the training setdataset_dict["train"] = dataset_dict["train"].map(augment_dataset)
Note: We apply augmentations randomly with a probability of 0.3 and that at most 3 augmentations can be applied to a waveform.
Preprocessing
While our data collection stage mostly preprocessed the audio data and removed noisy data we still ensure that nulls and short transcript are filtered out when training the models. Additionally, we need to normalize by stripping punctuation before feeding it into the model. The following code snippet is used to filter and clean the dataset.
def fix_sentence(sentence):transcription = sentenceif transcription.startswith('"') and transcription.endswith('"'):# we can remove trailing quotation marks as they do not affect the transcriptiontranscription = transcription[1:-1]if transcription[-1] not in [".", "?", "!"]:# append a full-stop to sentences that do not end in punctuationtranscription = transcription + "."transcription = transcription[:-1].translate(str.maketrans('', '', string.punctuation)) + transcription[-1]return transcriptiondef filter_empty_strings(sentence):if len(sentence) < 2:return Falseelse: return Truefor k in dataset_dict:dataset_dict[k] = dataset_dict[k].filter(filter_empty_strings, input_columns=["sentence"])
Feature Extraction
Since the model takes as input mel-spectrogram features and predicts transcription, we need to convert the audio into features and transcription text into tokenized inputs that the model can predict. The transformer library provides two utility classes to achieve this. The WhisperFeatureExtractor and the WhisperTokenzier.
We use these classes to extract features from the audio and tokenize the transcripts in to byte-pairs. The following snippet of code can be used to run the step over the entire dataset, for instance, in the Telugu Language.
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny",language="Telugu",task="transcribe",model_max_length=225)def prepare_dataset(examples):# compute log-Mel input features from input audio arrayaudio = examples["audio"]examples["input_features"] = feature_extractor(audio["array"], sampling_rate=16000).input_features[0]sentences = fix_sentence(examples["sentence"])# encode target text to label idsexamples["labels"] = tokenizer(sentences, max_length=225, truncation=True).input_idsreturn examplesfor k in dataset_dict:dataset_dict[k] = dataset_dict[k].map(prepare_dataset,).with_format("torch")
Note: The above code also converts the dataset into an instance of IterableDataset for easy loading into the model. Furthermore, you will notice that the model adds some special tokens to the start and end of the transcript and that the shape of the input features is (80,3000) i.e. 80 mels with a frequency scale of 3000. Here’s a sample:
💡
features = feature_extractor(sample["audio"]["array"], sampling_rate=16000)["input_features"][0]print(features)print(features.shape)[Output:][[ 0.03961742 -0.4417838 -0.423414 ... -0.6383085 -0.6383085-0.6383085 ][ 0.04983616 -0.28957796 -0.087901 ... -0.6383085 -0.6383085-0.6383085 ][ 0.13291407 -0.16156828 -0.17585051 ... -0.6383085 -0.6383085-0.6383085 ]...[-0.6383085 -0.6383085 -0.6383085 ... -0.6383085 -0.6383085-0.6383085 ][-0.6383085 -0.6383085 -0.6383085 ... -0.6383085 -0.6383085-0.6383085 ][-0.6383085 -0.6383085 -0.6383085 ... -0.6383085 -0.6383085-0.6383085 ]](80, 3000)input_str = sample["sentence"]labels = tokenizer(input_str).input_idsdecoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)decoded_str = tokenizer.decode(labels, skip_special_tokens=True)print(f"Input: {input_str}")print(f"Decoded w/ special: {decoded_with_special}")print(f"Decoded w/out special: {decoded_str}")print(f"Are equal: {input_str == decoded_str}")[Output:]Input: అక్కడ మీడియాతో మాట్లాడిన అనంతరం తిరిగి హైదరాబాద్ బయలుదేరుతారుDecoded w/ special: <|startoftranscript|><|te|><|transcribe|><|notimestamps|>అక్కడ మీడియాతో మాట్లాడిన అనంతరం తిరిగి హైదరాబాద్ బయలుదేరుతారు<|endoftext|>Decoded w/out special: అక్కడ మీడియాతో మాట్లాడిన అనంతరం తిరిగి హైదరాబాద్ బయలుదేరుతారుAre equal: True
Batching
So far we have loaded, preprocessed and converted the dataset into the model’s input format.
However, we also need to provide the model with batches of data for training. Dealing with variable transcription lengths means that we need to dynamically pad a batch of tensors to the same size. To do this we create a utility class to perform the data collation. Here’s the code for the class:
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny",language="Telugu",task="transcribe",model_max_length=225)@dataclassclass DataCollatorSpeechSeq2SeqWithPadding:processor: Anydef __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:# split inputs and labels since they have to be of different lengths and need different padding methods# first treat the audio inputs by simply returning torch tensorsinput_features = [{"input_features": feature["input_features"]} for feature in features]batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")# get the tokenized label sequenceslabel_features = [{"input_ids": self.processor.tokenizer.truncate_sequences(feature["labels"])[0]}for feature in features]# pad the labels to max lengthlabels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt",)# replace padding with -100 to ignore loss correctlylabels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)# if bos token is appended in previous tokenization step,# cut bos token here as it's append later anywaysif (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():labels = labels[:, 1:]batch["labels"] = labelsreturn batch
Here the WhisperProcessor is a utility class in the transformers library that combines the tokenizer and the feature extractor from the feature extraction step.
Computing Metrics
Before we begin to train the model, we need a way to compute the model metrics. Since we are interested in predicting the transcript for a given audio, we need to understand the amount of error the model is making in its prediction.
A common metric for this is the word error rate (WER) which computes the ratios of errors in the generated transcript to the total words spoken and present in the reference. We make use of the evaluate library to compute this metric. The following code snippet is used to compute the normalized word error rate:
metric = evaluate.load("wer")# evaluate with the 'normalised' WERdo_normalize_eval = Truedef compute_metrics(pred):pred_ids = pred.predictionslabel_ids = pred.label_ids# replace -100 with the pad_token_idlabel_ids[label_ids == -100] = processor.tokenizer.pad_token_id# we do not want to group tokens when computing the metricspred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True, normalize=do_normalize_eval)label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True, normalize=do_normalize_eval)wer = 100 * metric.compute(predictions=pred_str, references=label_str)return {"wer": wer}
Training
To train the model, we use the Seq2SeqTrainer utility class in the transformers library. The class provides a wrapper around the common training logic required to train encoder-decoder models and requires training arguments, models, and callbacks. The following code snippet shows the training arguments used to train a whisper-tiny Telugu model.
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base",use_cache=False)training_args = Seq2SeqTrainingArguments(output_dir="../models/whisper-tiny-te", # change to a repo name of your choiceper_device_train_batch_size=72, # batch size provided to each GPUgradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch sizelearning_rate=1e-5, # learning ratesave_total_limit=4, # no. of last model checkpoints to savewarmup_steps=500, # no. of batches for learning rate warmupmax_steps=3000, # total number of training batches to train the modelgradient_checkpointing=True, # whether or not to use gradient checkpointing on activationsfp16=True, # whether to use mixed precision trainingoptim="adamw_bnb_8bit", # change to a fixed precision optimizer for lower GPU RAM usageevaluation_strategy="steps", # wheather to perform evaluation every batch or epochper_device_eval_batch_size=36, # evaluation batch size (lower than training for seq2seq models)predict_with_generate=True, # whether to run predictions with model.generate (autoregressively)generation_max_length=225, # maximum length of the generated sequences.save_steps=300, # number of batches after which to checkpoint the modeleval_steps=300, # number of batches after which to evaluate the modellogging_steps=100, # number of bathes after which to log metrics from the modelreport_to="none", # whether to report the model progress to tensorboard or any other integrationload_best_model_at_end=True, # whether to load the best model weights at the end of trainingmetric_for_best_model="wer", # metric used to identify the best modelgreater_is_better=False, # needs to be lower for werhub_strategy="checkpoint", # push model checkpoints to huggingface_hubpush_to_hub=True, # whether to push the model to the huggingface_hubremove_unused_columns=False, # remove unused columns from the dataset while loading into the modelignore_data_skip=True # whether to ignore the first n batches when resuming model training.)trainer = Seq2SeqTrainer(args=training_args,model=model,train_dataset=dataset_dict["train"],eval_dataset=samples_dataset,data_collator=data_collator,compute_metrics=compute_metrics,tokenizer=processor,callbacks=[ShuffleCallback()],)
Callbacks
In addition to the existing callbacks we also add a callback to shuffle iterative datasets and monitor model generations periodically. To do this we'll use Weights & Biases Tables to track, organize, and visualize model training and evaluations. We'll create custom callbacks that we then add to the model.
The code for the custom callbacks:
# trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epochclass ShuffleCallback(TrainerCallback):def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):if isinstance(train_dataloader.dataset, IterableDatasetShard):pass # set_epoch() is handled by the Trainerelif isinstance(train_dataloader.dataset, IterableDataset):train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)class WandbProgressResultsCallback(WandbCallback):def __init__(self, trainer, sample_dataset):super().__init__()self.trainer = trainerself.sample_dataset = sample_datasetself.records_df = dataset_to_records(sample_dataset)def on_log(self, args, state, control, model=None, logs=None, **kwargs):super().on_log(args, state, control, model, logs)predictions = trainer.predict(self.sample_dataset)predictions = decode_predictions(self.trainer, predictions)measures_df = compute_measures(predictions, self.records_df["sentence"].tolist())records_df = pd.concat([self.records_df, measures_df], axis=1)records_df["prediction"] = predictionsrecords_df["step"] = state.global_steprecords_table = self._wandb.Table(dataframe=records_df)self._wandb.log({"sample_predictions": records_table})def on_save(self, args, state, control, model=None, tokenizer=None, **kwargs):if self._wandb is None:returnif self._log_model and self._initialized and state.is_world_process_zero:with tempfile.TemporaryDirectory() as temp_dir:self.trainer.save_model(temp_dir)metadata = ({k: vfor k, v in dict(self._wandb.summary).items()if isinstance(v, numbers.Number) and not k.startswith("_")}if not args.load_best_model_at_endelse {f"eval/{args.metric_for_best_model}": state.best_metric,"train/total_floss": state.total_flos,})artifact = self._wandb.Artifact(name=f"model-{self._wandb.run.id}",type="model", metadata=metadata)for f in Path(temp_dir).glob("*"):if f.is_file():with artifact.new_file(f.name, mode="wb") as fa:fa.write(f.read_bytes())self._wandb.run.log_artifact(artifact)# pass as sample from the test dataset to the callbacksamples_dataset = load_samples_dataset(dataset_dict["test"]).map(compute_spectrograms)#create the callbackprogress_callback = WandbProgressResultsCallback(trainer, samples_dataset)# add the callback to the trainertrainer.add_callback(progress_callback)
In addition to logging training progress the callback also generates visualizations and saves model artifacts. These artifacts can be used to resume model training whenever training crashes. Here’s a sample artifact:
model-2k10w4qq
Version overview
Full Name
parambharat/whisper_finetuning/model-2k10w4qq:v5
Aliases
latest
v5
Tags
Digest
8851084a88711f2d2a2e6d8ab2a35ace
Created By
Created At
December 9th, 2022 10:27:30
Num Consumers
0
Num Files
10
Size
151.1MB
TTL Remaining
Inactive
Description
Finally, we can train the model by simply calling the train method in the Trainer class
trainer.train()
This outputs the training progress including a table of the training loss, evaluation loss and the evaluation WER metics. Here’s an example from the whisper-tiny-telugu model.

Training progress of the whisper-tiny-te model
Evaluation
We'll evaluate the model on the test set of mozilla-foundation/common_voice_11 and google/fleurs dataset. The full code for evaluation can be seen in the run_streaming_evaluation.py file.
To run the evaluations, This creates a transformers.AutomaticSpeechRecognitionPipeline and calculate the word error rate (WER) over the entire test datasets. Here’s the main code snippet for this process. For the full code, please refer to the evaluation script provided by the whisper community event organizers here:
batch_size = args.batch_sizewhisper_asr = pipeline("automatic-speech-recognition", model=args.model_id, device=args.device)whisper_asr.model.config.forced_decoder_ids = (whisper_asr.tokenizer.get_decoder_prompt_ids(language=args.language, task="transcribe"))dataset = load_dataset(args.dataset,args.config,split=args.split,streaming=False,use_auth_token=True,)dataset = dataset.cast_column("audio", Audio(sampling_rate=16000))dataset = dataset.map(normalise)dataset = dataset.filter(is_target_text_in_range, input_columns=["norm_text"])predictions = []references = []# run streamed inferencefor out in tqdm(whisper_asr(data(dataset), batch_size=batch_size)):predictions.append(whisper_norm(out["text"]))references.append(out["reference"][0])wer = wer_metric.compute(references=references, predictions=predictions)wer = round(100 * wer, 2)print("WER:", wer)
Released Models
Since I trained multiple models across multiple machines, I created various model training notebooks that generated corresponding logs. The following table contains the links to each notebook, the corresponding training logs and models:
Training Curves
The interactive logs from Weights & Biases dashboard in the above table is the best way to view all the training and validation logs. However, for ease of viewing, I also include the training and validation plots of each model along with the sample prediction tables logged by the custom callback :
Tamil
Whisper Tiny
Whisper Base
Whisper Small
Malayalam
Whisper Tiny
Whisper Base
Whisper Small
Telugu
Whisper Tiny
Whisper Base
Whisper Small
Kannada
Whisper Tiny
Whisper Base
Whisper Small
Evaluation Results
As discussed earlier in the evaluation section, we'll perform evaluation and report evaluation results on the mozilla-foundation/common_voice_11 and google/fleurs test datasets. The results of these evaluations are also published along with the model cards. The following table reports the evaluation of each model on the test sets.
(While Tamil and Malayalam are available in both common_voice and google/fleurs datasets, Telugu and Kannada are only available in the google/fleurs dataset.)
Language | Model | Common Voice(WER) | Fleurs(WER) |
---|---|---|---|
Tamil | https://huggingface.co/parambharat/whisper-tiny-ta | 30.103 | 26.07 |
Tamil | https://huggingface.co/parambharat/whisper-base-ta | 15.780 | 20.410 |
Tamil | https://huggingface.co/parambharat/whisper-small-ta | 11.150 | 15.800 |
Malayalam | https://huggingface.co/parambharat/whisper-tiny-ml | 45.720 | 62.150 |
Malayalam | https://huggingface.co/parambharat/whisper-base-ml | 34.160 | 53.290 |
Malayalam | https://huggingface.co/parambharat/whisper-small-ml | 25.800 | 48.160 |
Telugu | https://huggingface.co/parambharat/whisper-tiny-te | N/A | 52.670 |
Telugu | https://huggingface.co/parambharat/whisper-base-te | N/A | 39.090 |
Telugu | https://huggingface.co/parambharat/whisper-small-te | N/A | 30.260 |
Kannada | https://huggingface.co/parambharat/whisper-tiny-kn | N/A | 43.700 |
Kannada | https://huggingface.co/parambharat/whisper-base-kn | N/A | 30.260 |
Kannada | https://huggingface.co/parambharat/whisper-small-kn | N/A | 25.540 |
Demonstrations
I provide an inference_notebook where running the code will initialize a gradio application to interactively generate transcripts. These applications are also made available publicly via HuggingFace spaces. See the below table for these applications.
A web app for each model developed in this project is released as spaces in the Huggingface hub. The apps are capable of transcribing audio from the microphone and youtube. Below are interactive widgets that you can try
Summary
This project was a great opportunity for me to finetune Automatic Speech Recognition Models for low-resource languages using Whisper. I went through the entire process, starting from data sourcing, collection, cleansing, and preprocessing to model training, evaluation, and even deployment.
In this process, I was able to release four new datasets and twelve new models for languages that have often been overlooked by the research community, mainly due to the lack of availability of large quantities of annotated datasets in these languages. This has left behind the people and speakers of these languages from access to new technologies enabled by deep learning that is often readily available to speakers of other common languages.
This project demonstrates that the gap can be bridged using the right set of tools and technologies. In addition, to training deep learning models, I was also able to make a meaningful contribution to my native language - “Tamil” by releasing open-source speech recognition models and datasets that can be easily accessed by the speakers of the language. Additionally, I was also able to release models for various languages in my country’s region and neighboring states, including Kannada, Telugu, and Malayalam.
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.