Metric Learning for Image Search With Weights & Biases
In this article, we will explore supervised metric learning and extend it to image similarity search using Weights & Biases to track the results of our experiments.
Created on September 24|Last edited on November 11
Comment
Metric learning is a broad field with many definitions to define it. Primarily, it aims to measure the similarity among data samples and to learn embedding models. In a familiar classification setting, we give our model some and learn to predict its class.
In the context of metric learning to learn embedding models, the motivation is to embed in an embedding space such that similar are close together in that space while dissimilar ones are far away. We are often not interested in how the embedding space looks as long as the we want to be close together(similar) form a cluster in that space.
Euclidean distance is a popular distance metric. One can argue that given images, we can represent them into vectors (abstract features) using a pre-trained image classifier and use euclidean distance to separate features. However, most practical data is not linear and requires task and dataset-specific distance metrics. Thus metric learning aims at automatically constructing task-specific distance metrics.
The field of metric learning is incredibly important and useful because the distance metric/embeddings learned can be useful for many downstream tasks. In literature, metric learning can be tied to model pre-training.
Metric learning falls under three categories:
- Supervised learning: The metric learning algorithm has access to a set of data points, each of them belonging to a class (label) as in a standard classification problem. This setting aims to learn distance metrics to put points in the same label close together.
- Weakly supervised learning: The metric learning algorithm has access to a set of data points with supervision only at the tuple level - typically pairs, triplets, or quadruplets of data points. Siamese network with triplet loss is a popular example in this setting.
- Unsupervised learning: The metric learning algorithm in this setting has only access to . In recent times, the contrastive loss has gained much traction to learn the state-of-the-art embeddings for downstream tasks. The recent developments in unsupervised visual representation can be tied to the success of metric learning.
In this report, we explore supervised metric learning and extend the same for image search.
Table of Contents
The Dataset
For simplicity, we will be using the CIFAR-10 dataset. There are 10 classes as mentioned by CLASS_NAMES.
from tensorflow.keras.datasets import cifar10(x_train, y_train), (x_test, y_test) = cifar10.load_data()CLASS_NAMES = ["airplane", "automobile", "bird", "cat", "deer","dog", "frog", "horse", "ship", "truck"]x_train = x_train.astype("float32") / 255.0y_train = np.squeeze(y_train)x_test = x_test.astype("float32") / 255.0y_test = np.squeeze(y_test)

Figure 1: Samples from CIFAR-10 dataset.
Dataset Preparation
In our supervised metric learning setting we do not want explicit pairs where is the label for corresponding . However, we want multiple instances(pairs) of which are related such that they express semantic similarity. thus one training instance will be a pair of images and not one single image. The pair of images belong to the same class as guided by .
When referring to the images in this pair of images, we use the common term of anchor(a randomly sampled image from the dataset) and positives(another randomly sampled image of the same class). Thus each training data sample consists of an anchor and a positive pair.
The code snippet shown below builds a lookup table that maps from classes to the instances of that class.
class_idx_to_train_idxs = defaultdict(list)for y_train_idx, y in enumerate(y_train):class_idx_to_train_idxs[y].append(y_train_idx)class_idx_to_test_idxs = defaultdict(list)for y_test_idx, y in enumerate(y_test):class_idx_to_test_idxs[y].append(y_test_idx)
Dataloader
The dataloader will produce batches of (anchor, positive) pairs spread across the classes(10 classes in our case).
Since the goal is to learn a metric that will bring images from the same class closer to each other in the embedding space and push away images from different classes. Each batch will have one pair of (anchor, positive) from each class. Thus our batch size will be 10 in the case of CIFAR-10.
num_classes = 10height_width = 32# Ref: https://keras.io/examples/vision/metric_learning/class AnchorPositivePairs(keras.utils.Sequence):def __init__(self, num_batchs):self.num_batchs = num_batchsdef __len__(self):return self.num_batchsdef __getitem__(self, _idx):x = np.empty((2, num_classes, height_width, height_width, 3), dtype=np.float32)for class_idx in range(num_classes):examples_for_class = class_idx_to_train_idxs[class_idx]anchor_idx = random.choice(examples_for_class)positive_idx = random.choice(examples_for_class)while positive_idx == anchor_idx:positive_idx = random.choice(examples_for_class)x[0, class_idx] = x_train[anchor_idx]x[1, class_idx] = x_train[positive_idx]return x

