Skip to main content

MNIST classification with MPS

We revisit the MNIST classification problem with matrix product states (MPS), with the addition of tuning the augmentation procedure. We achieve a 0.39% error with an ensemble of 10 MPS classifiers with bond dimension at least 80. Best single model error rate is 0.5%.
Created on September 17|Last edited on November 8

Model and initial conditions

We used a standard multi-linear matrix product state (MPS). The model can be decomposed into two stages. First is the embedding of a feature vector xx in an exponentially large Hilbert space. We chose a linear local embedding xψ(x)=[1xx]x \rightarrow \psi(x)=\begin{bmatrix} 1-x \\ x \end{bmatrix} that assigns to each point xi[0,1]x_i \in[0,1] an L1L_1 normalized vector ψxiR2\psi_{x_i} \in R_2.
Example: Let's consider the embedding of the feature vector x=(0.6,0.2)x=(0.6,0.2)
x=(0.6,0.2)Ψ(x)=ψ(0.6)ψ(0.2)=[0.40.6][0.80.2]=[0.320.080.480.12]x=(0.6,0.2)\rightarrow \Psi(x) =\psi(0.6)\otimes\psi(0.2) = \begin{bmatrix} 0.4 \\ 0.6 \end{bmatrix}\otimes\begin{bmatrix} 0.8 \\ 0.2 \end{bmatrix} =\begin{bmatrix} 0.32 \\ 0.08 \\ 0.48 \\ 0.12 \end{bmatrix}

💡
After that, a linear classification is performed as a maximum of the scalar products of the embedded vector with matrix product state vectors representing different classes. We choose the following matrix product state representation for a class cc
Ψc=1Ds1,s2,sn=1ntr(A1s1A2s2An/2sn/2BcAn/2+1sn/2+2Ansn)ψ(s1)ψ(s2)ψ(sn),\Psi_c=\frac{1}{D}\sum_{s_1,s_2,\ldots s_n=1}^n\mathrm{tr} (A_1^{s_1}\cdot A_2^{s_2}\ldots\cdot A_{n/2}^{s_{n/2}}B^c \cdot A_{n/2+1}^{s_{n/2+2}} \ldots A_n^{s_n})\psi(s_1)\otimes \psi(s_2)\otimes\ldots \psi(s_n),

where sj[0,1]s_j\in [0,1] and Ajsj,BcA^{s_j}_j, B^c are D×DD\times D matrices. The dimension DD is called the bond dimension. The logits for a given input xRnx\in R^{n} are then given by
lc=ΨcΨ(x).l_c = \Psi_c\cdot \Psi(x).

