Fine-Tuning Stable Diffusion Using Dreambooth in Keras
In this article, we quickly teach Stable Diffusion new visual concepts using Dreambooth in Keras, to produce fully-novel photorealistic images of a given subject.
Created on February 24|Last edited on March 17
Comment
Large text-to-image models, like Stable Diffusion, DALLE-2, and Midjourney, have revolutionized machine learning, enabling high-quality and diverse synthesis of images from a given text prompt. However, it's difficult for such a model to mimic the appearance of subjects in a given reference set and synthesize novel renditions of them in different contexts.
In this article, we discuss Dreambooth, which is an approach for the personalization of text-to-image diffusion models (specializing them to users' needs). Using this approach, we can fine-tune a pre-trained text-to-image model (such as Stable Diffusion) with just a few input images of a subject so that it learns to bind a unique identifier with that specific subject. Once the subject is embedded in the output domain of the model, the unique identifier can be used to synthesize fully-novel photorealistic images of the subject contextualized in different scenes! 😱
Here's what we'll be covering:
Table of Contents
Using Dreambooth, you can personalize Stable-Diffusion on novel visual concepts!!!
4
What Does Dreambooth Do?
With a few images (usually 3-5) of a subject or novel visual concept, DreamBooth can generate a myriad of images in different contexts using a text prompt. Our objective is to implant the subject into the output domain of Stable Diffusion such that it can be synthesized with a unique identifier.
Dreambooth helps to expand the language-vision dictionary of Stable-Diffusion such that it binds new words with a specific unique class or subject the user wants to generate.
Dreambooth uses a technique called prior preservation for fine-tuning text-to-image diffusion models in a few-shot setting while preserving the model’s semantic knowledge of the class of the subject.
Once the new dictionary is embedded in the Dreamboothed model, it can use these words to synthesize novel photorealistic images of the subject, contextualized in different scenes while preserving their key identifying features.
The effect is similar to a magic photo booth; once a few images of the subject are taken, the booth generates photos of the subject in different conditions and scenes, guided by simple and intuitive text prompts.
Note that a unique class is the new word corresponding to the new visual concept we wish to introduce in Stable-Diffusion's language-vision dictionary. A unique identifier is a string that is prepended to the unique class to uniquely indentify it in the instance prompts. For example, in the aforementioned generated images, gkrxt mom and sks monkey are the unique identifiers and classes respectively.
💡
Using KerasCV Dreambooth
In this report, we would be exploring the Keras implementation of Dreambooth, developed by Sayak Paul and Chansung Park, for fine-tuning Stable-Diffusion. Dreambooth-Keras is a modular and user-friendly implementation of Dreambooth that uses the pre-trained Stable Diffusion checkpoints from KerasCV. Moreover, it also comes with the goodness of Weights & Biases built-in!
Steps for using Dreambooth-Keras
First, clone the repository and install the prerequisites using pip install -r requirements.txt.
For the scope of this report, we use soumik12345/dreambooth-keras which is a fork of the original dreambooth-keras. It can be cloned and installed using pip install -e dreambooth-keras. We wholeheartedly recommend you to ⭐️ the original repository.
💡
You can run the following colab notebook to fine-tune Stable-diffusion on your own images using Dreambooth-Keras
Then, we need to choose a class to append a unique identifier. For example, if you use sks as the unique identifier and dog as the class, then two types of prompts are generated:
- Instance prompt: "a photo of {unique_id} {class_category}".
- Class prompt: "a photo of {class_category}".
Next, we need to collect a few images (ideally 20-25) called instance images that are representative of the concept the model is going to be fine-tuned with. These images are then associated with the respective instance prompt.
You can host your instance image datasets as online archives or as Weights & Biases artifacts. The following Weave panels show the Artifact hosting instance images.
Artifacts hosting Instance images
0
Next, we need a dataset of class images. DreamBooth uses a prior-preservation loss to regularize training which helps the model slowly adapt to the new visual concept under consideration from any prior knowledge it may have had about the concept. To use prior-preservation loss, we need the class prompt as mentioned previously. The class prompt is used to generate a predefined number of images which are used for computing the final loss used for DreamBooth training.
You can use the following Colab notebook to generate your own priors and host them as Weights & Biases artifacts.
Artifacts hosting class images
0
In the context of Dreambooth, class images denote the broader visual concept of the personalized subject and the instance images represent the specific visual concept of the subject. For this reason, you can either you a set of images generated by Stable Diffusion itself as class images, or you can use a dataset that contains similar real-world images. For example, while fine-tuning for a specific species of monkeys, you can either generate 200-300 images from Stable-diffusion using the prompt "a photo of monkey" or you can use a dataset like the 10 Monkey Species.
💡
Now that we have the instance and class images, we can launch training using the training script train_dreambooth.py. The training script comes with lots of options for you to try:
- You can launch mixed-precision training using python train_dreambooth.py --mp.
- You can fine-tune the text encoder by specifying the --train_text_encoder option.
- Additionally, the script also comes integrated with Weights & Biases if you use the option --log_wandb
- using this flag will enable you to automatically log the training metrics to your W&B dashboard using the WandbMetricsLogger callback.
- it will also upload your model checkpoints at the end of each epoch to your W&B project as an artifact for model versioning. This is done using the DreamBoothCheckpointCallback, which was built using WandbModelCheckpoint callback.
- it will also perform inference with the Dreamboothed model parameters at the end of each epoch and log them into a W&B Table in your W&B dashboard. This is done using the QualitativeValidationCallback, which also logs generated images into a media panel on your W&B dashboard at the end of the training.
Here's a command that launches training and logs training metrics and generated images to your Weights & Biases workspace:
python train_dreambooth.py \--log_wandb \--validation_prompts \"a photo of sks dog with a cat" \"a photo of sks dog riding a bicycle" \"a photo of sks dog peeing" \"a photo of sks dog playing cricket" \"a photo of sks dog as an astronaut"
Additionally, you can have your datasets corresponding to the instance and class images stored as Artifacts for versioning your dataset and tracking the lineage of your workflow. You can specify the artifact addresses of your datasets in the corresponding flags, like the following example:
python train_dreambooth.py \--instance_images_url "geekyrakshit/dreambooth-keras/monkey-instance-images:v0" \--class_images_url "geekyrakshit/dreambooth-keras/monkey-class-images:v0" \--class_category "monkey" \--mp \--log_wandb \--lr 5e-06 \--max_train_steps 2000 \--validation_prompts \"a photo of sks monkey with a cat" \"a photo of sks monkey riding a bicycle" \"a photo of sks monkey as an astronaut" \"a photo of sks monkey in front of the taj mahal" \"a photo of sks monkey wearing sunglasses and drinking beer"
Experiments With Dreambooth
Let's now look at some attempts at fine-tuning StablevDiffusion using Dreambooth. All the fine-tuning experiments were performed using a single NVIDIA Ampere A100 GPU with a batch size of 1 and an image resolution of 512x512.
Fine-tuning On Monkeys
Fine-tuning Stable-Diffusion on Images of Monkeys
10
Fine-Tuning On Photos of My Mom
Fine-tuning Stable-Diffusion on images of my Mom!
7
Conclusion
- In this article, we explored Dreambooth, an approach for fine-tuning text-to-image diffusion models on new visual concepts hitherto undreamt of by the model.
- We explored how we can fine-tune the Keras implementation of Stable Diffusion using Dreambooth-Keras developed by Sayak Paul and Chansung Park.
- We walked through the steps to fine-tune Stable Diffusion using Dreambooth-Keras and analyze the results on Weights & Biases.
- We also explored how we can perform inference on the Dreamboothed checkpoints using our own prompts and log the results on Weights & Biases.
- You can refer to this article by the developers of Dreambooth Keras that explores the implementation details.
Paella: Fast Text-Conditional Image Generation
In this article, we explore the paper "Fast Text-Conditional Discrete Denoising on Vector-Quantized Latent Spaces" which introduces Paella, a novel text-to-image model.
How To Train a Conditional Diffusion Model From Scratch
In this article, we look at how to train a conditional diffusion model and find out what you can learn by doing so, using W&B to log and track our experiments.
Training Journal: DreamBooth Torta Fine-tuning
Stable Diffusion Settings and Storing Your Images
In this article, we explore the impact of different settings used for the Stable Diffusion model and how you can store your generated images for quick reference.
Improving Generative Images with Instructions: Prompt-to-Prompt Image Editing with Cross Attention Control
A primer on text-driven image editing for large-scale text-based image synthesis models like Stable Diffusion & Imagen
Making My Kid a Jedi Master With Stable Diffusion and Dreambooth
In this article, we'll explore how to teach and fine-tune Stable Diffusion to transform my son into his favorite Star Wars character using Dreambooth.
Mid-U Guidance: Fast Classifier Guidance for Latent Diffusion Models
Introducing a new method for diffusion model guidance with various advantages over existing methods, demonstrated by adding aesthetic guidance to Stable Diffusion.
Add a comment
Tags: Keras, Articles, Stable Diffusion, GenAI, Image Generation, Has Colab, Experiment, Fine-tuning
Iterate on AI agents and models faster. Try Weights & Biases today.