Skip to main content

An Introduction to Datasets and DataLoader in PyTorch

A tutorial covering how to write Datasets and DataLoader in PyTorch, complete with code and interactive visualizations.
Created on September 16|Last edited on July 21
In this article, we'll go through the PyTorch data primitives, namely torch.utils.data.DataLoader and torch.utils.data.Dataset, and understand how the pre-loaded datasets work and how to create our own DataLoader and Datasets by subclassing these modules. We'll also use Weights & Biases to log metrics and data.
Here's what we'll be covering:

Table of Contents



Let's get going!

Why Write Good Data Loaders and Datasets?

Your training pipeline should be as modular as possible in order to aid quick prototyping and maintaining usability. Using a poorly-written data loader / not using a data loader (using a Python generator or a function) can affect the parallelization ability of your code.
Dataset processing is a highly important part of any training pipeline and should be kept separate from modeling.
The same technique won't work everywhere. Some problems might require you to use image augmentations. Therefore, you'd prefer to have an argument (something like data = Dataset(..., fetch = True) ) to test the model's performance. Or you might need to experiment with different sequence lengths and strides for fine-tuning an NLP model. To these ends, it's recommended to use custom Datasets and DatLoaders.

The Basic PyTorch Dataset Structure

The following code snippet contains the original implementation of the Dataset class from PyTorch. All pre-loaded Datasets inherit this basic structure.
class Dataset(...):

# Raises NotImplementedError
def __getitem__(self, index):
# Allows us to Add/Concat Datasets
def __add__(self, other):

# Returns the Attribute value or raises a AttributeError
def __getattr__(self, attribute_name):

# Utility methods to "Register" Functions
@classmethod
def register_function(cls, ...):

# Utility methods to "Register" Functions
@classmethod
def register_datapipe_as_function(cls, ...):
As it has such a simple structure, you don't always need to inherit from torch.utils.data.Dataset. For most cases, we can get away by writing some key functions.
💡

Implementing A Custom Dataset In PyTorch

Now, for most purposes, you will need to write your own implementation of a Dataset. So let's see how you can write a custom dataset by subclassing torch.utils.data.Dataset.
You'll need to implement 3 functions:
  1. __init__: This function is called when instancing the object. It's typically used to store some essential locations like file paths and image transforms.
  2. __len__: This function returns the length of the dataset.
  3. __getitem__ : This is the big kahuna 🏅. This function is responsible for returning a sample from the dataset based on the index provided.
class CustomDataset(torch.utils.data.Dataset):

# Basic Instantiation
def __init__(self, ..., *args, **kwargs):
...
# Length of the Dataset
def __len__(self):
...
# Fetch an item from the Dataset
def __getitem__(self, idx):
...
Let's walk through some examples of Custom Datasets.



The Flicker Dataset

This code snippet is taken from my Kaggle Kernel on Neural Image Captioning. Let's walk through the code:
  1. The __init__ method contains a reference to the data frame containing references to the image paths and a transforms variable containing a list of image augmentations.
  2. The __len__ method returns the length of the data frame. (The default Python len function is implemented for pandas)
  3. The __getitem__ method reads the image using PIL, applies the transforms needed, encodes the comments, and returns a dictionary with custom key values.
class FlickrDataset(Dataset):
def __init__(self, df,
transforms):
self.df = df
self.transforms = T.Compose([
T.ToTensor(),
T.Normalize(mean = [0.5], std = [0.5]),
T.Resize((256,256)),
])
def __len__(self) -> int:
return len(self.df)
def __getitem__(self, idx: int):
image_id = self.df.image_name.values[idx]
image = Image.open(image_id).convert('RGB')
if self.transforms is not None:
image = self.transforms(image)
comments = self.df[self.df.image_name == image_id].values.tolist()[0][1:][0]
encoded_inputs = tokenizer(comments,
return_token_type_ids = False,
return_attention_mask = False,
max_length = 100,
padding = "max_length",
return_tensors = "pt")
sample = {"image":image.to(device),
"captions": encoded_inputs["input_ids"].flatten().to(device)
}
return sample

