Supcon Loss results
Trained a mode on CIFAR10 dataset. Model backbone is MNASNET. Loss function is Supcon-Loss which is used in supervised contrastive learning.
Created on January 27|Last edited on January 27
Comment
python cifar_supcon.py --max_epochs 6 --auto_scale_batch_size True --embed_sz 128 --gamma 0.1 --amp_backend native --precision 16 --steps 3 4 --gpus 1 --data_dir dataset --img_sz 224 --resize 250
project("tusuf", "SupCon_loss-project").artifactVersion("source-2v8n2uqy", "c90cdafa0e25ac8245e2").file("cifar_supcon.py")
# supcon loss is a supervised contrastive learning loss. i.e. it needs the labels to perform learning as compared to SimSiam and SimCLR.
#%%
from argparse import ArgumentParser
from torch.optim.lr_scheduler import MultiStepLR
import torch
import pytorch_lightning as pl
from torch.nn import functional as F
from pytorch_lightning.callbacks import ModelCheckpoint
# from torch.utils.data import DataLoader, random_split
from typing import List
from pytorch_lightning.loggers import WandbLogger
from torchvision.datasets.mnist import MNIST
from torchvision import transforms
import timm
from cifar_dm import CIFAR_DataModule
from pytorch_metric_learning.losses import SupConLoss
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import LearningRateMonitor
import wandb
# wandb.login()
class LitClassifier(pl.LightningModule):
def __init__(self, embed_sz : int , steps : List, lr : float = 1e-3, gamma : float = 0.1, **kwargs):
super().__init__()
self.embed_sz = embed_sz
self.lr = lr
self.gamma = gamma
self.steps = [int(k) for k in steps]
# define the backbone network
self.backbone = timm.create_model('mnasnet_100', pretrained=True)
# put backbone in train mode
self.backbone.train()
in_features = self.backbone.classifier.in_features
self.project = torch.nn.Linear(in_features, self.embed_sz)
self.backbone.classifier = torch.nn.Identity()
self.supcon_head = SupConLoss(temperature=0.1)
# self.activation = torch.nn.LeakyReLU(negative_slope=0.1)
self.save_hyperparameters()
def forward(self, x):
x = self.backbone(x)
x = self.project(x)
return x
def training_step(self, batch, batch_idx):
x, y = batch
# print(x.shape)
# print(y.shape)
embeddings = self(x)
loss = self.supcon_head(embeddings, y)
# for logging to the loggers
self.log("train_loss", loss)
return {"loss":loss}
def validation_step(self, batch, batch_idx):
x, y = batch
embeds = self(x)
loss = self.supcon_head(embeds, y)
# for logging to the loggers
self.log('val_loss', loss)
return {"loss":loss}
def test_step(self, batch, batch_idx):
x, y = batch
embeds = self(x)
loss = self.supcon_head(embeds, y)
self.log('test_loss', loss)
return {"test_loss":loss}
def configure_optimizers(self):
opt = torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
# sched = MultiStepLR(opt, milestones=self.hparams.steps, gamma=self.hparams.gamma)
# lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
# opt, self.hparams.max_epochs, eta_min=0
# )
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, 50, T_mult=1, eta_min=0, last_epoch=- 1, verbose=False)
return [opt], [lr_scheduler]
# return {"optimizer": opt,
# "lr_scheduler":
# {
# # REQUIRED: The scheduler instance
# "scheduler": sched,
# # The unit of the scheduler's step size, could also be 'step'.
# # 'epoch' updates the scheduler on epoch end whereas 'step'
# # updates it after a optimizer update.
# "interval": "epoch",
# # How many epochs/steps should pass between calls to
# # `scheduler.step()`. 1 corresponds to updating the learning
# # rate after every epoch/step.
# "frequency": 1,
# }
# }
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--embed_sz', type=int, default=128)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--gamma', type=float, default=0.1)
parser.add_argument('--steps', nargs='+', required=True)
return parser
def cli_main():
pl.seed_everything(1000)
wandb.init()
wandb.run.log_code(".") # all python files uploaded
wandb.login()
checkpoint_callback = ModelCheckpoint(filename="checkpoints/cifar10-{epoch:02d}-{val_loss:.6f}", monitor='val_loss', mode='min', )
lr_callback = LearningRateMonitor(logging_interval="step")
wandb_logger = WandbLogger(project='CIFAR-10', # group runs in "MNIST" project
log_model='all', # log all new checkpoints during training
name="supcon loss")
# ------------
# args
# ------------
parser = ArgumentParser()
# trainer CLI args added
parser = pl.Trainer.add_argparse_args(parser)
# model specific args
parser = LitClassifier.add_model_specific_args(parser)
# dataset specific args
parser = CIFAR_DataModule.add_model_specific_args(parser)
args = parser.parse_args()
# ------------
# data
# ------------
# dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
# mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
# mnist_train, mnist_val = random_split(dataset, [55000, 5000])
# train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
# val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
# test_loader = DataLoader(mnist_test, batch_size=args.batch_size)
dm = CIFAR_DataModule(**vars(args)) # vars converts Namespace --> dict, ** converts to kwargs
# ------------
# model
# ------------
model = LitClassifier(**vars(args))
wandb_logger.watch(model)
# ------------
# training
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer.callbacks.append(checkpoint_callback)
trainer.callbacks.append(lr_callback)
trainer.logger = wandb_logger
trainer.tune(model, dm)
# log args to wandb
args.batch_size = model.hparams.get('batch_size')
# dm.hparams.batch_size = args.batch_size
dm.hparams.batch_size = 512
print(f"\n\n batch size -----> {args.batch_size}\n\n")
wandb.config.update(vars(args))
trainer.fit(model, dm)
# ------------
# testing
# ------------
# trainer.test()
if __name__ == '__main__':
cli_main()
project("tusuf", "SupCon_loss-project").artifactVersion("source-2v8n2uqy", "c90cdafa0e25ac8245e2").file("cifar_dm.py")
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader
from argparse import ArgumentParser
# Note - you must have torchvision installed for this example
from torchvision.datasets import CIFAR10
from torchvision import transforms
from typing import Optional
from PIL import Image
from pathlib import Path
# ========================================================================
# timm imports
# ========================================================================
# import urllib
# from PIL import Image
# from timm.data import resolve_data_config
# from timm.data.transforms_factory import create_transform
# import timm
# model = timm.create_model('mnasnet_100', pretrained=True)
# model.eval()
# config = resolve_data_config({}, model=model)
# transform = create_transform(**config)
# print(transform)
# url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
# urllib.request.urlretrieve(url, filename)
# img = Image.open(filename).convert('RGB')
# tensor = transform(img).unsqueeze(0) # transform and add batch dimension
# print(tensor.shape)
class CIFAR_DataModule(pl.LightningDataModule):
def __init__(self, data_dir: str, img_sz : int, resize : int, batch_size : int, *args, **kwargs):
super().__init__()
self.data_dir = data_dir
self.img_sz = img_sz
self.resize = resize
self.batch_size = batch_size
self.transform = transforms.Compose([
transforms.Resize(size=self.resize, interpolation=Image.BICUBIC),
transforms.CenterCrop(size=(self.img_sz, self.img_sz)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250])
# transforms.Normalize(mean=[0.4850,], std=[0.2290,])
])
# create an empty dir. if not exists
Path(self.data_dir).mkdir(parents=True, exist_ok=True)
# Setting default dims here because we know them.
# Could optionally be assigned dynamically in dm.setup()
# self.dims = (1, 28, 28)
# try updating the lightning to see if it works
self.save_hyperparameters()
print("==================================")
print(self.hparams)
print("==================================")
def prepare_data(self):
# download
CIFAR10(self.data_dir, train=True, download=True)
CIFAR10(self.data_dir, train=False, download=True)
def setup(self, stage: Optional[str] = None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
print(f"\n\n dataset size --> {len(mnist_full)} \n\n")
self.mnist_train, self.mnist_val = random_split(mnist_full, [45000, 5000])
# Optionally...
self.dims = tuple(self.mnist_train[0][0].shape) # X, Y ===> X.shape
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = CIFAR10(self.data_dir, train=False, transform=self.transform)
# Optionally...
# self.dims = tuple(self.mnist_test[0][0].shape)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.hparams.batch_size, num_workers=4)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.hparams.batch_size, num_workers=4)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.hparams.batch_size)
@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--data_dir', type=str, default='./dataset')
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--img_sz', type=int, default=224, help='size of image')
parser.add_argument('--resize', type=int, default=250, help='resize the image to this size after which center crop is performed @ --img_sz flag')
return parser
Run set
17
Add a comment
Created with ❤️ on Weights & Biases.
https://wandb.ai/tusuf/SupCon_loss-project/reports/Supcon-Loss-results--VmlldzoxNDk0MDYw