Optimizing image classification with Weights & Biases
Learn how to build and track image classification models using Weights & Biases for experiment tracking and visualization.
Created on March 31|Last edited on February 18
Comment
Image classification plays a critical role in various AI applications, from medical diagnostics to autonomous driving. However, training and optimizing deep learning models for classification tasks can be complex, requiring careful experiment tracking, hyperparameter tuning, and performance monitoring. This is where Weights & Biases comes in - providing a powerful platform for logging, visualizing, and optimizing machine learning experiments.
In this guide, we’ll walk through how to build an image classification model and track its progress using W&B Models. As an example use case, we’ll apply this to wildlife conservation, where AI-powered classification models can help researchers analyze camera trap images and satellite data to monitor endangered species and environmental changes. By leveraging Weights & Biases, we can efficiently track model performance, compare training runs, and fine-tune hyperparameters to achieve better accuracy.
Let’s dive into the process of setting up an image classification pipeline, integrating Weights & Biases for experiment tracking, and optimizing model performance step-by- step.

Source: Author
Table of contents
Use case: Image classification in wildlife conservationKey technologies behind image classificationHow CNNs process imagesTools and frameworks for image classificationDeep learning frameworksObject detection APIsWeights & BiasesBuilding your first image classifierStep 1: Install dependenciesStep 2: Initialize Weights & BiasesStep 3: Check if a GPU is availableStep 4: Define Directory paths and other parametersStep 5: Data preparationStep 6: Model setupStep 7: Performance trackingStep 8: TrainingStep 9: Model evaluationStep 10: Model testingStep 11: Predicting and printing resultsStep 12: Logging sample predictionsConclusion
Use case: Image classification in wildlife conservation
One real-world application of image classification is wildlife conservation, where AI models help researchers analyze vast amounts of image data from satellite imagery and camera traps. These models assist in identifying species, monitoring habitat changes, and detecting environmental threats like wildfires and deforestation.
However, developing high-performing classification models in such applications requires more than just a well-architected deep learning pipeline - it demands efficient experiment tracking to compare training runs, monitor model improvements, and fine-tune hyperparameters. By leveraging experiment tracking and visualization tools, machine learning practitioners can optimize their workflows and improve model accuracy.


The images above are sample real-time satellite images from the Wildfire Prediction Dataset (Satellite Images). The image on the left depicts unaffected land, while the one on the right shows an active wildfire spreading. These datasets are crucial for training AI models to detect and classify wildfire events based on satellite imagery.
Beyond satellite imaging, researchers and conservationists are also leveraging camera traps as a cost-effective and efficient method for monitoring wildlife and habitats. With advancements in machine learning, AI models can now analyze footage from camera traps, automatically identifying and categorizing species based on pre-existing datasets.
For example, Conservation AI has deployed camera traps and drones to track endangered species. Their system has processed over 12.5 million images, detecting more than 4 million individual animal appearances across 68 species, including endangered pangolins in Uganda, gorillas in Gabon, and orangutans in Malaysia. These AI-driven approaches provide valuable data for conservation efforts, improving species monitoring and habitat protection.

Key technologies behind image classification

Deep learning-based image classification models rely on Convolutional Neural Networks (CNNs), a class of neural networks specifically designed to process visual data. CNNs are widely used for tasks such as object recognition, facial recognition, and medical image analysis, and they have become indispensable in wildlife conservation applications.
How CNNs process images
CNNs consist of multiple layers that extract features from an input image:
- Convolutional Layers – Apply filters to detect edges, textures, and patterns.
- Pooling Layers – Reduce spatial dimensions, retaining essential features while improving computational efficiency.
- Fully Connected Layers – Map extracted features to class probabilities, making the final classification decision.
By training CNNs on large labeled datasets, models can learn to classify images with high accuracy. In wildlife conservation, CNNs are trained on datasets of animals, plants, and environmental conditions to automate species recognition and detect anomalies like illegal logging or wildfires.
To ensure continuous improvements in model accuracy, it’s essential to track training metrics, compare different architectures, and fine-tune hyperparameters—a process made seamless with W&B’s experiment tracking tools.
Tools and frameworks for image classification