Initial conditions for the matrices are chosen such that for any input xx the logits lcl_c are all up to ϵ\epsilon, where ϵ1\epsilon\ll1, equal to one. In particular we chose [Ajsj]ij=ID+Rj[A^{s_j}_j]_{ij}=I_{D}+R_j and Bc=ID+RcB^c=I_{D}+R_c, where IDI_D denotes an D×DD\times D dimensional identity matrix and RR an appropriate matrix with normally distributed elements with mean zero and standard deviation ϵ\epsilon.
Note that adding a bias term in the MPS does not change the representational since this is equivalent to having sjs_j-independent parts of matrices AjsjA_j^{s_j}.
💡
Example: For our 2 dimensional example above and three classes the matrices the MPS could be given by ( chosing ϵ=106\epsilon = 10^{-6} )
A10=[1001]+106[0.20.40.80.1],A11=[1001]+106[0.30.10.60.7]A_1^{0}=\begin{bmatrix} 1& 0 \\ 0 & 1 \\ \end{bmatrix} + 10^{-6}\begin{bmatrix} 0.2& -0.4 \\ 0.8 & -0.1 \\ \end{bmatrix}, \quad A_1^{1}=\begin{bmatrix} 1& 0 \\ 0 & 1 \\ \end{bmatrix} + 10^{-6}\begin{bmatrix} -0.3& 0.1 \\ -0.6 & -0.7 \\ \end{bmatrix}
A20=[1001]+106[0.10.30.20.9],A21=[1001]+106[0.40.30.40.2]A_2^{0}=\begin{bmatrix} 1& 0 \\ 0 & 1 \\ \end{bmatrix} + 10^{-6}\begin{bmatrix} 0.1& -0.3 \\ -0.2 & 0.9 \\ \end{bmatrix}, \quad A_2^{1}=\begin{bmatrix} 1& 0 \\ 0 & 1 \\ \end{bmatrix} + 10^{-6}\begin{bmatrix} -0.4& 0.3 \\ 0.4 & -0.2 \\ \end{bmatrix}
B0=[1001]+106[0.50.20.40.1],B1=[1001]+106[0.70.20.30.1],B3=[1001]+106[0.10.80.70.6]B^{0}=\begin{bmatrix} 1& 0 \\ 0 & 1 \\ \end{bmatrix} + 10^{-6}\begin{bmatrix} -0.5& 0.2 \\ 0.4 & 0.1 \\ \end{bmatrix}, \quad B^{1}=\begin{bmatrix} 1& 0 \\ 0 & 1 \\ \end{bmatrix} + 10^{-6}\begin{bmatrix} 0.7& 0.2 \\ -0.3 & -0.1 \\ \end{bmatrix}, \quad B^{3}=\begin{bmatrix} 1& 0 \\ 0 & 1 \\ \end{bmatrix} + 10^{-6}\begin{bmatrix} -0.1& 0.8 \\ -0.7 & 0.6 \\ \end{bmatrix}

Predicted logits for the given xx are then
lj=12tr((1x1)A11+x1A10)Bj((1x2)A20+x2A21))=12tr(([1001]+O(ϵ))([1001]+O(ϵ))([1001]+O(ϵ)))=1+O(3ϵ)l_j=\frac{1}{2}\mathrm{tr}\left((1-x_1)A^1_1 + x_1 A^0_1)B^j((1-x_2)A^0_2 + x_2 A^1_2) \right)=\frac{1}{2}\mathrm{tr}\left( \left(\begin{bmatrix} 1& 0 \\ 0 & 1 \\ \end{bmatrix} + \mathcal{O}(\epsilon)\right)\left(\begin{bmatrix} 1& 0 \\ 0 & 1 \\ \end{bmatrix} + \mathcal{O}(\epsilon)\right)\left(\begin{bmatrix} 1& 0 \\ 0 & 1 \\ \end{bmatrix} + \mathcal{O}(\epsilon)\right) \right) = 1+\mathcal{O}(3\epsilon)

💡
We used the standard corss-entropy loss without with L2 regularzation of the MPS matrices. In our final (best) setting we used batch size 500, AdamW optimizer, and reducing learning rate on plateau with γ=0.5\gamma=0.5 and patience 20. The optimizer has been determined as a part of the hyper-parameter tuning together with the augmentation.

Augmentation

We used torchvision.transforms to augment the dataset images. In order to get a good augmentation for the given model we performed hyper-parameter tuning for the various transformation parameters and probabilities. In particular we used the following sequence of transformations ( Transformation_name(parameter_1, parameter_2,... parameter_n) -- along with parameters in the brackets we tuned also the probability of applying the transformation):
  1. Random pixel shift in the range [-aug_phi, aug_phi]
  2. Random color jitter: transforms.ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue)
  3. Random sharpness: transforms.functional.adjust_sharpness(sharpness_factor=sharp_fac)
  4. Random Gaussian blur: transforms.GaussianBlur(kernel_size=blur_kernel_size)
  5. Random horizontal flip: transforms.RandomHorizontalFlip()
  6. Random horizontal flip: transforms.RandomHorizontalFlip()
  7. Random affine transformation: transforms.RandomAffine(rotate, translate=(txy, txy), scale=(scale_min, scale_max))
  8. Random perspective transformation: transforms.RandomPerspective(distortion_scale=perspective_scale)
  9. Resize: this transformation is always applied transforms.Resize(resize)
  10. Random crop: transforms.RandomCrop(crop)
  11. Random elastic transformation: (here we used the elasticdeform library as import elasticdeform.torch as etorch ) etorch.deform_grid(displacement)
  12. Random erasing: transforms.RandomErasing(scale=(erasing_scale_min, erasing_scale_max), ratio=(0.3, 3.3))
The order of transformation is fixed and is the same as given above. During hyper-parameter tuning the training was restricted to maximum 30 epochs. The tuning accuracies statistics of the best 3000 runs is shown in the following plot.


The best hyper-parameters obtained by extensive 5-fold cross-validation hyper-parameter tuning with kerastuner are the following
best_hyperparameters = {'bs': 400,
'lr': 0.0002552960624312262,
'l2': 0.0033015876828498825,
'optimizer': 'adamw',
'step': 20,
'gamma': 0.5,
'crop': 25,
'aug_phi': 1.3114520152002449e-05,
'aug_color_jitter_prob': 0.6914513663570058,
'aug_brightness': 0.4878184695451711,
'aug_contrast': 0.11360159363742299,
'aug_saturation': 0.22362211935883042,
'aug_hue': 0.11969939431322466,
'aug_sharpness_prob': 0.21108948047691428,
'aug_sharp_min': 0.27663873915325154,
'aug_sharp_max': 1.5612024117031147,
'aug_gblur_prob': 0.3502871754293765,
'aug_gblur_kernel': 2,
'aug_horizontal_filp': 0,
'aug_affine_prob': 0.006432062003668701,
'aug_translate': 0.10325178139420912,
'aug_rotate': 0.9259700779843727,
'aug_scale_min': 0.8462448419228799,
'aug_scale_max': 1.0147418030794249,
'aug_perspective_prob': 0.3448470727044216,
'aug_perspective_scale': 0.493661541529104,
'aug_elastic_prob': 0.6259378322210638,
'aug_elastic_strength': 0.2718582635633056,
'aug_erasing_prob': 0.33955848871966826,
'aug_erasing_scale_min': 0.018150876613635274,
'aug_erasing_scale_max': 0.0911760700140132}
We remark that significantly simpler augmentations can lead to similar error rates.

Results



Training

Our model is implemented in PyTorch and TensorFlow. Due to the time consuming graph construction of our implementation of the TensorFlow version we finally used only the PyTorch for our experiments.
Interestingly we observe that there is no over-fitting even for the largest models with bond dimension 90. We observe that the validation and test accuracies are about 1% larger than the training accuracy, during the evaluation phase image transformations . The typical training, validation and testing accuracies during training are shown in the following plot.



Error rate vs bond dimension

As expected the model performance increases with the bond dimension (see plots below). The full the model is converged already at the bond dimension D=30D=30 (D=40D=40 in for randomly permuted inputs). Using an ensemble predictor of 10 MPS classifiers, which is exactly the same as one MPS with 10-times larger bond dimension but block-diagonal matrices, we achieve from 2%-5% error rate reduction. Our best ensemble-model has an error rate of 0.39%. The difference between model performance for the row-major order and a random order of the input is 2%.
Test error for the full MPS.
The error rate of the translationary invariant MPS model is much bigger than the error rate for the full MPS and decreases much faster with the bond dimension. Nevertheless, the our largest models with bond dimension 90 are not yet converged. Additionally, the performance of the models drops significantly in the permuted case but not so in the row-major ordered inputs. In the former we also observe a a significant decrease of the ensemble test error when averaging models with different input permutations. In contrast, we did not observe a significant difference in the case of full MPS ensembles.
Test error for the translationary invariant MPS.