RSNA Brain Tumor Competition Dataset

This code snippet is taken from my Custom Wrapper for the RSNA-MICCAI Brain Tumor Radiogenomic Classification Kaggle Competition. Let's walk through the code:
  1. The __init__ method contains a reference to the paths, targets, the MRI type, and other such information.
  2. The __len__ method returns the length of the data frame. (The default Python len function is implemented for pandas)
  3. The __getitem__ method reads the image using a custom function and returns a custom dictionary containing augmented images and targets (if needed).
class Dataset(torch_data.Dataset):
def __init__(
self,
paths,
targets=None,
mri_type=None,
label_smoothing: float = 0.01,
split: str = "train",
augment: bool = False,
):
self.paths = paths
self.targets = targets
self.mri_type = mri_type
self.label_smoothing = label_smoothing
self.split = split
self.augment = augment

def __len__(self):
return len(self.paths)

def __getitem__(self, index):
scan_id = self.paths[index]
if self.targets is None:
data = load_dicom_images_3d(
str(scan_id).zfill(5), mri_type=self.mri_type[index], split=self.split
)
else:
data = load_dicom_images_3d(
str(scan_id).zfill(5), mri_type=self.mri_type[index], split="train"
)

if self.augment:
data = seq(images=data)

if self.targets is None:
return {"X": torch.tensor(data).float(), "id": scan_id}
else:
y = torch.tensor(
abs(self.targets[index] - self.label_smoothing), dtype=torch.float
)
return {"X": torch.tensor(data).float(), "y": y}



Best Practices For Creating Custom Datasets

There are some general things you need to remember while creating custom datasets.
  • The index for your dataset should vary between the length obtained from the __len__ function. Otherwise, it'll throw an error.
  • We need to overwrite the __len__ function to overwrite the output of many Sampler implementations and the default options of DataLoader. (Reference: PyTorch docs)
  • In case you're working with data that comes from a stream, you should subclass IterableDataset. For more information, refer to the docs.

The Basic PyTorch DataLoader Class Structure

The following code snippet contains the original implementation of the DataLoader class from PyTorch.
class DataLoader(...):

# Basic __init__ function
def __init__(self,..):

# Returns Either a Single or a Multi Process Iterator
def _get_iterator(self):

# Handle Multiprocessing
@property
def multiprocessing_context(self):

# Handle Multiprocessing
@multiprocessing_context.setter
def multiprocessing_context(self, multiprocessing_context):

# Override default __setattr__ method
def __setattr__(self, attr, val):

# Override default __iter__ method
def __iter__(self):

# Helper Function for collation
@property
def _auto_collation(self):

# The Actual Sampler Used for fetching
@property
def _index_sampler(self):

# Returns the length of the Index Sampler (in case of map-style dataset)
def __len__(self) -> int:

# Checks if the worker number is rational based on system resource
def check_worker_number_rationality(self):
Now this does look complicated 🧐, but in most cases, we don't understand most of this. But it's nice to know how PyTorch takes care of multiprocessing and handling different types of Iterators.

Example: Creating A Data Loader From A Dataset

Most pre-loaded datasets from Torchvision return torch.utils.data.Dataset objects, thus enabling us to directly feed them into the torch.utils.data.DataLoader class and then enumerate through them in our training loop.
For example, this code snippet from the PyTorch tutorials shows how easily we can create data loaders using pre-loaded datasets from torchvision.
from torchvision import datasets
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor

training_data = datasets.FashionMNIST(root="data", train=True, download=True, transform=ToTensor())
test_data = datasets.FashionMNIST(root="data", train=False, download=True, transform=ToTensor())

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)


Using Custom Samplers For More Control Over Data Loading

The aforementioned code example returns mini-batches of data with the provided batch size.
For even more control over your data loading, use custom Samplers. Every subclass must contain a __iter__ method and a __len__ method to specify enumeration.
For more information, refer to the PyTorch docs.

Run set
1


Helpful Dataset And DataLoader Resources

  1. DataLoader architecture RFC (For more advanced users)
Iterate on AI agents and models faster. Try Weights & Biases today.