Developing image classification models requires powerful tools and frameworks, many of which integrate seamlessly with W&B for tracking and optimization.
Deep learning frameworks

- OpenCV – Facilitates image preprocessing, such as resizing and feature extraction.

Object detection APIs
- TensorFlow Object Detection API – Simplifies training object detection models for wildlife monitoring.

Weights & Biases

Source: Author
Building your first image classifier
Now, let’s walk through the process of building a wildlife image classifier while tracking experiments using Weights & Biases.
Step 1: Install dependencies
First, we start by setting up the environment and model by importing the essential libraries including Tensorflow, Keras for the model, and Weights & Biases for experiment tracking.
import numpy as npimport tensorflow as tffrom tensorflow.keras.preprocessing.image import ImageDataGeneratorfrom tensorflow.keras.applications import VGG16from tensorflow.keras.layers import Dense, Flatten, Dropoutfrom tensorflow.keras.models import Modelfrom PIL import ImageFileimport wandbfrom wandb.keras import WandbCallback
Step 2: Initialize Weights & Biases
wandb.login()wandb.init(project='wildfire_detection', config={"learning_rate": 0.001,"epochs": 5,"batch_size": 8,"target_size": (350, 350),})config = wandb.config
A Weights & Biases run is initialized to log and monitor model training metrics, configurations, and outputs in a project named 'wildfire_detection'.
Step 3: Check if a GPU is available
We will then verify if a GPU is available for faster training, printing a message based on the GPU's presence.
if tf.test.gpu_device_name():print('GPU device found:', tf.test.gpu_device_name())else:print("No GPU found. Please ensure that GPU is enabled in the runtime settings.")
Step 4: Define Directory paths and other parameters
We then will input the data directories by specifying paths to the training and validation data directories.
train_dir_wildfire = '/kaggle/input/wildfire-prediction-dataset/train'valid_dir = '/kaggle/input/wildfire-prediction-dataset/valid'target_size = config.target_sizebatch_size = config.batch_sizeseed = 42
Step 5: Data preparation
Here we started by handling any truncated images in the datasets by using the ImageFile module.
Handling the truncated images.
ImageFile.LOAD_TRUNCATED_IMAGES = True
Creating data generators
ImageDataGenerator with data augmentation for training data
train_datagen_wildfire = ImageDataGenerator(rescale=1./255,rotation_range=40, # Randomly rotate images in the range (degrees, 0 to 180)width_shift_range=0.2, # Randomly horizontal shift imagesheight_shift_range=0.2, # Randomly vertical shift imagesshear_range=0.2, # Shear Intensity (Shear angle in counter-clockwise direction)zoom_range=0.2, # Randomly zoom imagehorizontal_flip=True, # Randomly flip images horizontallyfill_mode='nearest' # Strategy used for filling in newly created pixels)train_generator_wildfire = train_datagen_wildfire.flow_from_directory(directory=train_dir_wildfire,target_size=target_size,batch_size=batch_size,class_mode='binary',shuffle=True,seed=seed)
ImageDataGenerator with data augmentation for validation data
validation_datagen = ImageDataGenerator(rescale=1./255)validation_generator = validation_datagen.flow_from_directory(directory=valid_dir,target_size=target_size,batch_size=batch_size,class_mode='binary',shuffle=False)
After that, we use ImageDataGenerator for preprocessing and augmenting the training data dynamically during model training to help the model generalize better. This approach is memory efficient as it loads images in batches, rather than loading all images into memory at once.
For the training_datagen_wildfire we implemented augmentations like rotation, flipping, and scaling as this can help the model learn more general features that are not specific to the conditions in which the original training images were captured. The validation data is only rescaled without further augmentation since data augmentation is generally not applied to validation or test data.
The flow_from_directory() method then prepares data loaders (train_generator_wildfire and validation_generator) that will automatically feed images to the model during training and validation, respectively.
Step 6: Model setup
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(target_size[0], target_size[1], 3))x = base_model.outputx = Flatten()(x)x = Dense(256, activation='relu')(x)x = Dropout(0.5)(x)predictions = Dense(1, activation='sigmoid')(x)model = Model(inputs=base_model.input, outputs=predictions)for layer in base_model.layers:layer.trainable = Falsemodel.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
Base_model = VGG16 loads VGG16, a pre-trained convolutional neural network model, as the base. The include_top=False option excludes the top (fully connected) layers, making it suitable for feature extraction in new classification tasks.
Additionally, we will add custom layers on top of VGG16, including flattening, dense, and a dropout layer with a rate of 0.5 to reduce overfitting as well as a final dense layer with sigmoid activation for binary classification (wildfire vs. no wildfire). Finally, the model is compiled with Adam optimizer and binary cross-entropy loss, suitable for binary classification tasks.
Step 7: Performance tracking
from tensorflow.keras.callbacks import Callbackclass CustomWandbCallback(Callback):def on_epoch_end(self, epoch, logs=None):logs = logs or {}wandb.log({'epoch': epoch,'loss': logs.get('loss'),'accuracy': logs.get('accuracy'),'val_loss': logs.get('val_loss'),'val_accuracy': logs.get('val_accuracy')})
CustomWandbCallback defines a custom Keras callback to log training and validation metrics (loss and accuracy) to Weights & Biases at the end of each epoch. This aids in tracking the model's performance over time, enabling real-time tracking and visualization of the training process.
steps_per_epoch = train_generator_wildfire.samples // batch_size
validation_steps = validation_generator.samples // batch_size
💡
Step 8: Training
We will be using the “fit” method to train the model on images from train_generator_wildfire, validating against a separate set of images from validation_generator. Training parameters such as steps per epoch, number of epochs, and validation steps are specified. The custom callback is included to log metrics.
model.fit(train_generator_wildfire,steps_per_epoch=steps_per_epoch,epochs=config.epochs,validation_data=validation_generator,validation_steps=validation_steps,callbacks=[CustomWandbCallback()] # Use the custom callback here)
Step 9: Model evaluation
After training the model we then evaluate it on the validation data by assessing the model's performance on the validation dataset using the evaluate method, printing the validation accuracy to give an idea of how well the model generalizes.
test_loss, test_acc = model.evaluate(validation_generator, steps=validation_generator.samples // validation_generator.batch_size)print('Validation accuracy:', test_acc)
Log final evaluation metrics
wandb.log({'Validation Loss': test_loss, 'Validation Accuracy': test_acc})
save the model as an artifact in W&B
model.save("wildfire_detection_model.h5")wandb.save("wildfire_detection_model.h5")wandb.Artifact("model", type="model", description="A model to detect wildfires from images")wandb.log_artifact("wildfire_detection_model.h5", type="model", name="wildfire_detection_model")
model.save() saves the trained model locally. wandb.log_artifact() logs the model as an artifact in Weights & Biases, providing versioning and easy access to the trained model for future reference or deployment.
Step 10: Model testing
Defining directory paths for testing data
test_dir_wildfire = '/kaggle/input/wildfire-prediction-dataset/test/wildfire'test_dir_nowildfire = '/kaggle/input/wildfire-prediction-dataset/test/nowildfire'
Here, the def Load_and_preprocess_images(dir_path, target_size) defines a function to load and preprocess images from given directories (test_dir_wildfire and test_dir_nowildfire), resize them to the target size (350x350), and applies the same rescaling as during training.
from tensorflow.keras.preprocessing.image import ImageDataGeneratorfrom tensorflow.keras.preprocessing import imageimport numpy as npimport osdef load_and_preprocess_images(dir_path, target_size):datagen = ImageDataGenerator(rescale=1./255)images = []labels = [] # Assuming 1 for wildfire and 0 for nowildfire for binary classificationfile_paths = []for file_name in os.listdir(dir_path):file_path = os.path.join(dir_path, file_name)img = image.load_img(file_path, target_size=target_size)img_array = image.img_to_array(img)img_array = np.expand_dims(img_array, axis=0)img_preprocessed = datagen.flow(img_array, batch_size=1)[0]images.append(img_preprocessed)file_paths.append(file_path)return np.vstack(images), file_pathstarget_size = (350, 350) # The target size used during trainingimages_wildfire, files_wildfire = load_and_preprocess_images(test_dir_wildfire, target_size)images_nowildfire, files_nowildfire = load_and_preprocess_images(test_dir_nowildfire, target_size)
Step 11: Predicting and printing results
Finally, the predict is used to classify images as wildfire or nowildfire based on the model's output probability. If the probability is greater than 0.5, the image is classified as wildfire; otherwise, it's classified as nowildfire. Predictions are made separately for images in the wildfire and nowildfire test directories, and results are printed alongside their file paths.
predictions_wildfire = model.predict(images_wildfire)for i, pred in enumerate(predictions_wildfire):print(f"{files_wildfire[i]} - {'Wildfire' if pred >= 0.5 else 'No Wildfire'}: {pred[0]}")predictions_nowildfire = model.predict(images_nowildfire)for i, pred in enumerate(predictions_nowildfire):print(f"{files_nowildfire[i]} - {'Wildfire' if pred >= 0.5 else 'No Wildfire'}: {pred[0]}")
Step 12: Logging sample predictions
Convert predictions to a binary label based on the threshold of 0.5
binary_predictions_wildfire = [1 if pred >= 0.5 else 0 for pred in predictions_wildfire]binary_predictions_nowildfire = [1 if pred >= 0.5 else 0 for pred in predictions_nowildfire]
Providing the true labels.
true_labels_wildfire = [1] * len(files_wildfire)true_labels_nowildfire = [0] * len(files_nowildfire)
Define a function to log predictions along with their images and labels to Weights & Biases.
def log_predictions_wandb(image_paths, predictions, true_labels, title, num_samples=10):columns = ["Image", "Predicted Label", "True Label"]wandb_table = wandb.Table(columns=columns)for i in range(min(num_samples, len(image_paths))):img = wandb.Image(image.load_img(image_paths[i], target_size=target_size))pred_label = "Wildfire" if predictions[i] >= 0.5 else "No Wildfire"true_label = "Wildfire" if true_labels[i] == 1 else "No Wildfire"wandb_table.add_data(img, pred_label, true_label)wandb.log({title: wandb_table})
log sample predictions for both wildfire and nowildfire images
og_predictions_wandb(files_wildfire, binary_predictions_wildfire, true_labels_wildfire, "Wildfire Predictions")log_predictions_wandb(files_nowildfire, binary_predictions_nowildfire, true_labels_nowildfire, "Nowildfire Predictions")
Finish the run
wandb.finish()
This part of the code logs 10 sample predictions to Weights & Biases by creating tables with images, predicted labels, and true labels for both wildfire and nowildfire images, allowing for visual inspection and evaluation of the model's predictions.
Here are the 10 samples for “wildfire”:

Source: Author

Source: Author
And for the “no wildfire”:

Source: Author

Source: Author
Some Important charts logged by W&B:
These charts include the loss and accuracy graphs which are always necessary to check when training a given model.

Source: Author
This code demonstrates an end-to-end workflow for training a deep learning model on an image classification task, including preprocessing, model training, evaluation, prediction, and logging of experiments with Weights & Biases. It illustrates key concepts like using pre-trained models for transfer learning, custom callbacks for integrating with experiment tracking tools and processing images for predictions.
Conclusion
Image classification is a powerful tool for solving real-world problems, but achieving high model performance requires rigorous experiment tracking. Weights & Biases simplifies the process, providing a centralized platform to log, compare, and optimize deep learning models.
By integrating Weights & Biases into wildlife conservation applications, researchers can enhance species monitoring and environmental protection while improving model accuracy through better experiment tracking.
Whether you’re working on conservation, healthcare, or industrial applications, Weights & Biases ensures that your models are optimized, reproducible, and continuously improving.
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.