Building a Machine Learning Pipeline with TensorFlow Extended and W&B
This article gives an overview how to integrate experiment tracking and model registry from W&B into TensorFlow Extended to build machine learning pipeline.
Created on June 30|Last edited on August 10
Comment
Introduction
TensorFlow Extended or TFX is a framework to extend TensorFlow's capabilities beyond just modeling. It helps us automate every part of a machine learning pipeline from data injection, data validation, data transformation, hyperparameter tuning, model training, and model deployment. It also comes with the built-in standard components, ExampleGen, ExampleValidator, Transform, Tuner, Trainer, and Pusher for each purpose respectively.
In this article, we'll explore how W&B's experiment tracking and model registry can be integrated into TFX. For that purpose, we'll focus on the Tuner, Trainer, and Pusher, but you can find out how the other components are handled from the project repository.
- Experiment tracking is a very popular go-to toolbox among machine learning practitioners and organizations to better understand how the optimal combination of hyper-parameters is selected to fully train the final model. We'll explore how it can be integrated into TFX's Tuner and Trainer component.
- A trained model in each pipeline run should be managed in a version control system, allowing us to easily compare different versions of the model and roll back whenever when need to. We'll see how to build a custom TFX component by extending the standard Pusher to save models to the W&B's model registry. Additionally, we'll showcase how to publish a Gradio application which automatically points the version of the model from the current run.
We'll go through the entire design of the TFX pipeline with brief descriptions about the role of each TFX component. Finally, we'll focus on the actual implementation details of the Tuner, Trainer, and WandBPusher components.
Let's get started:
Overview of the TFX pipeline
To demonstrate TFX and W&B integration, we'll use beans dataset from HuggingFace 🤗 Dataset Hub and TFViT (vision transformer) model from HuggingFace transformers library. Our goal is simple: to correctly classify images of bean leaves in three categories: ngular_leaf_spot, bean_rust, and healthy.

The figure above shows the TFX pipeline consists of five components.
- Dataset: It's important to unify the data format inside any machine learning pipeline for consistency between components. Especially, binary and serializable format is preferred for better efficiency. Because TFX pipeline uses TFRecord as the standard data format, so the input data should be prepared as TFRecord beforehand. We provide a notebook to show how to convert the datasets from Hugging Face 🤗 Dataset Hub into TFRecord. Prepared TFRecord files can be easily consumed by ExampleGen component.
- Transform/Shaping: The serialized binary dataset doesn't have information about the shape of the data needed in order to transform it. That information can be carried through SchemaGen component when Transform component tries to transform the input data. In this project, Transform component does two things: 1) pixel value normalization by dividing with 255.0, and 2) and transposing image Tensors from BCWH to BWHC(Batch size, Width, Height, Channel).
- NOTE: Most of the TensorFlow models assume the input images are in the shape of BWHC while the datasets to be used in PyTorch in mind are often shaped in BCWH.💡
- Hyperparameter tuning: Hyperparameter tuning is a process that usually conducts and compares a number of experiments to find out an optimal combination of hyperparameters with a subset of the entire dataset. This is handled by the Tuner component which is simple a wrapper of KerasTuner. We provide a custom Python script to the Tuner component to tell which hyper-parameters to tune with what values. Tuner runs a number of experiments in parallel(if possible) by running the same script with different values. We just tell W&B to record an experiment has begun with what values.
- Full training: Tuner emits the optimal combination of hyperparameters as the output, and it is passed down to Trainer to do full training. Here, full training means it uses the entire dataset because it is convinced to achieve the best model performance with the output of Tuner. Trainer saves a fully trained model in SavedModel format, and the location of it is passed to the downstream task. As you can see from the chart below, the full training (highlighted line) has achieved the highest validation accuracy and the lowest validation loss.
Run set 2
1
- Pushing: Pusher, as you might guess, pushes a model to somewhere. By default, the target is the local file system, but it won't be manageable. Hence, we custom designed the WandBPusher component by extending the basic Pusher component. WandBPusher pushes the fully trained model to W&B's Model Registry, and it optionally publishes a demo application to Hugging Face 🤗 Space Hub. It's important to note that the model version is determined at runtime, so the it should be dynamically injected into the published application at runtime as well.
From the SavedModel saved in W&B's Model Registry. we can confirm that the final model is produced from the run named full-training-8OUV5U. Also, we can see that the model artifact has three aliases, latest, 1688113391, and v0. All these aliases are automatically generated. The first and the third ones are generated by W&B while the second one is generated by WandBPusher which is determined by int(time.time()) at runtime.
final_model
Version overview
Full Name
chansung18/tfx-vit-pipeline/final_model:v0
Aliases
1688113391
latest
v0
Tags
Digest
a957af88fda7ae3d08c8135509c3270e
Created By
Created At
June 30th, 2023 08:23:42
Num Consumers
0
Num Files
1
Size
319.2MB
TTL Remaining
Inactive
Description
Subsequently WandBPusher also can publish a demo application to Hugging Face 🤗 Space Hub. In this project, we chose Gradio as the application engine. All the template codes for the application has to be written, but with special strings such as $MODEL_PROJECT, $MODEL_RUN, $MODEL_NAME, and $MODEL_VERSION. Those special strings are replaced by the real values determined by WandBPusher at runtime. You can find the working application hosted on Hugging Face Space 🤗 Hub here.

