An Introduction to Datasets and DataLoader in PyTorch
A tutorial covering how to write Datasets and DataLoader in PyTorch, complete with code and interactive visualizations.
Created on September 16|Last edited on July 21
Comment
In this article, we'll go through the PyTorch data primitives, namely torch.utils.data.DataLoader and torch.utils.data.Dataset, and understand how the pre-loaded datasets work and how to create our own DataLoader and Datasets by subclassing these modules. We'll also use Weights & Biases to log metrics and data.
Here's what we'll be covering:
Table of Contents
Why Write Good Data Loaders and Datasets?The Basic PyTorch Dataset StructureImplementing A Custom Dataset In PyTorchThe Flicker DatasetRSNA Brain Tumor Competition DatasetBest Practices For Creating Custom DatasetsThe Basic PyTorch DataLoader Class StructureExample: Creating A Data Loader From A DatasetUsing Custom Samplers For More Control Over Data LoadingHelpful Dataset And DataLoader Resources
Let's get going!
Why Write Good Data Loaders and Datasets?
Your training pipeline should be as modular as possible in order to aid quick prototyping and maintaining usability. Using a poorly-written data loader / not using a data loader (using a Python generator or a function) can affect the parallelization ability of your code.
Dataset processing is a highly important part of any training pipeline and should be kept separate from modeling.
The same technique won't work everywhere. Some problems might require you to use image augmentations. Therefore, you'd prefer to have an argument (something like data = Dataset(..., fetch = True) ) to test the model's performance. Or you might need to experiment with different sequence lengths and strides for fine-tuning an NLP model. To these ends, it's recommended to use custom Datasets and DatLoaders.
The Basic PyTorch Dataset Structure
The following code snippet contains the original implementation of the Dataset class from PyTorch. All pre-loaded Datasets inherit this basic structure.
class Dataset(...):# Raises NotImplementedErrordef __getitem__(self, index):# Allows us to Add/Concat Datasetsdef __add__(self, other):# Returns the Attribute value or raises a AttributeErrordef __getattr__(self, attribute_name):# Utility methods to "Register" Functions@classmethoddef register_function(cls, ...):# Utility methods to "Register" Functions@classmethoddef register_datapipe_as_function(cls, ...):
As it has such a simple structure, you don't always need to inherit from torch.utils.data.Dataset. For most cases, we can get away by writing some key functions.
💡
Implementing A Custom Dataset In PyTorch
Now, for most purposes, you will need to write your own implementation of a Dataset. So let's see how you can write a custom dataset by subclassing torch.utils.data.Dataset.
You'll need to implement 3 functions:
- __init__: This function is called when instancing the object. It's typically used to store some essential locations like file paths and image transforms.
- __len__: This function returns the length of the dataset.
- __getitem__ : This is the big kahuna 🏅. This function is responsible for returning a sample from the dataset based on the index provided.
class CustomDataset(torch.utils.data.Dataset):# Basic Instantiationdef __init__(self, ..., *args, **kwargs):...# Length of the Datasetdef __len__(self):...# Fetch an item from the Datasetdef __getitem__(self, idx):...
Let's walk through some examples of Custom Datasets.
The Flicker Dataset
This code snippet is taken from my Kaggle Kernel on Neural Image Captioning. Let's walk through the code:
- The __init__ method contains a reference to the data frame containing references to the image paths and a transforms variable containing a list of image augmentations.
- The __len__ method returns the length of the data frame. (The default Python len function is implemented for pandas)
- The __getitem__ method reads the image using PIL, applies the transforms needed, encodes the comments, and returns a dictionary with custom key values.
class FlickrDataset(Dataset):def __init__(self, df,transforms):self.df = dfself.transforms = T.Compose([T.ToTensor(),T.Normalize(mean = [0.5], std = [0.5]),T.Resize((256,256)),])def __len__(self) -> int:return len(self.df)def __getitem__(self, idx: int):image_id = self.df.image_name.values[idx]image = Image.open(image_id).convert('RGB')if self.transforms is not None:image = self.transforms(image)comments = self.df[self.df.image_name == image_id].values.tolist()[0][1:][0]encoded_inputs = tokenizer(comments,return_token_type_ids = False,return_attention_mask = False,max_length = 100,padding = "max_length",return_tensors = "pt")sample = {"image":image.to(device),"captions": encoded_inputs["input_ids"].flatten().to(device)}return sample
RSNA Brain Tumor Competition Dataset
This code snippet is taken from my Custom Wrapper for the RSNA-MICCAI Brain Tumor Radiogenomic Classification Kaggle Competition. Let's walk through the code:
- The __init__ method contains a reference to the paths, targets, the MRI type, and other such information.
- The __len__ method returns the length of the data frame. (The default Python len function is implemented for pandas)
- The __getitem__ method reads the image using a custom function and returns a custom dictionary containing augmented images and targets (if needed).
class Dataset(torch_data.Dataset):def __init__(self,paths,targets=None,mri_type=None,label_smoothing: float = 0.01,split: str = "train",augment: bool = False,):self.paths = pathsself.targets = targetsself.mri_type = mri_typeself.label_smoothing = label_smoothingself.split = splitself.augment = augmentdef __len__(self):return len(self.paths)def __getitem__(self, index):scan_id = self.paths[index]if self.targets is None:data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], split=self.split)else:data = load_dicom_images_3d(str(scan_id).zfill(5), mri_type=self.mri_type[index], split="train")if self.augment:data = seq(images=data)if self.targets is None:return {"X": torch.tensor(data).float(), "id": scan_id}else:y = torch.tensor(abs(self.targets[index] - self.label_smoothing), dtype=torch.float)return {"X": torch.tensor(data).float(), "y": y}
Best Practices For Creating Custom Datasets
There are some general things you need to remember while creating custom datasets.
- The index for your dataset should vary between the length obtained from the __len__ function. Otherwise, it'll throw an error.
- We need to overwrite the __len__ function to overwrite the output of many Sampler implementations and the default options of DataLoader. (Reference: PyTorch docs)
- In case you're working with data that comes from a stream, you should subclass IterableDataset. For more information, refer to the docs.
The Basic PyTorch DataLoader Class Structure
The following code snippet contains the original implementation of the DataLoader class from PyTorch.
class DataLoader(...):# Basic __init__ functiondef __init__(self,..):# Returns Either a Single or a Multi Process Iteratordef _get_iterator(self):# Handle Multiprocessing@propertydef multiprocessing_context(self):# Handle Multiprocessing@multiprocessing_context.setterdef multiprocessing_context(self, multiprocessing_context):# Override default __setattr__ methoddef __setattr__(self, attr, val):# Override default __iter__ methoddef __iter__(self):# Helper Function for collation@propertydef _auto_collation(self):# The Actual Sampler Used for fetching@propertydef _index_sampler(self):# Returns the length of the Index Sampler (in case of map-style dataset)def __len__(self) -> int:# Checks if the worker number is rational based on system resourcedef check_worker_number_rationality(self):
Now this does look complicated 🧐, but in most cases, we don't understand most of this. But it's nice to know how PyTorch takes care of multiprocessing and handling different types of Iterators.
Example: Creating A Data Loader From A Dataset
Most pre-loaded datasets from Torchvision return torch.utils.data.Dataset objects, thus enabling us to directly feed them into the torch.utils.data.DataLoader class and then enumerate through them in our training loop.
For example, this code snippet from the PyTorch tutorials shows how easily we can create data loaders using pre-loaded datasets from torchvision.
from torchvision import datasetsfrom torch.utils.data import DataLoaderfrom torchvision.transforms import ToTensortraining_data = datasets.FashionMNIST(root="data", train=True, download=True, transform=ToTensor())test_data = datasets.FashionMNIST(root="data", train=False, download=True, transform=ToTensor())train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
Using Custom Samplers For More Control Over Data Loading
The aforementioned code example returns mini-batches of data with the provided batch size.
For even more control over your data loading, use custom Samplers. Every subclass must contain a __iter__ method and a __len__ method to specify enumeration.
Run set
1
Helpful Dataset And DataLoader Resources
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.