Concrete Autoencoder (CAE)
Basic idea & toy example
Created on October 26|Last edited on July 17
Comment
The concrete autoencoder (CAE) has been proposed in https://arxiv.org/pdf/1901.09346, is comprised of two main components: a concrete selector layer, which acts as the encoder, and a subsequent neural network that serves as the decoder.
The concrete selector layer extracts the most important features from a -dimensional input , with . This layer is modelled using independent -dimensional Gumbel-Softmax distributed variables, where the -th element of each random variable is defined as
where represents the distribution parameters (logits), is a -dimensional vector of i.i.d. samples from a Gumbel distribution, and is a temperature parameter that is annealed during training.
As , the value of converges to for the category with the highest logit, and to for all other categories. The advantage of using the Gumbel-Softmax distribution is that it allows differentiation with respect to , enabling gradient-based optimization.
By stacking all the random variables, we form the matrix , allowing us to express the latent vector as:
The latent vector then serves as the input to an arbitrary decoder neural network, whose output is used to compute a loss. During training, a linear combination of input features is selected, which, as , converges to a discrete set of features by the end of training.
Essential Code
Let's create the class ConcreteLinear: here, we initialize the logits parameter together with dropout probability.
class ConcreteLinear(nn.Module):def __init__(self,num_categories, # e.g., 28*28 for MNIST input dimensionnum_distributions=1, # e.g., 50 for number of concrete variablespi_dropout=0.0 # dropout probability):super(ConcreteLinear, self).__init__()# Store parametersself.num_categories = num_categoriesself.num_distributions = num_distributions# Initialize logits and normalize to sum to 1logits_init = torch.rand(num_distributions, num_categories)logits_init /= logits_init.sum(dim=1, keepdim=True) # Normalize to sum to 1# Define logits as a learnable parameter with initial normalizationself.logits = nn.Parameter(logits_init, requires_grad=True)# Initialize dropout layerself.pi_dropout = nn.Dropout(pi_dropout)def get_pi(self, ):logits = self.pi_dropout(self.logits)pi = F.softmax(logits, dim=1)return pi, logitsdef sample_matrix(self, temperature, random, hard=False):# Retrieve distribution parameterspi, logits = self.get_pi()# Calculate GJS divergencegjs = self.GJS_divergence(logits)if not random:# Deterministic samplingobserved_inds = torch.argmax(pi, dim=1)pi_deterministic = torch.zeros_like(pi)pi_deterministic[torch.arange(pi.shape[0]), observed_inds] = 1selector_matrix = pi_deterministicelse:# Stochastic sampling using Gumbel-Softmaxselector_matrix = F.gumbel_softmax(logits, tau=temperature, hard=hard)return selector_matrix, gjsdef forward(self, x, random, temperature, hard=False):assert x.dim() == 2, f"Expected 2D tensor, but got {x.dim()}D tensor with shape {x.shape}"selector, gjs = self.sample_matrix(temperature=temperature, random=random, hard=hard)x = F.linear(x, selector)outputs = {"latent": x,"gjs": gjs,'idx_on': torch.argmax(selector, dim=-1).data.detach().cpu().numpy()}return outputs
During training, the random parameter is True, therefore this method first computes the logits, and then uses the gumbel_softmax activation function to obtain the selector_matrix.
During testing, instead, the parameter random is False. We consider the probabilities pi, computed from the logits with the softmax, then we take the argmax for each distribution. pi_deterministic is equal to 1 in correspondence of the observed indices, and the rest is 0.
Another method that can be used to improve the learning has been proposed in https://arxiv.org/pdf/2403.00563, and consists in the computation of the Generalized Jensen–Shannon Divergence (GJSD).
The GJSD for categorical distributions and weights is given by:
As the Gumbel-Softmax distributions can be approximated as categorical, we want to compute for the probabilities obtained after applying the softmax to the logits. Intuitively, maximizing the GJSD facilitates learning diverse Gumbel-Softmax distributions that converge to distinct features.
For the implementation, we set for . Essentially, by setting the GJSD measures how the distributions collectively deviate from their average .
def GJS_divergence(self, logits):K = self.num_distributionsw = torch.ones([K, 1]).type_as(logits) / Kpi = F.softmax(logits, dim=1)log_pi = F.log_softmax(logits, dim=1)d_gls = torch.sum(w* torch.sum(pi * (log_pi - torch.log(torch.sum(w * pi, dim=0).repeat([K, 1]))),dim=1,keepdim=True,),dim=0,)return d_gls
CAE
This class implements a asimple CAE with fully connected layers.
class CAE(nn.Module):def __init__(self, input_dim=77, decoder_hiddens=[], dropout=0.0, k=50):super(CAE, self).__init__()self.encoder = ConcreteLinear(num_categories=input_dim, num_distributions=k)decoder_hiddens = [k] + decoder_hiddens + [input_dim]nets_dec = []for i in range(len(decoder_hiddens) - 1):nets_dec += [nn.Linear(decoder_hiddens[i], decoder_hiddens[i + 1])]if i < len(decoder_hiddens) - 2:nets_dec += [nn.LeakyReLU(0.2)]nets_dec += [nn.Dropout(dropout)]self.decoder = nn.Sequential(*nets_dec)def forward(self, x, temperature, random):outputs = self.encoder(x, temperature=temperature, random=random)x = self.decoder(outputs["latent"])returns = {'X_rec': x, 'GJS': outputs['gjs'], 'mask': outputs['idx_on']}return returns
Temperature Callback
This class is needed because we need to adjust the temperature of the Gumbel-Softmax during training.
class TemperatureCallback:def __init__(self, num_epochs, temp_base, temp_min, stop_anneal=0, warmup_epochs=0, mode="exp"):self.num_epochs = num_epochsself.temp_base = temp_baseself.temp_min = temp_minself.stop_anneal = stop_annealself.warmup_epochs = warmup_epochsself.mode = modedef get_current_value(self, epoch):if epoch < self.warmup_epochs:temp = self.temp_baseelse:adj_epoch = epoch - (self.warmup_epochs + 1)adj_num_epochs = self.num_epochs - (self.warmup_epochs + 1)if self.mode == "exp":if self.stop_anneal > 0 and adj_epoch > self.stop_anneal:return self.temp_base * (self.temp_min / self.temp_base) ** (self.stop_anneal / adj_num_epochs)temp = self.temp_base * (self.temp_min / self.temp_base) ** (adj_epoch / adj_num_epochs)elif self.mode == "linear":k = (self.temp_min - self.temp_base) / (self.num_epochs - self.warmup_epochs)if self.stop_anneal > 0 and adj_epoch > self.stop_anneal:return k * self.stop_anneal + self.temp_basetemp = k * adj_epoch + self.temp_basetemp = max(temp, self.temp_min)return tempdef on_epoch_start(self, epoch):temp = self.get_current_value(epoch)return temp
Training loop
for epoch in range(1, config.num_epochs + 1):model.train() # Ensure model is in training mode# Training phasefor batch_idx, (data, _) in enumerate(train_loader):temp = temperature_callback.on_epoch_start(epoch) # Update temperaturedata = data.view(-1, 28 * 28).to(device) # Move data to deviceoptimizer.zero_grad()returns = model(data, random=True, temperature=temp)gjs = returns['GJS']X_rec = returns['X_rec']loss = F.mse_loss(X_rec, data, reduction="mean") - config.gjs_factor * gjsloss.backward()optimizer.step()
Experiment with MNIST
We have trained the CAE by considering the MNIST dataset for epochs, a GJSD factor equal to , and equal to . Here, we show the selected features and the reconstructed image i) before training, ii) after epochs, and iii) after epochs. In particular, after the first epochs, the CAE has already learnt the most important features, which are slightly refined by the epoch.
Run set
1
Add a comment