Demo application deployed on Hugging Face 🤗 Space Hub
At this point, we've levelset on MLOps, the problem we're trying to solve, the roles of each components, and how those components are interconnected. If you want to learn the basics of TFX, please take a look at the official tutorials. If you already know the basics and want to learn about this project, please take a look at the project repository. From the sections below, we will focus on the implementation details of the three most important components, Tuner, Trainer, and WandBPusher.
Reusability
Before diving into the actual code implementation, we need to understand one more very important concept: reusability.
In MLOps, it's common to run each component in separate and isolated Docker containers. This execution scheme allows us to allocate hardware resources differently based on each component's requirements. For example, we might need GPUs for Trainer and Tuner while data processing by ExampleGen and Transform might require high RAM capacity. Also, it becomes much easier to scale like allocating one or more GPUs and less or more RAM depending on how large data and model are taken into account.
By introducing containerization, it is non-trivial to define a common interfaces between containers to input the intermediate results such as transformed dataset. Instead, it's solved by having centralized Artifact Store (i.e. Google Cloud Storage, S3, etc.,) and SQL DBMS (Database Management System).
In each run of the pipeline, all intermediate results and status of each component are recorded in the Artifact Store and SQL DBMS respectively. With these, it becomes easier to track the lineage of runs of pipelines. Furthermore, we can even query the intermediate results from the Artifact Store and reuse it. For instance, in the case that we restart from tuning or training, we can reuse the transformed intermediate datasets, so we don't need to re-run data transformation tasks.

