Bad Global Minima Exist and SGD Can Reach Them

Reproducibility Challenge 2021 Weights & Biases submission for the paper "Bad Global Minima Exist and SGD Can Reach Them" by Liu et al.
Gustavo Sutter P. Carvalho

Bruno Gomes Coelho - bruno.coelho@nyu.edu - New York University & Data

Gustavo Sutter P. Carvalho - gsutterp@uwaterloo.ca - University of Waterloo & Data

João Araújo - joaogui1@cohere.ai - Cohere & Data

João Marcos Cardoso da Silva - joaomarcoscsilva@usp.br - Universidade de São Paulo & Data

João Pedro Rodrigues Mattos - joao_pedro_mattos@usp.br - Universidade de São Paulo & Data

(Ordering is alphabetical, and does not indicate the amount of contribution)

Original Paper | Original code | Our code


Reproducibility Summary

Scope of Reproducibility

The authors suggest through empirical pieces of evidence that it is possible to obtain a model initialization that leads to a perfect fit of the training data while having poor generalization. This phenomenon is called bad minimum by the authors throughout the paper and is constructed by randomly assigning labels to the training data. In this report, we reproduce the core claims in the paper and use Weights and Biases' tracking tools to gain insight into the robustness of the original experiments.

Methodology

We implement from scratch the main claims of the paper using Jax, with the author's original PyTorch implementation as a handy guide. Since the claims rely on weights initialization and randomness and not all code was available in the original codebase, our implementations intend to check not only the reproducibility but also the assumptions and robustness of the results presented.

Results

We verify the main claim of the paper that there are bad minimum points that can severely impact the model generalization and that these points can be reached by adversarial training.

What was easy

Understanding the concepts involved and coding the models and experiments, while time-consuming, was fairly straightforward and helped by the general availability of the author's code.
Filtering, analyzing, and plotting the results was made easier by logging everything to Weights & Biases and using their built-in report functionality.

What was difficult

One of our main difficulties was finding a good subset of the experiments that accurately represent the main claims in the paper while balancing runtime costs.
The original paper presents a wide range of experiments related to major and minor claims, across 4 datasets and 4 models. Even after specifying a subset of experiments and a single dataset and model, our runtime is still very costly due to the needing 35 different trained models (7 models across 5 seeds). Luckily, since our codebase was in Jax and not PyTorch, we were able to finish all the computations in time thanks to switching from GPU to TPU support without any hassle.
Minor difficulties also include the initialization rabbit-hole since every framework initializes weights with different standards and general project management across 5 collaborators in different parts of the world.

Communication with the original authors

We did not contact the authors during the Reproducibility Challenge 2021 period, although we did find their source code very helpful in answering specific implementation details.

Report

1. Introduction

This report is an attempt to validate some of the experiments from the paper "Bad Global Minima Exist and SGD Can Reach Them" (Liu et al. 2020), from NeurIPS 2020. From the various pieces of evidence the authors collect, we reproduce one of the main experiments, which investigates model initializations that are able to perfectly fit the training data while generalizing poorly, a phenomenon called "bad minimum" by the authors. Throughout this report, we compare our results with the ones presented in the paper and conclude that, apart from a single plot, all the results are consistent, even with implementations based in other libraries, such as Jax, used in our experiments.
In Section 2 we discuss the scope of our work and how it connects to the original paper and in Section 3 we discuss the methodology used. Section 4 presents our results while Section 5 discusses various implementations details and other considerations pertaining to this project. Finally, we conclude this report in Section 6.

2. Scope of Reproducibility

The main procedure behind all the experiments in this involves: fitting a model to a randomly labeled dataset - this is called "adversarial initialization; Using these weights as initializations to two new training sessions with the correct labels, differing by the fact that one of these sessions uses only a simple SGD optimizer, while the other combines the same optimizer with data augmentation and l_2 regularization.
To further extend the scope of reproducibility, all experiments were reimplemented in Jax instead of running the authors' original code, written in PyTorch, confirming the robustness of the results even with distinct Deep Learning libraries implementations.
Figure 1 from the original paper - Shows the four different settings the authors investigate throughout this work. We reproduced all these scenarios using ResNet18 as a model and the CIFAR 10 dataset.

2.1 Report limitations

