Skip to main content

DL4J with W&B

A simple example of using W&B for tracking training of Neural Networks on the JVM via DL4J
Created on April 27|Last edited on July 27

Introduction

Setting aside for a moment that Python has become dominant language for Deep Learning, we going to visit a great java framework for training Deep Neural Nets, DL4J, and how we can use W&B to track our training experiments all from the comfort of a Scala notebook.
I have always enjoyed coding in Scala and if you would ask me 7 years ago, I would have through it would be a much more prominent language and Data Science and ML, but it really never seemed to take off. Aside from being the language behind Spark and Akka, it is a JVM language and has significant interoperability with Java libraries, opening up a very rich ecosystem. As for Java, well, Java might be considered the language of the enterprise as huge investments have been made in Java based systems, it provides a reason why a deep learning framework for the JVM might be relevant, if not necessary.

W&B Client

Getting up and running with W&B for any python based ML project is painless. All that is required is a
pip install wandb
and from here, you can be tracking your experiments in as little as 60 seconds! But what about non python workflows, specifically those originated in a JVM language? In this report, we'll comment on two methods
The W&B Java Client follows a builder pattern to create experiments that will be tracked. It is limited to only logging metrics from your workflow and does not currently support any richer features. In order to extend this you should be pretty fairly competent with Java and gRPC to get everything working as expected. I can't claim to be either, so I opted for a different approach. specifically, using py4j to provide tracking on my java based projects.

W&B + py4j

Py4j offers a bridge between Python and Java (and other JVM languages by extension). I'll be the first to admit that this implementation is fairly obtuse and clunky, but it works and it was up in running in a reasonable amount of time.
This approach requires a java interface, and a python class implementing the java interface. We'll then have a scala wrapper for the python class which will start the python process, create the gateway, and then allow you to interact with W&B from you scala project.
The Java Interfaces is seen below
package py4j.wandb;
import java.util.List;
import java.util.Map;

