Metric Learning for Image Search

In this report, we will explore supervised metric learning and extend it to image similarity search. Made by Ayush Thakur using Weights & Biases
Ayush Thakur


Metric learning is a broad field with many definitions to define it. Primarily, it aims to measure the similarity among data samples and to learn embedding models. In a familiar classification setting, we give our model some $X$ and learn to predict its class.

In the context of metric learning to learn embedding models, the motivation is to embed $X's$ in an embedding space such that similar $X's$ are close together in that space while dissimilar ones are far away. We are often not interested in how the embedding space looks as long as the $X's$ we want to be close together(similar) form a cluster in that space.

Euclidean distance is a popular distance metric. One can argue that given images, we can represent it into vectors(abstract features) using a pre-trained image classifier and use euclidean distance to separate features. However, most practical data is not linear and requires task and dataset-specific distance metric. Thus metric learning aims at automatically constructing task-specific distance metrics.

The field of metric learning is incredibly important and useful because the distance metric/embeddings learned can be useful for many downstream tasks. In literature, metric learning can be tied to model pre-training.

Metric learning falls under three categories:

In this report, we explore supervised metric learning and extend the same for image search.

Try out metric learning in Google Colab $\rightarrow$

The Dataset

For simplicity, we will be using the CIFAR-10 dataset. There are 10 classes as mentioned by CLASS_NAMES.

from tensorflow.keras.datasets import cifar10

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

CLASS_NAMES = ["airplane", "automobile", "bird", "cat", "deer", 
               "dog", "frog", "horse", "ship", "truck"]

x_train = x_train.astype("float32") / 255.0
y_train = np.squeeze(y_train)
x_test = x_test.astype("float32") / 255.0
y_test = np.squeeze(y_test)


-> Figure 1: Samples from CIFAR-10 dataset. <-

Dataset Preparation

In our supervised metric learning setting we do not want explicit $(X, y)$ pairs where $y$ is the label for corresponding $X$. However, we want multiple instances(pairs) of $X$ which are related such that they express semantic similarity. thus one training instance will be a pair of images and not one single image. The pair of images belong to the same class as guided by $y$.

When referring to the images in this pair of images, we use the common term of anchor(a randomly sampled image from the dataset) and positives(another randomly sampled image of the same class). Thus each training data sample consists of an anchor and a positive pair.

The code snippet shown below builds a lookup table that maps from classes to the instances of that class.

class_idx_to_train_idxs = defaultdict(list)
for y_train_idx, y in enumerate(y_train):

class_idx_to_test_idxs = defaultdict(list)
for y_test_idx, y in enumerate(y_test):


The dataloader will produce batches of (anchor, positive) pairs spread across the classes(10 classes in our case).

Since the goal is to learn a metric that will bring images from the same class closer to each other in the embedding space and push away images from different classes. Each batch will have one pair of (anchor, positive) from each class. Thus our batch size will be 10 in the case of CIFAR-10.

num_classes = 10
height_width = 32

# Ref:
class AnchorPositivePairs(keras.utils.Sequence):
    def __init__(self, num_batchs):
        self.num_batchs = num_batchs

    def __len__(self):
        return self.num_batchs

    def __getitem__(self, _idx):
        x = np.empty((2, num_classes, height_width, height_width, 3), dtype=np.float32)
        for class_idx in range(num_classes):
            examples_for_class = class_idx_to_train_idxs[class_idx]
            anchor_idx = random.choice(examples_for_class)
            positive_idx = random.choice(examples_for_class)
            while positive_idx == anchor_idx:
                positive_idx = random.choice(examples_for_class)
            x[0, class_idx] = x_train[anchor_idx]
            x[1, class_idx] = x_train[positive_idx]
        return x


-> Figure 2: One batch generated by AnchorPositivePairs. The first-row shows anchor images and the second-row belongs to randomly chosen positive images. <-

The Model

Now that our input pipeline is ready, let's build a model architecture suited for the task. Since the objective is to learn a metric to bring together images from the same class in the embedding space we will first pass the image anchor and its positive image through the convolutional block(one at a time) to get the feature representation of the anchor image and its positive pair. This is then projected using a linear layer(without activation) and normalized so that we can use simple dot products to measure similarity.

We will use a simple convolutional feature extractor for the sake of simplicity.

def get_model():
  inputs = Input(shape=(height_width, height_width, 3))
  x = Conv2D(filters=32, kernel_size=3, strides=2, activation="relu")(inputs)
  x = Conv2D(filters=64, kernel_size=3, strides=2, activation="relu")(x)
  x = Conv2D(filters=128, kernel_size=3, strides=2, activation="relu")(x)
  x = GlobalAveragePooling2D()(x)
  embeddings = Dense(units=64, activation=None)(x) # here units can be hyperparameter
  embeddings = tf.nn.l2_normalize(embeddings, axis=-1)

  return EmbeddingModel(inputs, embeddings)

