Skip to main content

How the TorchData API Works: a Tutorial with Code

Let's check the new way of building Datasets on latest PyTorch 1.11 with TorchData.
Created on March 15|Last edited on October 6
PyTorch 1.11 was released last week and it adds a few new important functionalities to our machine learning toolbox:
  • TorchData:  A library that enable building datasets with reusable building blocks called DataPipes
  • functorch: A JAX like composable functions API
For an intro to the other new PyTorch feature: functorch, check this report. Let's take the new TorchData for a spin.
💡Check the official documentation of torchdata here

What Is TorchData? 🧐

From PyTorch website:
torchdata is a library of common modular data loading primitives for easily constructing flexible and performant data pipelines.
Suppose we want to load data from folders containing image files with the following steps:
List all files in a directory ➡ Filter the image files ➡ Find the labels for each image
There are a few built-in DataPipes that can help us do exactly that:
  • FileLister - lists out files in a directory
  • Filter - filters the elements in DataPipe based on a given function
  • FileOpener - consumes file paths and returns opened file streams
  • Mapper - Applies a function over each item from the source DataPipe
✅ The code for this report can be found here: https://github.com/tcapelle/torchdata
Let's see how this looks in action:

A Computer Vision Example With The CamVid Dataset 🚗

We are using the Cambridge-driving Labeled Video Database or CamVid. It contains a collection of videos with object class semantic labels, complete with metadata. The database provides ground truth labels that associate each pixel with one of 32 semantic classes.
💡 If you want to know how to train a model on this dataset, check this report.
Let's grab CamVid from the fastai S3 bucket:
$wget "https://s3.amazonaws.com/fast-ai-imagelocal/camvid.tgz"
$tar -xf camvid.tgz # x: extract, f: file
The decompressed folder looks like this:
$tree camvid
camvid
├── codes.txt
├── images
│   ├── 0001TP_006690.png
│   ├── 0001TP_006720.png
│   ├── 0001TP_006750.png
│   ...
├── labels
│   ├── 0001TP_006690_P.png
│   ├── 0001TP_006720_P.png
│   ├── 0001TP_006750_P.png
. ...
└── valid.txt

NB: the labels are named with and extra _P before the extension.
💡

The Datasets Way ☕️

Using TorchData API 🔥

The new TorchData gives us multiple ways to achieve the above with a lot less fuss:

Option #1

We can mimic the Dataset class defined above. First, we list all .png files inside the images folder:
import torchdata.datapipes.iter as pipes

camvid_path = Path("camvid")
files = pipes.FileLister([camvid_path/"images"], masks="*.png")
...then we get the corresponding label masks from the filename (here we actually split the pipeline to return a tuple[image, mask])
def label_func2(fname):
"Same as before, but returns (image, mask)"
fname = Path(fname)
return fname, label_func(fname)

labelled_files = pipes.Mapper(files, label_func2)
...and finally open both image and mask
def PIL_open(data):
return Image.open(data[0]), Image.open(data[1], mode)

ds = pipes.Mapper(labelled_files, PIL_open)

# helper function from fastai
show_images(next(iter(ds)))

This pipeline can be also written in this functional version (preferred way)
ds = (pipes.FileLister([camvid_path/"images"], masks="*.png")
.map(label_func2)
.map(PIL_open))

Option #2

Alternatively, we can put two pipelines together using Zip
images = pipes.FileLister([camvid_path/"images"]).map(Image.open)
labels = pipes.FileLister([camvid_path/"labels"]).map(Image.open)

ds = pipes.Zipper(images, labels).shuffle().batch(3) # we can even shuffle and batch the outputs

batch = next(iter(ds))
for im, mk in batch:
show_images([im, mk])


Option #3

Lastly, a more delicate option, reading directly inside the .tar file without decompressing it.
camvid_itertable = pipes.IterableWrapper(["camvid.tgz"])
files = pipes.FileOpener(camvid_itertable, mode="b").load_from_tar(). # could not find the FileOpener func

# separate pipelines for images and masks
images = files.filter(lambda tup: Path(tup[0]).parent.name == "images") # maybe better way to do this?
images = images.routed_decode(imagehandler("pil"))

