Conditional Variational Autoencoders
Histopathological imaging
Created on June 16|Last edited on June 17
Comment
Introduction
Running some experiments using a Conditional Variational Autoencoder (CVAE) on CIFAR10 in order to better understand the methods.
- Implementation of a CVAE in keras for RGB images inputs
- Implementation of a function for Activation layers visualization (troubleshooting + explainability). For example, this was useful in identifying the so called "Dying ReLU Problem" which was evident by the plots.
- Implementation of a function for Filters visualization can be obtained by gradient ascent starting from a noise matrix. It should be useful for assessing the patterns identified by the layers.
NB: the sliders allow to look at different layers
Filters
Activations
Architecture and example run
- Architecture:
Model: "encoder"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
Input (InputLayer) [(None, 32, 32, 13) 0 []
]
block1_conv1 (Conv2D) (None, 32, 32, 16) 1888 ['Input[0][0]']
block1_conv2 (Conv2D) (None, 32, 32, 16) 2320 ['block1_conv1[0][0]']
batch_normalization (BatchNorm (None, 32, 32, 16) 64 ['block1_conv2[0][0]']
alization)
leaky_re_lu (LeakyReLU) (None, 32, 32, 16) 0 ['batch_normalization[0][0]']
block2_conv1 (Conv2D) (None, 32, 32, 32) 4640 ['leaky_re_lu[0][0]']
block2_conv2 (Conv2D) (None, 32, 32, 32) 9248 ['block2_conv1[0][0]']
batch_normalization_1 (BatchNo (None, 32, 32, 32) 128 ['block2_conv2[0][0]']
rmalization)
leaky_re_lu_1 (LeakyReLU) (None, 32, 32, 32) 0 ['batch_normalization_1[0][0]']
S4 (MaxPooling2D) (None, 16, 16, 32) 0 ['leaky_re_lu_1[0][0]']
block3_conv1 (Conv2D) (None, 16, 16, 64) 18496 ['S4[0][0]']
block3_conv2 (Conv2D) (None, 16, 16, 64) 36928 ['block3_conv1[0][0]']
batch_normalization_2 (BatchNo (None, 16, 16, 64) 256 ['block3_conv2[0][0]']
rmalization)
leaky_re_lu_2 (LeakyReLU) (None, 16, 16, 64) 0 ['batch_normalization_2[0][0]']
flatten (Flatten) (None, 16384) 0 ['leaky_re_lu_2[0][0]']
dense (Dense) (None, 1024) 16778240 ['flatten[0][0]']
mu (Dense) (None, 512) 524800 ['dense[0][0]']
log_var (Dense) (None, 512) 524800 ['dense[0][0]']
==================================================================================================
Total params: 17,901,808
Trainable params: 17,901,584
Non-trainable params: 224
Layer (type) Output Shape Param #
=================================================================
decoder_input (InputLayer) [(None, 522)] 0
dense_3 (Dense) (None, 16384) 8568832
reshape (Reshape) (None, 16, 16, 64) 0
batch_normalization_3 (Batc (None, 16, 16, 64) 256
hNormalization)
leaky_re_lu_3 (LeakyReLU) (None, 16, 16, 64) 0
up_block4_conv1 (Conv2DTran (None, 16, 16, 64) 36928
spose)
up_block4_conv2 (Conv2DTran (None, 16, 16, 64) 36928
spose)
batch_normalization_4 (Batc (None, 16, 16, 64) 256
hNormalization)
leaky_re_lu_4 (LeakyReLU) (None, 16, 16, 64) 0
up_block5_conv1 (Conv2DTran (None, 16, 16, 32) 18464
spose)
up_block5_conv2 (Conv2DTran (None, 16, 16, 32) 9248
spose)
batch_normalization_5 (Batc (None, 16, 16, 32) 128
hNormalization)
leaky_re_lu_5 (LeakyReLU) (None, 16, 16, 32) 0
up_sampling2d (UpSampling2D (None, 32, 32, 32) 0
)
up_block6_conv1 (Conv2DTran (None, 32, 32, 16) 4624
spose)
up_block6_conv2 (Conv2DTran (None, 32, 32, 16) 2320
spose)
batch_normalization_6 (Batc (None, 32, 32, 16) 64
hNormalization)
leaky_re_lu_6 (LeakyReLU) (None, 32, 32, 16) 0
conv2d_transpose (Conv2DTra (None, 32, 32, 3) 195
nspose)
=================================================================
Total params: 8,678,243
Trainable params: 8,677,891
Non-trainable params: 352
- Training: Some sign of overfitting in this case
- Visualize the (projected) embedding and conditional embedding for each class: This run resulted in a poor embedding, notice that the kl_coefficient is equal to 0.001, so it is expected as we are approaching an autoencoder.
- Visualize reconstructions and generations:
Grid search
Clearly in the example above the reconstruction error is low enough, but the embedding is very poor and therefore also the generations.
I performed a grid search of the encoded dimensions (128, 256, 512) and beta coefficients for the KL loss (0.001, 0.01, 0.1, 0.5, 0.8).
NB: I'm not sure it makes a lot of sense to compare the loss functions of runs with different kl_coefficients as this directly impact the loss by rescaling (i.e. runs with kl_coefficient close to zero will naturally have a lower loss value)
Nonetheless, we can assess the impact of the parameter controlling the encoding dimension, which might be lower than I expected. I report the metrics and some visualizations.
Unfortunately, none of the runs resulted in a successfull generation. Maybe a more careful grid search could help find an appropriate balance between reconstructions and regularization but I think it might be a more fundamental issue.
TO DO:
- Explore more the parameter space of "kl_coefficient" with a fixed encoding dimension
- Adding more layers
- Adding residual blocks such as seen in the ResNet architecture
- Consider Hierarchical Variational Autoencoders
- Consider Variational Autoencers + Probabilistic Denoising Diffusion Models (https://arxiv.org/abs/2106.05931)
Add a comment

