logistic-regression-iris
A logistic regression model trained on the Iris dataset.
It takes two inputs: 'PetalLengthCm' and 'PetalWidthCm'. It predicts whether the species is 'Iris-setosa'.
It is a PyTorch adaptation of the scikit-learn model in Chapter 10 of Aurelien Geron's book 'Hands-On Machine Learning with Scikit-Learn, Keras, and TensorFlow'.
Code: https://github.com/sambitmukherjee/handson-ml3-pytorch/blob/main/chapter10/logistic_regression_iris.ipynb
Model: https://huggingface.co/sadhaklal/logistic-regression-iris
Project visibility
Public
Last active
3/11/2024, 7:51:23 AM
Owner
Contributors
1 user
Total runs
1
Total compute
2 minutes
Export & update data
Use our Public API to export data or update data, such as editing the config of existing runs. Learn more in the docs →
Find the run path
To use the public API, you'll often need the run path which is <entity>/<project>/<run_id>. In the app UI, open a run page, then click on the Overview tab to see the run path.
Update config for an existing run
This example updates one of your configuration settings.
import wandb
api = wandb.Api()
run = api.run("sadhaklal/logistic-regression-iris/<run_id>")
run.config["key"] = updated_value
run.update()
Export metrics from a single run to a CSV file
This snippet finds all the metrics saved for a single run and saves them to a CSV file.
import wandb
api = wandb.Api()
# run is specified by <entity>/<project>/<run_id>
run = api.run("sadhaklal/logistic-regression-iris/<run_id>")
# save the metrics for the run to a csv file
metrics_dataframe = run.history()
metrics_dataframe.to_csv("metrics.csv")
Read metrics for a run
Pull down the accuracy and timestamps for logged metric data. In this example, data was logged with wandb.log({"accuracy": acc}) to a run with the run path <entity>/<project>/<run_id>.
import wandb
api = wandb.Api()
run = api.run("sadhaklal/logistic-regression-iris/<run_id>")
if run.state == "finished":
for i, row in run.history().iterrows():
print(row["_timestamp"], row["accuracy"])
Get unsampled metric data
When you pull data from history, by default it's sampled to 500 points. Get all the logged data points using run.scan_history(). Here's an example downloading all the loss data points logged in history.
import wandb
api = wandb.Api()
run = api.run("sadhaklal/logistic-regression-iris/<run_id>")
history = run.scan_history()
losses = [row["loss"] for row in history]
Download the best model file from a sweep
This snippet downloads the model file with the highest validation accuracy from a sweep with runs that saved model files to model.h5.
import wandb
api = wandb.Api()
sweep = api.sweep("sadhaklal/logistic-regression-iris/<sweep_id>")
runs = sorted(sweep.runs,
key=lambda run: run.summary.get("val_acc", 0), reverse=True)
val_acc = runs[0].summary.get("val_acc", 0)
print(f"Best run {runs[0].name} with {val_acc}% validation accuracy")
runs[0].file("model.h5").download(replace=True)
print("Best model saved to model-best.h5")
