Skip to main content

Concrete Autoencoder (CAE)

Basic idea & toy example
Created on October 26|Last edited on July 17
The concrete autoencoder (CAE) has been proposed in https://arxiv.org/pdf/1901.09346, is comprised of two main components: a concrete selector layer, which acts as the encoder, and a subsequent neural network that serves as the decoder.
The concrete selector layer extracts the kk most important features from a dd-dimensional input x\boldsymbol{x}, with k<dk < d. This layer is modelled using kk independent dd-dimensional Gumbel-Softmax distributed variables, where the jj-th element of each random variable is defined as
mj=exp((logαj+gj)/T)i=1dexp((logαi+gi)/T),m_j = \frac{\exp((\log\alpha_j + g_j)/T)}{\sum_{i=1}^d \exp((\log \alpha_i + g_i)/T)},

where logαRd\log \boldsymbol{\alpha} \in \mathbb{R}^d represents the distribution parameters (logits), g\boldsymbol{g} is a dd-dimensional vector of i.i.d. samples from a Gumbel distribution, and TR+T \in \mathbb{R}_{+} is a temperature parameter that is annealed during training.
As T0T \rightarrow 0, the value of mjm_j converges to 11 for the category with the highest logit, and to 00 for all other categories. The advantage of using the Gumbel-Softmax distribution is that it allows differentiation with respect to α\boldsymbol{\alpha}, enabling gradient-based optimization.
By stacking all the random variables, we form the matrix MRk×d\boldsymbol{M} \in \mathbb{R}^{k \times d}, allowing us to express the latent vector as:
zCAE=Mx.\boldsymbol{z}_{\text{CAE}} = \boldsymbol{M} \boldsymbol{x}.

The latent vector zCAE\boldsymbol{z}_{\text{CAE}} then serves as the input to an arbitrary decoder neural network, whose output is used to compute a loss. During training, a linear combination of input features is selected, which, as T0T \rightarrow 0, converges to a discrete set of kk features by the end of training.

Essential Code

Let's create the class ConcreteLinear: here, we initialize the logits parameter together with dropout probability.
class ConcreteLinear(nn.Module):
def __init__(
self,
num_categories, # e.g., 28*28 for MNIST input dimension
num_distributions=1, # e.g., 50 for number of concrete variables
pi_dropout=0.0 # dropout probability
):
super(ConcreteLinear, self).__init__()

# Store parameters
self.num_categories = num_categories
self.num_distributions = num_distributions

# Initialize logits and normalize to sum to 1
logits_init = torch.rand(num_distributions, num_categories)
logits_init /= logits_init.sum(dim=1, keepdim=True) # Normalize to sum to 1

# Define logits as a learnable parameter with initial normalization
self.logits = nn.Parameter(logits_init, requires_grad=True)

# Initialize dropout layer
self.pi_dropout = nn.Dropout(pi_dropout)

def get_pi(self, ):
logits = self.pi_dropout(self.logits)
pi = F.softmax(logits, dim=1)
return pi, logits

def sample_matrix(self, temperature, random, hard=False):
# Retrieve distribution parameters
pi, logits = self.get_pi()

# Calculate GJS divergence
gjs = self.GJS_divergence(logits)

if not random:
# Deterministic sampling
observed_inds = torch.argmax(pi, dim=1)
pi_deterministic = torch.zeros_like(pi)
pi_deterministic[torch.arange(pi.shape[0]), observed_inds] = 1
selector_matrix = pi_deterministic
else:
# Stochastic sampling using Gumbel-Softmax
selector_matrix = F.gumbel_softmax(logits, tau=temperature, hard=hard)

return selector_matrix, gjs

def forward(self, x, random, temperature, hard=False):
assert x.dim() == 2, f"Expected 2D tensor, but got {x.dim()}D tensor with shape {x.shape}"

selector, gjs = self.sample_matrix(temperature=temperature, random=random, hard=hard)
x = F.linear(x, selector)

outputs = {
"latent": x,
"gjs": gjs,
'idx_on': torch.argmax(selector, dim=-1).data.detach().cpu().numpy()
}
return outputs


During training, the random parameter is True, therefore this method first computes the logits, and then uses the gumbel_softmax activation function to obtain the selector_matrix.
During testing, instead, the parameter random is False. We consider the probabilities pi, computed from the logits with the softmax, then we take the argmax for each distribution. pi_deterministic is equal to 1 in correspondence of the observed indices, and the rest is 0.

Another method that can be used to improve the learning has been proposed in https://arxiv.org/pdf/2403.00563, and consists in the computation of the Generalized Jensen–Shannon Divergence (GJSD).
The GJSD for KK categorical distributions {pi}i=1K\{ \boldsymbol{p}_i \}_{i=1}^K and weights w\boldsymbol{w} is given by:
DGJS({pi}i=1K)=i=1KwiDKL(pij=1Kwjpj)D_{GJS}(\{\boldsymbol{p}_i\}_{i=1}^K) = \sum_{i=1}^K w_i D_{KL}(\boldsymbol{p}_i \lvert \lvert \sum_{j=1}^K w_j \boldsymbol{p}_j)

