Skip to main content

object_argsort_autoregressive

Experiments evaluating learning curves of Abstractor and Transformer models on an autoregressive object sorting task. N = 48 'objects' are generated as 12-dimensional gaussian vectors. They are generated via two sets of 'attribute objects' which each have an associated ordering. The first set of objects is 4-dimensional and has 4 possible values (each of which is a gaussian vector). The second set of objects is 8-dimensional and has 12 possible values (each of which is a gaussian vector). The 48 objects are formed by a cartesian product of the two sets of objects/attributes, where the first set of attributes forms the primary key and the second set of attributes forms the secondary key. This forms a strict ordering on the 48 objects. Models are trained to sort random permutations of the objects autoregressively. The task is to predict the `argsort` of sequences of 10 randomly permuted objects. For each model, learning curves are evaluated by training on subsets of the training data of varying size and evaluating on a hold-out test set. A RelationalAbstractor, SimpleAbstractor, and a Transformer model are evaluated. A second sorting task is formed by shuffling the order underlying the first attribute, and keeping the second attribute's ordering fixed. We evaluate pre-training on this task. When pre-training, the pre-training dataset contains 1000 sequences. The transformer and abstractor have hyper-parameters chosen such that they have a roughly equal parameter count (in fact, the transformer is larger): transformer_kwargs = dict( num_layers=4, num_heads=2, dff=64, input_vocab='vector', target_vocab=seqs_length+1, output_dim=seqs_length, embedding_dim=64) param count = 469,898 ---------- rel_abstracter_kwargs = dict( num_layers=2, num_heads=2, dff=64, input_vocab='vector', target_vocab=seqs_length+1, output_dim=seqs_length, embedding_dim=64, rel_attention_activation='softmax' ) param count = 386,954 ---------- simple_abstractor_kwargs = dict( embedding_dim=64, input_vocab='vector', target_vocab=seqs_length+1, output_dim=seqs_length, abstractor_kwargs=dict(num_layers=1, num_heads=4, dff=64, use_pos_embedding=False, mha_activation_type='softmax'), decoder_kwargs=dict(num_layers=1, num_heads=4, dff=64, dropout_rate=0.1)) param count = 219,210 ---------- symbolic_abstracter_kwargs = dict( num_layers=2, num_heads=2, dff=64, input_vocab='vector', target_vocab=seqs_length+1, output_dim=seqs_length, embedding_dim=64, rel_attention_activation='softmax' ) param count = 386,954 ---------- ablation_abstractor_kwargs = dict( num_layers=2, num_heads=2, dff=64, input_vocab='vector', target_vocab=seqs_length+1, output_dim=seqs_length, embedding_dim=64, use_self_attn=True, use_encoder=True, mha_activation_type='softmax' ) param count = 386,954
Project visibility
Public
Last active
5/27/2023, 6:42:51 AM
Contributors
1 user
Total runs
3162
Total compute
4 days
Export & update data
Use our Public API to export data or update data, such as editing the config of existing runs. Learn more in the docs →
Find the run path
To use the public API, you'll often need the run path which is <entity>/<project>/<run_id>. In the app UI, open a run page, then click on the Overview tab to see the run path.
Update config for an existing run
This example updates one of your configuration settings.
import wandb api = wandb.Api() run = api.run("abstractor/object_argsort_autoregressive/<run_id>") run.config["key"] = updated_value run.update()
Export metrics from a single run to a CSV file
This snippet finds all the metrics saved for a single run and saves them to a CSV file.
import wandb api = wandb.Api() # run is specified by <entity>/<project>/<run_id> run = api.run("abstractor/object_argsort_autoregressive/<run_id>") # save the metrics for the run to a csv file metrics_dataframe = run.history() metrics_dataframe.to_csv("metrics.csv")
Read metrics for a run
Pull down the accuracy and timestamps for logged metric data. In this example, data was logged with wandb.log({"accuracy": acc}) to a run with the run path <entity>/<project>/<run_id>.
import wandb api = wandb.Api() run = api.run("abstractor/object_argsort_autoregressive/<run_id>") if run.state == "finished": for i, row in run.history().iterrows(): print(row["_timestamp"], row["accuracy"])
Get unsampled metric data
When you pull data from history, by default it's sampled to 500 points. Get all the logged data points using run.scan_history(). Here's an example downloading all the loss data points logged in history.
import wandb api = wandb.Api() run = api.run("abstractor/object_argsort_autoregressive/<run_id>") history = run.scan_history() losses = [row["loss"] for row in history]
Download the best model file from a sweep
This snippet downloads the model file with the highest validation accuracy from a sweep with runs that saved model files to model.h5.
import wandb api = wandb.Api() sweep = api.sweep("abstractor/object_argsort_autoregressive/<sweep_id>") runs = sorted(sweep.runs, key=lambda run: run.summary.get("val_acc", 0), reverse=True) val_acc = runs[0].summary.get("val_acc", 0) print(f"Best run {runs[0].name} with {val_acc}% validation accuracy") runs[0].file("model.h5").download(replace=True) print("Best model saved to model-best.h5")