labels = files.filter(lambda tup: Path(tup[0]).parent.name == "labels")
labels = labels.routed_decode(imagehandler("pill")) # on mode pil-L

# merge them together on id (the filename)
def get_image_id(data):
path, _ = data
return Path(path).name.split(".")[0]

def get_label_id(data):
path, _ = data
return Path(path).name.split("_P")[0]

ds = pipes.IterKeyZipper(images,
labels,
key_fn=get_image_id, # function to get id from images
ref_key_fn=get_label_id. # function to get id from masks
) # this is utterly slow, probably doing something wrong

Visualizing your data: Logging images into a wandb.Table

To get insights on your data you can use the power of wandb.Tables. It's like DataFrames but with rich media support. You can filter, sort, and preview your images along with your segmentation masks dynamically.
dp_batched = dp.shuffle().batch(24)

def to_wandb_image(image, mask, class_labels=class_labels):
"Cast PIL images to wandb.Image"
wandb_image = wandb.Image(image)
wandb_mask = wandb.Image(image, masks={"predictions":
{"mask_data": np.array(mask),
"class_labels": class_labels}})
return wandb_image, wandb_mask

def create_table(samples):
"Create a table with (Images, Masks)"
table_data = []

for img, mask in samples:
table_data.append(to_wandb_image(img, mask))

return wandb.Table(data=table_data,
columns=["Images", "Segmentation_Masks"])

samples = next(iter(dp_batched))
table = create_table(samples)

import wandb

# log to wandb
wandb.login()

with wandb.init(project="torchdata", job_type="log_dataset"):
wandb.log({"sample_batch":table})



A Simple Timeseries Example 📈

Let's grab a Kaggle dataset containing stock prices for the $AAPL (Apple Inc) company for the last 10 years
df = pd.read_csv("HistoricalQuotes.csv")
df.head()


we will use the Close/Last price and create a pipeline capable of generating a rolling window

Let's build the TorchData pipeline:
datapipe = pipes.IterableWrapper(["HistoricalQuotes.csv"])

# we skip the header row
csv = pipes.FileOpener(datapipe, mode='rt').parse_csv(delimiter=',', skip_lines=1)
Now, we have to define a way to convert the price to a numerical value:
def parse_price(dp):
date, close, vol, open, high, low = dp
return float(close.strip().replace("$", ""))

# then we map this function to the pipeline
prices = csv.map(parse_price)
Next, we need to define a custom pipeline to generate the rolling window. We can do it like this:
from torchdata.datapipes import functional_datapipe

# we register the class to use it as function
@functional_datapipe("rolling")
class RollingWindow(pipes.IterDataPipe):
def __init__(self, source_dp: pipes.IterDataPipe, window_size, step=1) -> None:
super().__init__()
self.source_dp = source_dp
self.window_size = window_size
self.step = step
def __iter__(self):
it = iter(self.source_dp)
cur = []
while True:
try:
while len(cur) < self.window_size:
cur.append(next(it))
yield np.array(cur)
for _ in range(self.step):
if cur:
cur.pop(0)
else:
next(it)
except StopIteration:
return
Followed by:
from itertools import islice

dp = RollingWindow(prices, 5, step=2)
it = iter(dp)

#let's grab the first 4 windows
list(islice(it, 4))


Now, let's put everything together:
datapipe = pipes.IterableWrapper(["HistoricalQuotes.csv"])
ds = (pipes.FileOpener(datapipe, mode='rt').parse_csv(delimiter=',', skip_lines=1)
.map(parse_price)
.rolling(window_size=5, step=2)
.batch(4)
)

next(iter(ds))



Final Thoughts On The TorchData API

The TorchData API is really fun to use and enables all types of composable blocks to build our datasets constructions. We already loved the fastai 's DataBlock API so we are vey happy to have something familiar in torch!
✅ The code for this report can be found here: https://github.com/tcapelle/torchdata



Do you like this article?

This is a Weights and Biases Report, the best way to share experiments with your team.
Start writing your own right now on https://wandb.ai/fc or try an intro notebook in PyTorch here



Iterate on AI agents and models faster. Try Weights & Biases today.