Skip to main content

JoJoGAN: One Shot Face Stylization with W&B and Gradio

This report showcases JoJoGAN: One Shot Face Stylization for fine-tuning a pretrained stylegan from faces to stylized art. Track experiments on wandb and use the live demo with Gradio. Try the live demo in your browser!
Created on January 12|Last edited on January 25

Introduction

In this report, we'll walk you through how to use W&B and Gradio together, complete with code, tables, training runs, and of course, a slick Gradio embed on JoJoGAN.
If you're unfamiliar with any of the above, we've got you covered:
  • Weights and Biases (W&B) allows ML practitioners to track their machine learning experiments at every stage, from training to production.
  • Gradio is the fastest way to demo your machine learning model with a friendly web interface so that anyone can use it, anywhere!
  • JoJoGAN is a recently published one-shot face stylization model. You can read the paper by clicking the link or you can try out the Gradio embed below!

Run set
1


Now, let's walk you through how to do this on your own. We'll make the assumption that you're new to W&B and Gradio for the purposes of this tutorial.
Let's get started!

1. Create a W&B account

Follow these quick instructions to create your free account if you don’t have one already. It shouldn't take more than a couple minutes. Once you're done (or if you've already got an account), next, we'll run a quick colab.

2. Open Colab Install Gradio and W&B

We'll be following along with the colab provided in the JoJoGAN repo with some minor modifications to use Wandb and Gradio more effectively.



Install Gradio and Wandb at the top:
!pip install gradio wandb

3. Follow the instruction in colab to setup and try out a pretrained model

Code as follows:
plt.rcParams['figure.dpi'] = 150
pretrained = 'arcane_multi' #@param ['art', 'arcane_multi', 'supergirl', 'arcane_jinx', 'arcane_caitlyn', 'jojo_yasuho', 'jojo', 'disney']
#@markdown Preserve color tries to preserve color of original image by limiting family of allowable transformations. Otherwise, the stylized image will inherit the colors of the reference images, leading to heavier stylizations.
preserve_color = False #@param{type:"boolean"}

if preserve_color:
ckpt = f'{pretrained}_preserve_color.pt'
else:
ckpt = f'{pretrained}.pt'

# load base version if preserve_color version not available
try:
downloader.download_file(ckpt)
except:
ckpt = f'{pretrained}.pt'
downloader.download_file(ckpt)

ckpt = torch.load(os.path.join('models', ckpt), map_location=lambda storage, loc: storage)
generator.load_state_dict(ckpt["g"], strict=False)

#@title Generate results
n_sample = 5#@param {type:"number"}
seed = 3000 #@param {type:"number"}

torch.manual_seed(seed)
with torch.no_grad():
generator.eval()
z = torch.randn(n_sample, latent_dim, device=device)

original_sample = original_generator([z], truncation=0.7, truncation_latent=mean_latent)
sample = generator([z], truncation=0.7, truncation_latent=mean_latent)

original_my_sample = original_generator(my_w, input_is_latent=True)
my_sample = generator(my_w, input_is_latent=True)

# display reference images
if pretrained == 'arcane_multi':
style_path = f'style_images_aligned/arcane_jinx.png'
else:
style_path = f'style_images_aligned/{pretrained}.png'
style_image = transform(Image.open(style_path)).unsqueeze(0).to(device)
face = transform(aligned_face).unsqueeze(0).to(device)

my_output = torch.cat([style_image, face, my_sample], 0)
display_image(utils.make_grid(my_output, normalize=True, range=(-1, 1)), title='My sample')

output = torch.cat([original_sample, sample], 0)
display_image(utils.make_grid(output, normalize=True, range=(-1, 1), nrow=n_sample), title='Random samples')

4. Add style images for fine-tuning

Next, you'll upload some of your own images for style. Upload those images in colab and add the names of the images as shown below:
#@markdown Upload your own style images into the style_images folder and type it into the field in the following format without the directory name. Upload multiple style images to do multi-shot image translation
names = ['arcane_caitlyn.jpeg', 'arcane_jinx.jpeg', 'arcane_jayce.jpeg', 'arcane_viktor.jpeg'] #@param {type:"raw"}

targets = []
latents = []

for name in names:
style_path = os.path.join('style_images', name)
assert os.path.exists(style_path), f"{style_path} does not exist!"

name = strip_path_extension(name)

# crop and align the face
style_aligned_path = os.path.join('style_images_aligned', f'{name}.png')
if not os.path.exists(style_aligned_path):
style_aligned = align_face(style_path)
style_aligned.save(style_aligned_path)
else:
style_aligned = Image.open(style_aligned_path).convert('RGB')

# GAN invert
style_code_path = os.path.join('inversion_codes', f'{name}.pt')
if not os.path.exists(style_code_path):
latent = e4e_projection(style_aligned, style_code_path, device)
else:
latent = torch.load(style_code_path)['latent']

targets.append(transform(style_aligned).to(device))
latents.append(latent.to(device))

targets = torch.stack(targets, 0)
latents = torch.stack(latents, 0)

target_im = utils.make_grid(targets, normalize=True, range=(-1, 1))
display_image(target_im, title='Style References')


5. Finetune StyleGAN and W&B experiment tracking

This next step will open a W&B dashboard to track your experiments and a gradio panel showing pretrained models to choose from a drop down menu from a Gradio Demo hosted on Huggingface Spaces.
#@title Finetune StyleGAN
#@markdown alpha controls the strength of the style
alpha = 1.0 #@param {type:"slider", min:0, max:1, step:0.1}
alpha = 1-alpha

#@markdown Tries to preserve color of original image by limiting family of allowable transformations. Set to false if you want to transfer color from reference image. This also leads to heavier stylization
preserve_color = True #@param{type:"boolean"}
#@markdown Number of finetuning steps. Different style reference may require different iterations. Try 200~500 iterations.
num_iter = 200 #@param {type:"number"}
#@markdown Log training on wandb and interval for image logging
use_wandb = True #@param {type:"boolean"}
log_interval = 50 #@param {type:"number"}

samples = []
column_names = ["Referece (y)", "Style Code(w)", "Real Face Image(x)"]

if use_wandb:
wandb.init(project="JoJoGAN")
config = wandb.config
config.num_iter = num_iter
config.preserve_color = preserve_color
wandb.log(
{"Style reference": [wandb.Image(transforms.ToPILImage()(target_im))]},
step=0)
wandb.log({"Gradio panel": wandb.Html('''
<link rel="stylesheet" href="https://gradio.s3-us-west-2.amazonaws.com/2.6.2/static/bundle.css">
<div id="JoJoGAN-demo"></div>
<script src="https://gradio.s3-us-west-2.amazonaws.com/2.6.2/static/bundle.js"></script>
<script>
launchGradioFromSpaces("akhaliq/JoJoGAN", "#JoJoGAN-demo")
</script>
<style>
/* work around a weird bug */
.gradio_app .gradio_bg[theme=huggingface] .gradio_interface .input_dropdown .dropdown:hover .dropdown_menu {
display: block;
}
</style>
''', inject=False)})

lpips_fn = lpips.LPIPS(net='vgg').to(device)

# reset generator
del generator
generator = deepcopy(original_generator)

g_optim = optim.Adam(generator.parameters(), lr=2e-3, betas=(0, 0.99))

# Which layers to swap for generating a family of plausible real images -> fake image
if preserve_color:
id_swap = [7,9,11,15,16,17]
else:
id_swap = list(range(7, generator.n_latent))

for idx in tqdm(range(num_iter)):
if preserve_color:
random_alpha = 0
else:
random_alpha = np.random.uniform(alpha, 1)
mean_w = generator.get_latent(torch.randn([latents.size(0), latent_dim]).to(device)).unsqueeze(1).repeat(1, generator.n_latent, 1)
in_latent = latents.clone()
in_latent[:, id_swap] = alpha*latents[:, id_swap] + (1-alpha)*mean_w[:, id_swap]

img = generator(in_latent, input_is_latent=True)
loss = lpips_fn(F.interpolate(img, size=(256,256), mode='area'), F.interpolate(targets, size=(256,256), mode='area')).mean()
if use_wandb:
wandb.log({"loss": loss}, step=idx)
if idx % log_interval == 0:
generator.eval()
my_sample = generator(my_w, input_is_latent=True)
generator.train()
my_sample = transforms.ToPILImage()(utils.make_grid(my_sample, normalize=True, range=(-1, 1)))
wandb.log(
{"Current stylization": [wandb.Image(my_sample)]},
step=idx)
table_data = [
wandb.Image(transforms.ToPILImage()(target_im)),
wandb.Image(img),
wandb.Image(my_sample),
]
samples.append(table_data)

g_optim.zero_grad()
loss.backward()
g_optim.step()

out_table = wandb.Table(data=samples, columns=column_names)
wandb.log({"Current Samples": out_table})
Using LaunchGradioFromSpaces allows anyone can embed Gradio demos on HF spaces directly into their blogs, websites, documentation, etc.:
launchGradioFromSpaces("akhaliq/JoJoGAN", "#JoJoGAN-demo")
Meanwhile, adding a Gradio Demo to a W&B Report takes just a few extra lines of code:
wandb.log({"Gradio panel": wandb.Html('''
<link rel="stylesheet" href="https://gradio.s3-us-west-2.amazonaws.com/2.6.2/static/bundle.css">
<div id="JoJoGAN-demo"></div>
<script src="https://gradio.s3-us-west-2.amazonaws.com/2.6.2/static/bundle.js"></script>
<script>
launchGradioFromSpaces("akhaliq/JoJoGAN", "#JoJoGAN-demo")
</script>
<style>
/* work around a weird bug */
.gradio_app .gradio_bg[theme=huggingface] .gradio_interface .input_dropdown .dropdown:hover .dropdown_menu {
display: block;
}
</style>
''', inject=False)})

Run set
38

Lastly, here's how to save, download, and load your model (and Gradio demo)

6. Save and Download Model

torch.save({"g": generator.state_dict()}, "your-model-name.pt")

from google.colab import files
files.download('your-model-name.pt')

7. Load Model and Gradio Demo

ckptyourmodelname = torch.load('your-model-name.pt', map_location=lambda storage, loc: storage)
generatoryourmodelname.load_state_dict(ckptjojo["g"], strict=False)
import gradio as gr

title = "JoJoGAN"
description = "Gradio Demo for JoJoGAN: One Shot Face Stylization. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."

article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2112.11641' target='_blank'>JoJoGAN: One Shot Face Stylization</a>| <a href='https://github.com/mchong6/JoJoGAN' target='_blank'>Github Repo Pytorch</a></p> <center><img src='https://visitor-badge.glitch.me/badge?page_id=akhaliq_jojogan' alt='visitor badge'></center>"

examples=[['mona.png','Jinx']]
gr.Interface(inference, [gr.inputs.Image(type="pil"),gr.inputs.Dropdown(choices=['JoJo', 'Disney','Jinx','Caitlyn','Yasuho','Arcane Multi','Art','Spider-Verse'], type="value", default='JoJo', label="Model")], gr.outputs.Image(type="file"),title=title,description=description,article=article,allow_flagging=False,examples=examples,allow_screenshot=False,enable_queue=True).launch()


Conclusion

We hope you enjoyed this brief demo of embedding a Gradio demo to a W&B report! Thanks for making it to the end. To recap:
  • Only one single reference image is needed for fine-tuning JoJoGAN which usually takes about 1 minute on a GPU in colab. After training, style can be applied to any input image. Read more in the paper.
  • W&B tracks experiments with just a few lines of code added to a colab and you can visualize, sort, and understand your experiments in a single, centralized dashboard.
  • Gradio, meanwhile, demos the model in a user friendly interface to share anywhere on the web.
If you have any interesting Gradio demos to share, drop a link in the comments! We'd love to check them out!

Iterate on AI agents and models faster. Try Weights & Biases today.