As the Gumbel-Softmax distributions can be approximated as categorical, we want to compute DGJSD_{GJS} for the probabilities obtained after applying the softmax to the logits. Intuitively, maximizing the GJSD facilitates learning diverse Gumbel-Softmax distributions that converge to distinct features.
For the implementation, we set wi=1Kw_i = \frac{1}{K} for i=1,,Ki=1,\dots,K. Essentially, by setting wi=1/Kw_i = 1/K the GJSD measures how the distributions p1,pK\boldsymbol{p}_1, \dots \boldsymbol{p}_K collectively deviate from their average 1Kj=1Kpj\frac{1}{K}\sum_{j=1}^K \boldsymbol{p}_j.
def GJS_divergence(self, logits):
K = self.num_distributions
w = torch.ones([K, 1]).type_as(logits) / K

pi = F.softmax(logits, dim=1)
log_pi = F.log_softmax(logits, dim=1)

d_gls = torch.sum(
w
* torch.sum(
pi * (log_pi - torch.log(torch.sum(w * pi, dim=0).repeat([K, 1]))),
dim=1,
keepdim=True,
),
dim=0,
)
return d_gls


CAE

This class implements a asimple CAE with fully connected layers.
class CAE(nn.Module):
def __init__(self, input_dim=77, decoder_hiddens=[], dropout=0.0, k=50):
super(CAE, self).__init__()

self.encoder = ConcreteLinear(num_categories=input_dim, num_distributions=k)

decoder_hiddens = [k] + decoder_hiddens + [input_dim]
nets_dec = []
for i in range(len(decoder_hiddens) - 1):
nets_dec += [nn.Linear(decoder_hiddens[i], decoder_hiddens[i + 1])]
if i < len(decoder_hiddens) - 2:
nets_dec += [nn.LeakyReLU(0.2)]
nets_dec += [nn.Dropout(dropout)]
self.decoder = nn.Sequential(*nets_dec)


def forward(self, x, temperature, random):

outputs = self.encoder(x, temperature=temperature, random=random)

x = self.decoder(outputs["latent"])

returns = {'X_rec': x, 'GJS': outputs['gjs'], 'mask': outputs['idx_on']}

return returns

Temperature Callback

This class is needed because we need to adjust the temperature of the Gumbel-Softmax during training.
class TemperatureCallback:
def __init__(self, num_epochs, temp_base, temp_min, stop_anneal=0, warmup_epochs=0, mode="exp"):
self.num_epochs = num_epochs
self.temp_base = temp_base
self.temp_min = temp_min
self.stop_anneal = stop_anneal
self.warmup_epochs = warmup_epochs
self.mode = mode

def get_current_value(self, epoch):
if epoch < self.warmup_epochs:
temp = self.temp_base
else:
adj_epoch = epoch - (self.warmup_epochs + 1)
adj_num_epochs = self.num_epochs - (self.warmup_epochs + 1)

if self.mode == "exp":
if self.stop_anneal > 0 and adj_epoch > self.stop_anneal:
return self.temp_base * (self.temp_min / self.temp_base) ** (
self.stop_anneal / adj_num_epochs
)
temp = self.temp_base * (self.temp_min / self.temp_base) ** (
adj_epoch / adj_num_epochs
)
elif self.mode == "linear":
k = (self.temp_min - self.temp_base) / (
self.num_epochs - self.warmup_epochs
)
if self.stop_anneal > 0 and adj_epoch > self.stop_anneal:
return k * self.stop_anneal + self.temp_base
temp = k * adj_epoch + self.temp_base

temp = max(temp, self.temp_min)
return temp

def on_epoch_start(self, epoch):
temp = self.get_current_value(epoch)
return temp



Training loop

for epoch in range(1, config.num_epochs + 1):
model.train() # Ensure model is in training mode

# Training phase
for batch_idx, (data, _) in enumerate(train_loader):
temp = temperature_callback.on_epoch_start(epoch) # Update temperature
data = data.view(-1, 28 * 28).to(device) # Move data to device

optimizer.zero_grad()
returns = model(data, random=True, temperature=temp)

gjs = returns['GJS']
X_rec = returns['X_rec']
loss = F.mse_loss(X_rec, data, reduction="mean") - config.gjs_factor * gjs
loss.backward()
optimizer.step()


Experiment with MNIST


We have trained the CAE by considering the MNIST dataset for 200200 epochs, a GJSD factor equal to 0.0050.005, and kk equal to 5050. Here, we show the selected features and the reconstructed image i) before training, ii) after 5050 epochs, and iii) after 150150 epochs. In particular, after the first 5050 epochs, the CAE has already learnt the most important features, which are slightly refined by the 150150 epoch.


Run set
1