So how does TFX reuse the cached results? There are many factors to consider, but in a nutshell, it determines to re-use them if Docker container is instantiated by a different Docker image from the last run and if all the input parameters passed to the components are the same as before. It's important to keep in mind that we should avoid building a new Docker image whenever possible. That means we shouldn't hardcode something that we want to flexibly change at different runs.
All this brings us to Tuner and Trainer. It's common to fail on the currently chosen hyper-parameter search space, so we need to try out different search space in the next run. However, if we embed the search space inside the Trainer component, we need to build a new Docker image, which makes every component of the previous steps (i.e. ExampleGen, Transform, ...) run again. That means we cannot get the benefit of reusability. Instead, we can decouple and inject the search space into the Trainer at runtime.
Experiment tracking (Tuner)
TFX pipeline simply defines the relationship between components, how each should be executed, what artifacts should be taken as inputs, and what artifact should be outputted.
In the next section, we'll take a deeper look inside the Tuner, Trainer, and WandBPusher, and the components they are dependent on.
NOTE: Tuner is essentially a wrapper of KerasTuner. Its usage is almost identical, so if you don't know much about KerasTuner yet, please take a look at the basic tutorials.
💡
The Tuner component can be created by tfx.components.Tuner for local environment and from tfx.extensions.google_cloud_ai_platform.tuner.component.Tuner for Google Cloud Vertex AI environment. However, the usage is identical. When instantiating a class instance, as shown below (link), tuner_fn and custom_config are the important arguments that we need to focus on.
tuner = Tuner(tuner_fn="modules.tuning.tuner_fn",custom_config=TUNER_ARGS,....)
The tuner_fn points to the function which runs the tuning process, and custom_config delivers additional information to the function. For instance, TUNER_ARGS as defined below (link) carries information of which hyperparameters to tune (HYPER_PARAMETERS), how to adjust our hyperparameter tuning process (TUNER_CONFIGS), and uses W&B to log tuning process (WANDB_CONFIGS).
WANDB_CONFIGS = {"API_KEY": "WANDB_ACCESS_KEY","PROJECT": PIPELINE_NAME,}HYPER_PARAMETERS = {...."optimizer_type": {"type": "choice","values": ["Adam", "AdamW"],},"learning_rate": {"type": "float","min_value": 0.00001,"max_value": 0.1,"sampling": "log","step": 10},....}TUNER_CONFIGS = {"num_trials": 15}TUNER_ARGS = {...."hyperparameters": HYPER_PARAMETERS,"tuner": TUNER_CONFIGS,"wandb": WANDB_CONFIGS}
With this information, modules.tuning.tuner_fn function can access them through the custom_config property of the fn_args parameter as shown below (link).
The function starts by logging into W&B with the given API key. Then, it instantiates MyTuner class, a subclass of keras_tuner.RandomSearch that essentially defines the model, which hyperparameters to try, and how to log the tuning process in W&B.
def tuner_fn(fn_args: FnArgs) -> TunerFnResult:wandb_configs = fn_args.custom_config["wandb"]wandb.login(key=wandb_configs["API_KEY"])tuner = MyTuner(wandb_configs["PROJECT"],MyHyperModel(),max_trials=fn_args.custom_config["tuner"],hyperparameters=get_hyperparameters(fn_args.custom_config["hyperparameters"]),objective=keras_tuner.Objective("val_accuracy", "max"),....)train_dataset = input_fn(fn_args.train_files, ...)eval_dataset = input_fn(fn_args.eval_files, ...)return TunerFnResult(tuner=tuner,fit_kwargs={"x": train_dataset,"validation_data": eval_dataset,},)
The instantiated MyTuner object is used to create TunerFnResult. TunerFnResult is the returned value of tuner_fn, and it will eventually trigger a tuning process with the training and validation datasets. I didn't go through how training and validation datasets are obtained, but assume they can be accessed via fn_args.train_files and fn_args.eval_files which refers to where the files of each purposes are.
NOTE: If you are familiar with KerasTuner, you know we don't need the custom MyTuner class to just tune a model. However, since we log information to W&B whenever each tuning trial is triggered, it should be done in run_trial method. KerasTuner could run our tuning process in a distributed manner, so we don't know which hyperparameters are handled by which VM instances beforehand. Instead, we should gran the hyperparameter assigned to our current VM instance via trial parameter.
💡
class MyTuner(keras_tuner.RandomSearch):def run_trial(self, trial, *args, **kwargs):hp = trial.hyperparametersmodel = self.hypermodel.build(hp)optimizer_type = hp.get("optimizer_type")learning_rate = hp.get("learning_rate")weight_decay = hp.get("weight_decay")epochs = hp.get("finetune_epochs")log_name = f"tuning@opt-{optimizer_type}@lr-{learning_rate}@wd-{weight_decay}"wandb.init(project=self.wandb_project, config=hp.values, name=log_name,)wandb.log({"optimizer_type": optimizer_type})wandb.log({"learning_rate": learning_rate})wandb.log({"weight_decay": weight_decay})result = self.hypermodel.fit(hp, model, epochs=epochs,callbacks=[wandb.keras.WandbMetricsLogger(log_freq='epoch')],*args, **kwargs)wandb.finish()return result
MyTuner overrides run_trial which essentially builds a model instance with given hyper-parameters and calls keras.Model.fit method (link). Pretty Keras standard! It just does some additional work to initialize a W&B run, log the assigned hyperparameters to the run, and finish up the logging process.
Experiment tracking (Trainer)
After the tuning process, the Tuner component emits best_hyperparameters as an output artifact which you can access via outputs["best_hyperparameters"]. So, it should be passed to the Trainer component so that it can do full training with the best hyperparameters (link).
Trainer component can be created by tfx.components.Trainer for local environment and from tfx.extensions.google_cloud_ai_platform.trainer.component.Trainer for the Google Cloud Vertex AI environment. And the usage is identical just like Tuner. (Also, the usage is very similar to Tuner too.)
When creating an instance, run_fn and custom_config are passed along with the hyperparameters, and their purpose is exactly the same as tuner_fn and cusom_config in our Tuner component.
trainer = Trainer("run_fn"="modules.train.run_fn","hyperparameters"=tuner.outputs["best_hyperparameters"],"custom_config"=TRAINING_ARGS,....)TRAINING_ARGS = {..."wandb": WANDB_CONFIGS}
The only difference is that run_fn runs full training instead of hyperparameter tuning process with the best hyperparameters and custom_config only contains information about W&B since the best hyperparameters are carried in hyperparameters argument.
So, let's understand what happens inside run_fn function. Just like Tuner, we can get any given information via fn_args parameter. For instance, the best hyperparameters are stored in fn_args.hyperparameters, and fn_args.custom_config["wandb"] holds W&B related information as below (link).
def run_fn(fn_args: FnArgs):hp = keras_tuner.HyperParameters.from_config(fn_args.hyperparameters)wandb_configs = fn_args.custom_config["wandb"]wandb.login(key=wandb_configs["API_KEY"])wandb_project = wandb_configs["PROJECT"]....optimizer_type = hp.get("optimizer_type")learning_rate = hp.get("learning_rate")weight_decay = hp.get("weight_decay")epochs = hp.get("fulltrain_epochs")unique_id = wandb_configs["FINAL_RUN_ID"]wandb.init(project=wandb_project, config=hp.values, name=unique_id)wandb.log({"optimizer_type": optimizer_type})wandb.log({"learning_rate": learning_rate})wandb.log({"weight_decay": weight_decay})model = MyHyperModel().build(hp)model.fit(....callbacks = [...,wandb.keras.WandbMetricsLogger(log_freq='epoch')])wandb.finish()model.save(....)
As usual, it initializes W&B, logs which hyperparameters the full training process depends on, build the model with the chosen hyperparameters, and run the training. This is pretty Keras standard approach except W&B part, so the code should be straight forward to understand.
WandBPusher component
Finally, let's move on to the last component: WandBPusher.
It might look complicated, but its purpose is just to push the fully trained model to our W&B Model Registry (it has an optional feature to publish a working application to HuggingFace Space).
With these in mind, if you look at below code snippet (link), you will see that WANDB_PUSHER_ARGS can be split into two parts, one for W&B Model Registry, and the other one for HuggingFace Space!
WANDB_PUSHER_ARGS = {"access_token": "$WANDB_ACCESS_TOKEN","project_name": PIPELINE_NAME,"run_name": WANDB_RUN_ID,"model": trainer.outputs["model"]"model_name": "final_model","aliases": ["test_aliases"],"space_config": {"app_path": "huggingface.apps.gradio","hf_username": "chansung","hf_repo_name": PIPELINE_NAME,"hf_access_token": "$HF_ACCESS_TOKEN"}}pusher = WandBPusher(**wandb_pusher_args)
You can take a look into the actual implementation of the WandBPusher component (link), but just going through the arguments and their roles should be sufficient to understand how the component works in high level view.
Let's go for W&B Model Registry first:
- access_token: This is a W&B access token. You need it to interact with W&B programatically. We have used the same access token in Tuner and Trainer components.
- project_name: This is our W&B project name and this should be the same as the project name used in Tuner and Trainer components. Otherwise, we can not associate the model and where it belongs to.
- run_name: This specifies the run name in the designated project name. Since the final model is the output from Trainer component, we need to associate the model to the run name that is used to log Trainer's process.
- model: This points to the SavedModel artifact saved as the output from Trainer component.
- model_name: This is the identifier of the model in W&B Model Registry.
- aliases: This is the secondary identifier of the model in W&B Model Registry. It is particularly useful to identify a model with version name or latest.
Now, if you take a look at the final_model from Model Registry (see below), everything should make a lot more sense.
final_model
Version overview
Full Name
chansung18/tfx-vit-pipeline/final_model:v0
Aliases
1688113391
latest
v0
Tags
Digest
a957af88fda7ae3d08c8135509c3270e
Created By
Created At
June 30th, 2023 08:23:42
Num Consumers
0
Num Files
1
Size
319.2MB
TTL Remaining
Inactive
Description
Okay. Now let's go for the HuggingFace Space part:
- space_config.app_path: This is the module path where the application files and codes are. Under the path, all the to be hosted files in Hugging Face Space should be stored such as README.md, app.py, and requirements.txt.
- space_config.hf_username: Any repository in Hugging Face Hub should be identified in username/reponame, so this information is the first part of the repository identifier.
- space_config.hf_repo_name: This is just the second part of the repository identifier.
- space_config.hf_access_token: This is our Hugging Face access token and we need this to interact with HuggingFace Hub programatically.
The current version of the model is dynamically determined by WandBPusher component, and the to-be published or updated application should know about it because it has to use the model of that version. To this end, WandBPusher replaces all the special strings (placeholders) of all files under the space_config.app_path with the real values.
For instance:
- $MODEL_PROJECT is replaced by the project_name in the WANDB_PUSHER_ARGS.
- $MODEL_RUN is replaced by the run_name in the WANDB_PUSHER_ARGS.
- $MODEL_NAME is replaced by the model_name in the WANDB_PUSHER_ARGS.
- $MODEL_VERSION is replaced by the internally decided model version.
wandb.login(key=wb_token)wandb.init(project="$MODEL_PROJECT", id="$MODEL_RUN", resume=True)path = wandb.use_artifact('tfx-vit-pipeline/$MODEL_NAME:$MODEL_VERSION', type='model').download()tar = tarfile.open(f"{path}/$MODEL_FILENAME")tar.extractall(path=".")MODEL = tf.keras.models.load_model("./model")
With these replacements, the application can download the right version of the model. Below, you'll see an example of the replaced code snippet with the actual values.
wandb.login(key=wb_token)wandb.init(project="tfx-vit-pipeline", id="avi9pkis", resume=True)path = wandb.use_artifact('tfx-vit-pipeline/final_model:1683870912', type='model').download()tar = tarfile.open(f"{path}/model.tar.gz")tar.extractall(path=".")MODEL = tf.keras.models.load_model("./model")
You can find out the fully working application from this project at HuggingFace Space (link). If you are curious about the process, take a look at the implementation of the WandBPusher (link), code snippets with placeholders (link), and code snippets with the replaced values (link).
Conclusion
That gets us through our tutorial on TFX alongside W&B. By combining these two solutions together, we get more options to choose where to log experiments and model artifacts. We can decide to use both experiment tracking and model registry or just one of those. You may visualize the experiments in many different perspective while keep the model artifacts on Google Cloud Platform, or you may do both in W&B platform because it is easy to tract which model is produced from which pipeline and runs. Happy modeling!
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.