The EmbeddingModel encapsulated the training logic for our metric learning task. It implements a custom model with train_step.

# ref:
class EmbeddingModel(keras.Model):
    def train_step(self, data):

        if isinstance(data, tuple):
            data = data[0]
        anchors, positives = data[0], data[1]

        with tf.GradientTape() as tape:
            # Run both anchors and positives through model.
            anchor_embeddings = self(anchors, training=True)
            positive_embeddings = self(positives, training=True)

            # Calculate cosine similarity between anchors and positives.
            similarities = tf.einsum(
                "ae,pe->ap", anchor_embeddings, positive_embeddings

            # Scale using temperature. Temperatue is a hyperparameter.
            temperature = 0.2
            similarities /= temperature

            # Compute loss.
            sparse_labels = tf.range(num_classes)
            loss = self.compiled_loss(sparse_labels, similarities)

        # Calculate gradients and apply via optimizer.
        gradients = tape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

        # Update and return metrics (specifically the one for the loss value).
        self.compiled_metrics.update_state(sparse_labels, similarities)
        return { m.result() for m in self.metrics}

The SimilarityLogger - Image query search

We will use this custom Keras callback to interactively visualize images that are semantically similar. near_neighbours_per_examples controls the number of examples that will be logged along with the query image and are semantically similar. num_examples_to_log is the number of such query-similar images.

One can see this callback as a naive implementation for image similarity search.

class SimilarityLogger(tf.keras.callbacks.Callback):
    def __init__(self, 
        super(SimilarityLogger, self).__init__()
        self.samples = x_test[:num_samples]
        self.near_neighbours_per_example = near_neighbours_per_example
        self.num_examples_to_log = num_examples_to_log

    def on_epoch_end(self, logs, epoch):
        embeddings = self.model.predict(self.samples)
        gram_matrix = np.einsum("ae,be->ab", embeddings, embeddings)
        near_neighbours = np.argsort(gram_matrix.T)[:, -(self.near_neighbours_per_example + 1) :]

        examples = np.empty(
                self.near_neighbours_per_example + 1,
        for row_idx in range(self.num_examples_to_log):
            examples[row_idx, 0] = x_test[row_idx]
            anchor_near_neighbours = reversed(near_neighbours[row_idx][:-1])
            for col_idx, nn_idx in enumerate(anchor_near_neighbours):
                examples[row_idx, col_idx + 1] = x_test[nn_idx]

        for i, example in enumerate(examples):
            wandb.log({'query_similarity_{}'.format(i): [wandb.Image(img) for img in example]})

Experiments and Results

We will perform two sets of experiments:

But before that let's train a model and look at the training metrics as well as the result of SimilarityLogger.

The loss metric shown below has 64 units in the embedding layer is scaled with a temperature value of 0.2.

Section 7

Effect of the number of units in the projection layer

Let's look at the effect of the number of units in the projection layer on the training loss. For this comparative study, we will use these units:

[8, 16, 32, 64, 128, 256, 512]


Section 9

Linear Evaluation

In linear evaluation, we keep the feature backbone(our simple convolutional block) trained using a given framework to be frozen and learn a linear classifier on top of it. We can implement this in the following way -

def get_linear_model(model, trainable=False):
    backbone_model = Model(inputs=model.inputs,
    backbone_model.trainable = trainable

    inputs = Input((32,32,3)) 
    x = backbone_model(inputs, training=False)

    linear_layer = Dense(10, activation='softmax')(x)
    return Model(inputs, linear_layer)


For comparison we will also train the linear classifier with unfrozen convolutional layers. This is our conventional supervised image classifier.

The results are shown below.


Section 13


Thank you for sticking to the end. Metric Learning is widely used to generate rich embedding space which can facilitate many downstream tasks. This report explored the supervised method for metric learning to get embedding which performed better for the downstream image classification task.

Recently, visual representation learning has started gaining quite a lot of attention from the research community. Contrastive loss based learning techniques have gained a lot of traction. The reports linked below will help you get updated with the fast-paced development in this area:

These are some more resources on metric learning:

I would like to thank Mat Kelcey for this amazing tutorial on Metric learning for image similarity search. I used this tutorial as a starting point and provided my take on metric learning followed by ablation study and linear evaluation. Hope you liked it. I will love to hear your thoughts in the comment section. :smile: