Skip to main content

Supcon Loss results

Trained a mode on CIFAR10 dataset. Model backbone is MNASNET. Loss function is Supcon-Loss which is used in supervised contrastive learning.
Created on January 27|Last edited on January 27
python cifar_supcon.py --max_epochs 6 --auto_scale_batch_size True --embed_sz 128 --gamma 0.1 --amp_backend native --precision 16 --steps 3 4 --gpus 1 --data_dir dataset --img_sz 224 --resize 250

# supcon loss is a supervised contrastive learning loss. i.e. it needs the labels to perform learning as compared to SimSiam and SimCLR.

#%%
from argparse import ArgumentParser
from torch.optim.lr_scheduler import MultiStepLR
import torch
import pytorch_lightning as pl
from torch.nn import functional as F
from pytorch_lightning.callbacks import ModelCheckpoint
# from torch.utils.data import DataLoader, random_split
from typing import List
from pytorch_lightning.loggers import WandbLogger
from torchvision.datasets.mnist import MNIST
from torchvision import transforms
import timm
from cifar_dm import CIFAR_DataModule
from pytorch_metric_learning.losses import SupConLoss
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor
import wandb

# wandb.login()

class LitClassifier(pl.LightningModule):
    def __init__(self, embed_sz : int , steps : List, lr : float = 1e-3,  gamma : float = 0.1, **kwargs):
        super().__init__()
        self.embed_sz = embed_sz
        self.lr = lr
        self.gamma = gamma
        self.steps = [int(k) for k in steps]
        # define the backbone network
        self.backbone = timm.create_model('mnasnet_100', pretrained=True)
        # put backbone in train mode
        self.backbone.train()
        in_features = self.backbone.classifier.in_features
        self.project = torch.nn.Linear(in_features, self.embed_sz)
        
        self.backbone.classifier = torch.nn.Identity()        

        self.supcon_head = SupConLoss(temperature=0.1)
        # self.activation = torch.nn.LeakyReLU(negative_slope=0.1)


        self.save_hyperparameters()

    def forward(self, x):
        x = self.backbone(x)
        x = self.project(x)

        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        # print(x.shape)
        # print(y.shape)
        embeddings = self(x)
        loss = self.supcon_head(embeddings, y)
        # for logging to the loggers
        self.log("train_loss", loss) 

        return {"loss":loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        embeds = self(x)
        loss = self.supcon_head(embeds, y)
        # for logging to the loggers
        self.log('val_loss', loss)
        return {"loss":loss}

    def test_step(self, batch, batch_idx):
        x, y = batch
        embeds = self(x)
        loss = self.supcon_head(embeds, y)
        self.log('test_loss', loss)
        return {"test_loss":loss}

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        # sched = MultiStepLR(opt, milestones=self.hparams.steps, gamma=self.hparams.gamma)
        
        

        # lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        #     opt, self.hparams.max_epochs, eta_min=0
        # )
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, 50, T_mult=1, eta_min=0, last_epoch=- 1, verbose=False)
        return [opt], [lr_scheduler]

        # return {"optimizer": opt, 
        #         "lr_scheduler": 
        #                     {
        #                             # REQUIRED: The scheduler instance
        #                             "scheduler": sched,
        #                             # The unit of the scheduler's step size, could also be 'step'.
        #                             # 'epoch' updates the scheduler on epoch end whereas 'step'
        #                             # updates it after a optimizer update.
        #                             "interval": "epoch",
        #                             # How many epochs/steps should pass between calls to
        #                             # `scheduler.step()`. 1 corresponds to updating the learning
        #                             # rate after every epoch/step.
        #                             "frequency": 1,
                                    
        #                     }
        # }

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        
        parser.add_argument('--embed_sz', type=int, default=128)
        parser.add_argument('--lr', type=float, default=0.0001)
        parser.add_argument('--gamma', type=float, default=0.1)
        parser.add_argument('--steps', nargs='+', required=True)
        return parser


def cli_main():
    pl.seed_everything(1000)
    wandb.init()
    wandb.run.log_code(".") # all python files uploaded
    wandb.login()

    
    checkpoint_callback = ModelCheckpoint(filename="checkpoints/cifar10-{epoch:02d}-{val_loss:.6f}", monitor='val_loss', mode='min', )
    lr_callback = LearningRateMonitor(logging_interval="step")
    wandb_logger = WandbLogger(project='CIFAR-10', # group runs in "MNIST" project
                           log_model='all', # log all new checkpoints during training
                            name="supcon loss")
    # ------------          
    # args
    # ------------
    parser = ArgumentParser()
    
    
    #  trainer CLI args added
    parser = pl.Trainer.add_argparse_args(parser)
    # model specific args
    parser = LitClassifier.add_model_specific_args(parser)
    # dataset specific args
    parser = CIFAR_DataModule.add_model_specific_args(parser)
    args = parser.parse_args()
    
    # ------------
    # data
    # ------------
    # dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
    # mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
    # mnist_train, mnist_val = random_split(dataset, [55000, 5000])

    # train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
    # val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
    # test_loader = DataLoader(mnist_test, batch_size=args.batch_size)
    dm = CIFAR_DataModule(**vars(args)) # vars converts Namespace --> dict, ** converts to kwargs
    # ------------
    # model
    # ------------
    model = LitClassifier(**vars(args))
    wandb_logger.watch(model)
    # ------------
    # training
    # ------------
    trainer = pl.Trainer.from_argparse_args(args)
    trainer.callbacks.append(checkpoint_callback)
    trainer.callbacks.append(lr_callback)
    trainer.logger = wandb_logger
    trainer.tune(model, dm)

    # log args to wandb
    args.batch_size = model.hparams.get('batch_size')
    # dm.hparams.batch_size = args.batch_size
    dm.hparams.batch_size = 512
    print(f"\n\n batch size -----> {args.batch_size}\n\n")
    wandb.config.update(vars(args))

    trainer.fit(model, dm)

    # ------------
    # testing
    # ------------
    # trainer.test()


if __name__ == '__main__':
    cli_main()

 
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader
from argparse import ArgumentParser
# Note - you must have torchvision installed for this example
from torchvision.datasets import CIFAR10
from torchvision import transforms
from typing import Optional

from PIL import Image
from pathlib import Path
# ========================================================================
#                             timm imports                                  
# ========================================================================
# import urllib
# from PIL import Image
# from timm.data import resolve_data_config
# from timm.data.transforms_factory import create_transform
# import timm

 
# model = timm.create_model('mnasnet_100', pretrained=True)
# model.eval()


# config = resolve_data_config({}, model=model)
# transform = create_transform(**config)
# print(transform)
 
# url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
# urllib.request.urlretrieve(url, filename)
# img = Image.open(filename).convert('RGB')

# tensor = transform(img).unsqueeze(0) # transform and add batch dimension
# print(tensor.shape)


class CIFAR_DataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str, img_sz : int, resize : int, batch_size : int, *args, **kwargs):
        super().__init__()
        self.data_dir = data_dir
        self.img_sz = img_sz
        self.resize = resize
        self.batch_size = batch_size
        self.transform = transforms.Compose([
                                    transforms.Resize(size=self.resize, interpolation=Image.BICUBIC),
                                    transforms.CenterCrop(size=(self.img_sz, self.img_sz)),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250])
                                    # transforms.Normalize(mean=[0.4850,], std=[0.2290,])
        ])
        # create an empty dir. if not exists
        Path(self.data_dir).mkdir(parents=True, exist_ok=True)
        # Setting default dims here because we know them.
        # Could optionally be assigned dynamically in dm.setup()
        # self.dims = (1, 28, 28)
        
        # try updating the lightning to see if it works
        self.save_hyperparameters()
        print("==================================")
        print(self.hparams)
        print("==================================")

    def prepare_data(self):
        # download
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
            print(f"\n\n dataset size --> {len(mnist_full)} \n\n")
            self.mnist_train, self.mnist_val = random_split(mnist_full, [45000, 5000])

            # Optionally...
            self.dims = tuple(self.mnist_train[0][0].shape) # X, Y ===> X.shape

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = CIFAR10(self.data_dir, train=False, transform=self.transform)

            # Optionally...
            # self.dims = tuple(self.mnist_test[0][0].shape)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.hparams.batch_size, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.hparams.batch_size, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.hparams.batch_size)

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--data_dir', type=str, default='./dataset')
        parser.add_argument('--batch_size', type=int, default=8)
        parser.add_argument('--img_sz', type=int, default=224, help='size of image')
        parser.add_argument('--resize', type=int, default=250, help='resize the image to this size after which center crop is performed @ --img_sz flag')
        return parser


Showing first 1 runs
102030405060Step02468EPOCH
102030405060Step10203040Supcon Loss
Run set
17

File<{extension: py}>
File<{extension: py}>