Figure 2: One batch generated by AnchorPositivePairs. The first-row shows anchor images and the second-row belongs to randomly chosen positive images.
The Model
Now that our input pipeline is ready, let's build a model architecture suited for the task. Since the objective is to learn a metric to bring together images from the same class in the embedding space we will first pass the image anchor and its positive image through the convolutional block (one at a time) to get the feature representation of the anchor image and its positive pair. This is then projected using a linear layer(without activation) and normalized so that we can use simple dot products to measure similarity.
We will use a simple convolutional feature extractor for the sake of simplicity.
def get_model():inputs = Input(shape=(height_width, height_width, 3))x = Conv2D(filters=32, kernel_size=3, strides=2, activation="relu")(inputs)x = Conv2D(filters=64, kernel_size=3, strides=2, activation="relu")(x)x = Conv2D(filters=128, kernel_size=3, strides=2, activation="relu")(x)x = GlobalAveragePooling2D()(x)embeddings = Dense(units=64, activation=None)(x) # here units can be hyperparameterembeddings = tf.nn.l2_normalize(embeddings, axis=-1)return EmbeddingModel(inputs, embeddings)
The EmbeddingModel encapsulated the training logic for our metric learning task. It implements a custom model with train_step.
- First, we get the embedding for anchor and positive image.
- Since the embedding is already normalized(Notice tf.nn.l2_normalize in get_model), we can compute the similarity between the anchor and the positive image by doing a pairwise dot product (cosine similarity).
- Since the similarity matrix is used as a logit in our case we need to scale it. The temperature hyperparameter is used to do the same.
- The similarity matrix will have a shape of (NUM_CLASS x NUM_CLASS). The diagonal of this matrix corresponds to the similarity between the anchor and the positive image. We thus want to maximize this diagonal matrix.
- The labels can be [0, 1, 2,..., NUM_CLASS] if we use sparse_categorical_crossentropy to compile our model. This loss will move embeddings for the anchor/positive pairs together and move all other pairs apart.
# ref: https://keras.io/examples/vision/metric_learning/class EmbeddingModel(keras.Model):def train_step(self, data):if isinstance(data, tuple):data = data[0]anchors, positives = data[0], data[1]with tf.GradientTape() as tape:# Run both anchors and positives through model.anchor_embeddings = self(anchors, training=True)positive_embeddings = self(positives, training=True)# Calculate cosine similarity between anchors and positives.similarities = tf.einsum("ae,pe->ap", anchor_embeddings, positive_embeddings)# Scale using temperature. Temperatue is a hyperparameter.temperature = 0.2similarities /= temperature# Compute loss.sparse_labels = tf.range(num_classes)loss = self.compiled_loss(sparse_labels, similarities)# Calculate gradients and apply via optimizer.gradients = tape.gradient(loss, self.trainable_variables)self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))# Update and return metrics (specifically the one for the loss value).self.compiled_metrics.update_state(sparse_labels, similarities)return {m.name: m.result() for m in self.metrics}
The SimilarityLogger - Image Query Search
We will use this custom Keras callback to interactively visualize images that are semantically similar. near_neighbours_per_examples control the number of examples that will be logged along with the query image and are semantically similar. num_examples_to_log is the number of such query-similar images.
One can see this callback as a naive implementation for image similarity search.
class SimilarityLogger(tf.keras.callbacks.Callback):def __init__(self,num_samples=1000,near_neighbours_per_example=9,num_examples_to_log=5):super(SimilarityLogger, self).__init__()self.samples = x_test[:num_samples]self.near_neighbours_per_example = near_neighbours_per_exampleself.num_examples_to_log = num_examples_to_logdef on_epoch_end(self, logs, epoch):embeddings = self.model.predict(self.samples)gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings)near_neighbours = np.argsort(gram_matrix.T)[:, -(self.near_neighbours_per_example + 1) :]examples = np.empty((self.num_examples_to_log,self.near_neighbours_per_example + 1,height_width,height_width,3,),dtype=np.float32,)for row_idx in range(self.num_examples_to_log):examples[row_idx, 0] = x_test[row_idx]anchor_near_neighbours = reversed(near_neighbours[row_idx][:-1])for col_idx, nn_idx in enumerate(anchor_near_neighbours):examples[row_idx, col_idx + 1] = x_test[nn_idx]for i, example in enumerate(examples):wandb.log({'query_similarity_{}'.format(i): [wandb.Image(img) for img in example]})
Experiments and Results
We will perform two sets of experiments:
- Train Metric learning model
- with varying units in the linear Dense layer(projection layer).
- Linear Evaluation
- Comparison with supervised training.
But before that let's train a model and look at the training metrics as well as the result of SimilarityLogger.
The loss metric shown below has 64 units in the embedding layer is scaled with a temperature value of 0.2.
Run set
1
Effect of the Number of Units in the Projection Layer
Let's look at the effect of the number of units in the projection layer on the training loss. For this comparative study, we will use these units:
[8, 16, 32, 64, 128, 256, 512]
Observations
- The training loss was the lowest for 16 units in the projection layer.
- The training loss for 8 units is more than that of 32 and 64 units even though 8 is closer to 16 in terms of the number of units. This can be because of the number of units in the projection layer is lower than the number of classes in the training data(10 in our case).
- The training loss goes high in the order of
- 16 < 32 < 64 < 8 < 128 < 256 < 512
- A lot of units for mere 10 classes lead to a poor loss metric.
Run set
7
Linear Evaluation
In linear evaluation, we keep the feature backbone(our simple convolutional block) trained using a given framework to be frozen and learn a linear classifier on top of it. We can implement this in the following way -
def get_linear_model(model, trainable=False):backbone_model = Model(inputs=model.inputs,outputs=model.get_layer('global_average_pooling2d').output)backbone_model.trainable = trainableinputs = Input((32,32,3))x = backbone_model(inputs, training=False)linear_layer = Dense(10, activation='softmax')(x)return Model(inputs, linear_layer)
Thus,
- We train our model(embedding) using metric learning.
- We discard the projection layer, freeze the convolutional layers and use a Dense layer after the global maxpooling layer.
- We will use the entire labeled training dataset to train our linear classifier.
For comparison we will also train the linear classifier with unfrozen convolutional layers. This is our conventional supervised image classifier.
The results are shown below.
Observations
- The model trained from scratch quickly overfitted compared to the model pre-trained using metric learning.
- The linear model is more generalized when pre-trained using metric learning as seen from the validation metrics.
Run set
2
Conclusion
Thank you for sticking to the end. Metric Learning is widely used to generate rich embedding space which can facilitate many downstream tasks. This report explored the supervised method for metric learning to get embedding which performed better for the downstream image classification task.
Recently, visual representation learning has started gaining quite a lot of attention from the research community. Contrastive loss-based learning techniques have gained a lot of traction. The reports linked below will help you get updated with the fast-paced development in this area:
These are some more resources on metric learning:
I would like to thank Mat Kelcey for this amazing tutorial on Metric learning for image similarity search. I used this tutorial as a starting point and provided my take on metric learning followed by an ablation study and linear evaluation. Hope you liked it. I will love to hear your thoughts in the comment section. 😄
Add a comment
Tags: Intermediate, Computer Vision, Classification, Keras, Experiment, Conv2D, Panels, Plots, CIFAR10
Iterate on AI agents and models faster. Try Weights & Biases today.