Skip to main content

Flyte and Weights & Biases integration

Discover how Flyte’s latest plugin for Weights & Biases allows you to run and integrate ML/AI workflows on Union seamlessly.
Created on June 27|Last edited on July 2
With Flyte’s latest plugin for Weights & Biases, you can now effectively run machine learning or AI workflows on Union and integrate with Weights & Biases capabilities. Union provides scalability, declarative infrastructure, and data lineage allowing you to quickly iterate and productionize AI or ML workflows. Weights & Biases helps customers build models faster, fine-tune LLMs, and develop GenAI applications with confidence, all in one system of record. In this blog post, we dive deep into how the plugin works.
"This integration is particularly valuable to us as in our ecosystem many members who jointly work on large-scale ML projects can now have much broader traceability of dataset curation, training and evaluation all in one place. This allows our teams to progress faster by using a single source of truth to keep up with relevant developments from other members and pin point issues easier along the way."
- Alborz Alavian, Sr. Engineer Manager at Woven by Toyota

Flytekit’s Weights & Biases plugin

Union considers both data and computing to be fundamental building blocks.You can train models using machine learning or AI libraries such as XGBoost or PyTorch and track those models with Union artifacts. Union's reactive workflows are triggered when the underlying data changes and scales up to train many models.
In this initial example, flytekit's wandb_init configures the run in Weights & Biases and the XGBoost callback automatically tracks the model's progress. After decorating your function, the body consists of code you'll find in Weights & Biases documentation:
from flytekit import task
from flytekitplugins.wandb import wandb_init

wandb_secret = Secret(key="wandb-api-key")

@task(container_image=image, secret_requests=[wandb_secret])
@wandb_init(
project=WANDB_PROJECT, entity=WANDB_ENTITY, secret=wandb_secret,
)
def train(data: pd.DataFrame) -> float:
# Normal usage of wandb
from wandb.integration.xgboost import WandbCallback
import wandb
bst = XGBClassifier(...,callbacks=[WandbCallback(log_model=True)])

wandb.run.log({"test_score": test_score})
return test_score

The wandb_secret Secret object refers to a Weights & Biases API key, which was created with Union’s CLI: unionai create secret wandb-api-key. The wandb_init decorator will start the run and configure Union's UI to show the link to the run:

Clicking the link takes us to Weights & Biases, which shows all the tracking information about our model training execution. On Weights & Biases, the Flyte Execution is linked back in the run’s description:


Reactive workflows

With Union's artifacts, you can write workflows that automatically trigger when the data gets updated by another workflow. This enables workflows to be modular, where one team focuses on extracting data, and another focuses on modeling. You can declare an artifact with a Python typing annotation:
from flytekit.core.artifact import Artifact
from typing_extensions import Annotated

MyDataset = Artifact(name="my_dataset")

@task(...)
@wandb_init(...)
def train(data: pd.DataFrame) -> float:
...

# train_workflow will trigger when "my_dataset" gets updated
@workflow
def train_workflow(data: pd.DataFrame = MyDataset.query()):
train(data)

trigger = LaunchPlan.create(
"trigger_train_workflow",
train_workflow,
trigger=OnArtifact(trigger_on=MyDataset),
)
This train_with_artifact task takes a "my_dataset" artifact, which represents an upstream dataset. With the wandb_init decorator, Weights & Biases will track the metrics and results of the new training task with the updated dataset. You can observe changes in the model's performance as the dataset changes over time.

Scaling out experiments

With Flyte's dynamic workflows, you can quickly scale up to multiple training tasks, each with its own resources. In this example, you see how to use Flyte’s declarative infrastructure to train various models using PyTorch Lightning on GPUs. The function’s body consists of regular PyTorch Lightning code that you’ll find in their documentation.
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import WandbLogger
from flytekit.extras.accelerators import T4

@task(
container_image=image,
requests=Resources(gpu="1", cpu="2", mem="8Gi"),
accelerator=T4,
)
@wandb_init(...)
def train_lightning_model(n_layer: int) -> dict:
wandb_logger = WandbLogger(log_model="all")
model = MyLightningModule(n_layer_1=n_layer, n_layer_2=n_layer)
trainer = Trainer(max_epochs=5, logger=wandb_logger)
trainer.fit(model, training_loader, validation_loader)
...

@dynamic(container_image=image)
def main(n_layers: list[int]):
dataset = get_dataset()
for n_layer in n_layers:
train_lightning_model(dataset=dataset, n_layer=n_layer)
In the Union UI, the workflow dynamically scale out to multiple GPU-powered tasks:

PyTorch Lightning's WandbLogger automatically logs the metrics, hyperparameters, and checkpoints during model training. From the Weights & Biases platform, you can compare the different runs and evaluate our model’s performance.

Wrapping up

Union's declarative infrastructure and scalable orchestration platform make it simple to scale up our machine learning or AI workflows and put them in production. With flytekit's Weights & Biases plugin, you can easily track your experiments, visualize results, and debug our models. Use the plugin by installing it with pip install flytekitplugins-wandb.
If you want to learn more about Union, get in touch with us at www.union.ai/demo.
Iterate on AI agents and models faster. Try Weights & Biases today.