In this report, due to hardware limitations, we reproduced only one of the various model-dataset combinations presented in the paper. We ran a ResNet18 with the CIFAR10 dataset in all four settings presented by the authors in Figure 1 of the original paper and obtained similar results, supporting the main claims of this work. Originally, the authors tested the same experiment on four different architectures (VGG 16, ResNet18, ResNet50, and DenseNet40) and four datasets (CIFAR10, CIFAR100, CINIC10, and Restricted ImageNet).
Still, due to runtime cost, we were unable to explore all eight combinations of the three heuristics displayed in Table 2 of the original paper, which tests the contributions of each heuristic to the generalization improvement. In our experiments, we tested "Vanilla SGD" and the "DA + l_2 + Momentum (SOTA)" combinations.
Table 2 from the original paper - The authors tested their claims in eight different combinations of the three heuristics (DA, l_2 , and Momentum). Unfortunately, we could not reproduce all these experiments due to hardware constraints; only "Vanilla" and "DA + l_2 + Momentum (SOTA)" combinations were reproduced.

3. Methodology

For this paper, we found the author's original code on GitHub in PyTorch and reimplemented their experiments using Jax. That is motivated by the fact that this work is almost entirely based on weights initialization and randomness. These basic procedures are implemented differently in each Deep Learning library, thus our goal was to have as much control as possible over those factors, to not only reproduce the results but also to check their robustness.
The main claims of the paper rely on plots of different metrics tracked during the training sessions of each of the many experiments conducted. In this sense, this work is an excellent showcase of the advantages of using a tool such as Weights and Biases to experiment tracking, weight storage, and many other features. Our experiments were originally conducted on a Tesla P100 16GB provided by Google Colaboratory, which ended up having a high runtime cost. Towards the end of the project, a TPUv3 on GCP was used to allow us to finish all experiments.

4. Results

We reimplemented and executed their experiments on training accuracy, test accuracy, and the distance traveled during training (Sections 3.1, 3.2, and 3.3, respectively). More precisely, we verified their results using the ResNet18 architecture and the CIFAR10 dataset.
All of our experiments follow the 5 trials (each one with a different random seed) scheme used on the paper.

4.1 Training and test accuracy

This experiment consists of tracking metrics at each epoch of training, which is a task well suited for the WandB library. The following image is taken from the paper, demonstrating the results reported by the author. The panel below it presents our results on the same datasets, where trials are grouped by their type.

Our experimental results closely follows the original, supporting the authors claims that bad local minimum exists and they can be reached by performing adversarial training.

5. Discussion

Below we discuss various implementations details and observations relevant to our experiments and findings.

5.1 Frobenius Norm

In the original paper, the Frobenius norm (alongside L1 and L2 path norm) is used as a proxy of model complexity. The Frobenious norm is an extension of the Euclidean norm to arbitrary matrices of the form A^{n, n} , defined as \|A\|_\text{F} = \sqrt{\sum_{i=1}^n \sum_{j=1}^n |a_{ij}|^2}
During our experiments, while we seem to observe a general trend similar to the one presented by the authors, our measurements are in a widely different scale - we found values in the range of 10^5, while the authors observe values in the order of 10^{50} for a ResNet18 trained on CIFAR10 (Figure 10). Even removing the square root from our definition does not yield values so big and we are unsure how numerical libraries could even support them. Below we present both our values and the authors original plot. We edit the original to include the scale on the left side while removing other metrics and follow their own range maximum range style (approximately 3x the observed value) for better comparision.
For the L1 and L2 path norms, we did not observe results consistent with the original paper - This can be due to multiple reasons, including implementation details, the non-linear scale used in the original plots or Jax specific implementation differences. We show our findings below.

5.2 Global vs local minimum

The term "global" minimum is generally used to represent the lowest possible value of our loss over its landscape - although our adversarial methods do overfit to achieve 100% train accuracy, we believe these "minimums" are still unknown if they're global minimum in regards to our loss and therefore "local minimum" would be a better representation.

5.3 Data augmentation

In Algorithm 1 in the paper, the authors mention
While following their implementation, we notice that a random RGB channel of a pixel is zero-out, but not the pixel itself. This tends to create images with red/blue/green patches, instead of black patches as originally desired. The code of the author's implementation responsible for zeroing out the pixels can be found here.
Although we noticed this small discrepancy, we followed the author's implementation.

6. Conclusion

Our experiments verify the main claim of the paper, that is, the effect of adversary initialization on model generalization. Re-implementing their experiments from scratch using a different framework resulted in analogous results and conclusions as to the original paper, demonstrating their robustness.