public interface IWandB {
// The most commonly used functions/objects are:
// - wandb.login -> login
// - wandb.init — initialize a new run at the top of your training script (in wandb.sdk.wandb_init.py line 742)
// - wandb.config — track hyperparameters and metadata
// - wandb.log — log metrics and media over time within your training loop
public Boolean login(
String anonymous,
String key,
Boolean relogin,
String host,
Boolean force,
Integer timeout
);

public Boolean log(Map<String, Object> data, Integer step, Boolean commit, Boolean sync);
public Boolean logArtifact(String artifactPath, String name, String type, List<String> aliases);
public Boolean init(String job_type, String dir, Map<String, Object> config,
String project, String entity, Boolean reinit, List<String> tags,
String group, String name, String notes,
Map<String, Object> magic, List<String> configExcludeKeys, List<String> configIncludeKeys,
String anonymous, String mode, Boolean allowValChange,
Boolean resume, Boolean force, Object tensorboard, Object syncTensorboard, Object monitorGym, Object saveCode,
String id, Map<String, Object> settings);
public Boolean finish();
The signature of these functions closely resembles those of their python counterpart. In some cases, when it wasn't clear what the argument type was, a java Object was used. The Python class, which will implement this interface, will be such that all the methods implemented will invoke the appropriate W&B methods. For example, the method init in the Python class will call wandb.init with all the specified arguments. See below for how the Python class looks. Not all of the methods are list, but it should convey the main idea of how this will work.
class PyIWandB(object):
def __init__(self):
self.active_run = None
self.config = None
def login(
self,
anonymous=None,
key=None,
relogin=None,
host=None,
force=None,
timeout=None,
):
res = wandb.login(anonymous, key, relogin, host, force, timeout)
## more methods
class Java:
implements = ["py4j.wand.IWandB"]

The last piece of this is the Scala wrapper for our Python class implementing the Java Interface. This wrapper will
  • Start the Python process and create either a JavaGateway or a ClientServer.
  • Pass a Python instance implementing the Java interface as a python_server_entry_point parameter.
  • Start a GatewayServer or ClientServer on the Java side.
  • Call getPythonServerEntryPoint by providing the list of interfaces the Python entry point is expected to implement.
  • Expose methods to log information regarding your experiment.
Example usage could look like
import py4j.wandb._
import py4j.wandb.Implicits._
val wandb = new WandB("wandb-session.log")
wandb.login(key=API_KEY)
wandb.init(project = "myproject")
val data = Map[String, Any]("metric" -> 0.1)
wandb.log(data)
wandb.finish
wandb.shutdownGateway
very limited bindings available
  • login
  • init - when called, creates a run on the python side.
  • log - logs detail with the run
  • logArtifact - currently only works with files and directories
  • finish - finishes the run.
The WandB scala constructor takes a single argument, which is the location where logging should be saved.
Other immediately useful methods that would be simple to implement would include
  • useArtifact
  • logMedia - media types would probably be passed around as bytestreams, strings, or file paths.

ND4J

ND4J is a scientific computing library for the JVM. It is meant to be used in production environments rather than as a research tool, which means routines are designed to run fast with minimum RAM requirements. The main features are: A versatile n-dimensional array object with a mix of numpy operations and tensorflow/pytorch operations. It is described as numpy++ for java.

DL4J

DL4J: High level API to build MultiLayerNetworks and ComputationGraphs with a variety of layers, including custom ones. Supports importing Keras models from h5, including tf.keras models (as of 1.0.0-beta7) and also supports distributed training on Apache Spark.

Data

Our old friend MNIST.

Training

Callbacks

In DL4J, callbacks are referred to as listeners. The listeners let you "hook" into certain events in DL4J, which allows you to collect and print information for tasks like training. A simple W&B aware listener could look like the following
import org.deeplearning4j.optimize.api.BaseTrainingListener;
import java.io.Serializable;
import org.deeplearning4j.nn.api.Model;
import org.nd4j.linalg.dataset.DataSet;

case class WandBListener(logIteration: Int = 10,
testDataset: DataSet,
run: WandB) extends BaseTrainingListener with Serializable {

if(logIteration <= 0) throw new Exception(s"Iteration must be greater than 0")

override def iterationDone(model: Model, iteration: Int, epoch: Int): Unit = {
if(iteration % logIteration == 0) {
val trainingScore = model.score();
val testScore = model.asInstanceOf[MultiLayerNetwork].score(testDataset)
val data = Map[String, Any](
"epoch" -> epoch, "iteration" -> iteration,
"test_metric" -> testScore, "train_metric" -> trainingScore
)
run.log(data = data, step = iteration)
print(s"Score on train dataset at iteration $iteration is $trainingScore")
println(s"\t|| Score on test dataset at iteration $iteration is $testScore")
}
}
}
This should work pretty well in a distributed setting as well provided we rethink the set up of the Scala wrapper. Using the wandb java-client, you can set up the listener in a very similar way.
import org.deeplearning4j.optimize.api.BaseTrainingListener;
import java.io.Serializable;
import com.wandb.client._
import org.json.JSONObject;

case class WandBListener(logIteration: Int = 10,
testDataset: DataSet,
run: WandbRun) extends BaseTrainingListener with Serializable {

if(logIteration <= 0) throw new Exception(s"Iteration must be greater than 0")

override def iterationDone(model: Model, iteration: Int, epoch: Int): Unit = {
if(iteration % logIteration == 0) {
val trainingScore = model.score();
val testScore = model.asInstanceOf[MultiLayerNetwork].score(testDataset)
val data = new JSONObject()
data.put("epoch", epoch)
data.put("iteration", iteration)
data.put("train_loss", trainingScore)
data.put("test_loss", testScore)
run.log(data)
print(s"Score on train dataset at iteration $iteration is $trainingScore")
println(s"\t|| Score on test dataset at iteration $iteration is $testScore")
}
}
}

wandb.init

And when it comes time for training and tracking,
val rate=0.015
val numEpochs=100
val numColumns=28
val randomSeed=123
val numRows=28
val outputNum=10

val wandbConfig = Map[String, Any](
"numRows" -> numRows, "numColumns" -> numColumns, "outputNum" -> outputNum,
"randomSeed" -> randomSeed, "numEpochs" -> numEpochs, "rate" -> rate
)
// set some tags
val wandbTags = List("scala", "dl4j").asJava
// start wandb run
wandb.init(project=projectName, config=wandbConfig, tags=wandbTags)
val model = new MultiLayerNetwork(modelConf);
// print and log score every 10 iterations
model.setListeners(new WandBListener(10, test, wandb));
// initialize model
model.init();
// start training
(
1 to wandbConfig.get("numEpochs").get.asInstanceOf[Int]
).foreach{
i => model.fit(train)
}

// finish wandb run
wandb.finish()

Evaluation


Run set
9
Run set 2


W&B + gRPC

Using the available W&B Java Client is also entirely possible. While there hasn't been a PR in a long while, you can still get it up and running without issue.
pip install wandb[grpc]==0.10.32 -q --upgrade
git clone https://github.com/wandb/client-java.git
wandb login $WANDB_API_KEY
apt-get install maven &> /dev/null
cd client-java
make install
Once this is done, set the configuration
import scala.collection.JavaConverters._
import com.wandb.client._
import org.json.JSONObject;

val rate=0.0015
val numEpochs=1
val numColumns=28
val randomSeed=123
val numRows=28
val outputNum=10
val config = new JSONObject()
val configMap = Map(
"numRows" -> numRows, "numColumns" -> numColumns, "outputNum" -> outputNum,
"randomSeed" -> randomSeed, "numEpochs" -> numEpochs, "rate" -> rate
)
configMap.foreach{ case (k,v) => config.put(k,v) }
val tags = List("scala", "dl4j", "client-java").asJava
Create the Listener
case class WandBListener(logIteration: Int = 10,
testDataset: DataSet,
run: WandbRun) extends BaseTrainingListener with Serializable {

if(logIteration <= 0) throw new Exception(s"Iteration must be greater than 0")

override def iterationDone(model: Model, iteration: Int, epoch: Int): Unit = {
if(iteration % logIteration == 0) {
val trainingScore = model.score();
val testScore = model.asInstanceOf[MultiLayerNetwork].score(testDataset)
val data = new JSONObject()
data.put("epoch", epoch)
data.put("iteration", iteration)
data.put("train_loss", trainingScore)
data.put("test_loss", testScore)
run.log(data)
println(s"Score on train dataset at iteration $iteration is $trainingScore")
println(s"Score on test dataset at iteration $iteration is $testScore")
}
}
}
Initialize the run, and start training the model.
val runBuilder = new WandbRun.Builder()
runBuilder.withConfig(config).withProject("dl4j-wandb-java-client").setTags(tags).setJobType("training")
val run = runBuilder.build
model.setListeners(new WandBListener(model, 10, test, run));
model.init();

(1 to 500).foreach{
i => model.fit(train)
if(i % 50 == 0){
val testLogLoss = model.score(test)
val trainingLogLoss = model.score(train)
}
}