TorchExplorer: interactive neural network inspection
A brief overview of the TorchExplorer tool.
Created on November 13|Last edited on December 11
Comment
As part of my research, I tend to work with pretty wacky, nonstandard architectures. When they inevitably don't work, I often find myself wanting a tool that lets me quickly and interactively see what's happening at each layer in the network. The result is TorchExplorer.
TorchExplorer integration is as easy as it gets:
torchexplorer.setup() # Call once before wandb.init(), not needed for standalonewandb.init()model = ...torchexplorer.watch(model, backend='wandb') # Or 'standalone'# Training loop...
This automatically uploads a Vega Custom Chart to wandb containing an interactive, module-level view of the network. Standalone deployment to a local web server is also possible. Feel free to play around with the CIFAR-10 ResNet18 demo below (fullscreen the application to visualize more module panels). Navigating the left hand side panel is as easy as clicking a module to open its interior graph and clicking in the upper-left hand parent list to get back to where you started. Drag and drop a module into one of the column panels to see its histograms in more detail. Enlarge the window to see more columns.
Here are two more demos showing the kind of architectures that can be handled by TorchExplorer. Both the VQVAE and the transformer encoder feature modules with multiple inputs / outputs. The VQVAE also has a nondifferentiable VectorQuantizer block, which is appropriately disconnected. Note that the data here is dummy data for the demo, there's no meaningful training task.
Lastly, this wouldn't be possible without the ambitious custom chart support from the excellent weights and biases team. I was also heavily inspired by the wandb watch tool. While I did add a standalone option, wandb is the really the intended usecase.
User interface

Explorer
The left-hand panel contains a module-level graph of your network architecture, automatically extracted from the autograd graph. Clicking on a module will open its "internal" submodules. To return to a parent module, click on the appropriate element in the top-left expanding list.
Nodes. A node in the explorer graph is either a) an input/output placeholder for the visualized module, or b) a specific invocation of a submodule of the visualized module. If the visualized module has multiple inputs to its forward function, these will appear as multiple nodes ("Input 0", "Input 1", ...). A similar logic applies to outputs. All other nodes represent a distinct submodule invocation. This means that if a particular submodule is called twice in one forwards pass, these two invocations show up separately in the explorer graph. Their histograms and "internal" submodules will also be distinct.
Edges. An edge between two nodes means that there exists a autograd trace from some output of the parent node to some input of the child mode. The number of incoming / outgoing edges to a node is unrelated to how many inputs/outputs the forward function takes. To illustrate this, let's consider a Linear node with two incoming edges from two distinct parent nodes. This can arise if, say, the outputs of the parent modules are added together and then passed to the single forward function input. Conversely, consider a TransformerEncoderLayer node, which accepts multiple inputs. There may still only be one incoming edge from a parent module, if all the inputs to the TransformerEncoderLayer are computed from this source.
Tooltips. Mousing over explorer graph nodes displays a helpful tooltip. The first few lines summarize the shapes of the input / output tensors, recorded once from the first forwards pass through the network. The subsequent lines parse key information from module.extra_repr(). This string parsing is designed around common PyTorch extra_repr() implementations (e.g., nn.Conv2d). The string is first split on commas, with each resulting string becoming one row in the tooltip. If a resulting substring is of the form "{key}={value}", these become the key and value pairs for the tooltip. Otherwise the entire string is treated as a value with an empty key, visualized using a dash. This occurs for the in_channels and out_channels attributes for Conv2d.
Panels
To inspect a module in more detail, just drag and drop it into one of the columns on the right. The histogram colors don't represent anything intrinsically—they're just to help identify in the explorer which modules are being visualized.
Histograms
Each vertical "slice" of a histogram encodes the distribution of values at the corresponding x-axis time. The y-axis displays the minimum / maximum bounds of the histogram. Completely white squares mean that no data fell in that bin. A bin with one entry will be shaded light gray, with the color intensifying as more values fall in that bin (this encodes the "height" of the histogram). The dashed horizontal line is the line.
Note that the tensors populating histograms are processed in two ways. First, for performance reasons they are randomly subsampled according to the sample_n parameter. This is 100 by default, and passing None will disable sub-sampling. Note that this sampling means that histograms that should be the same may look slightly different (e.g., output of parent node and input of child node). Second, a fraction of the most extreme values from the median are rejected to prevent outliers from compressing the histogram. This fraction is 0.1 by default, and can be disabled by passing 0.0 to the reject_outlier_propertion parameter.
For the following explanations, I'll be referencing this module:
class TestModule(nn.Module):def __init__(self):super().__init__()self.fc = nn.Linear(20, 20)self.activation = nn.ReLU()def forward(self, x):x1 = self.fc(x)x2 = self.activation(x1)return x2
Input/output histograms. These histograms represent the values passed into and out of the module's forward method, captured using hooks. For instance, if we are visualizing the fc layer in the above TestModule, the input 0 histogram will be the histogram of `x`, and the output 0 histogram will be the histogram of x1. If fc accepted two inputs self.fc(x, y), then the histogram would show input 0 and input 1. Note that the input 0 histogram on the activation module will look very close to the output 0 histogram on the fc module, with some small differences due to random sampling.
Input/output gradient norm histograms. These histograms capture tensor gradients from backward passes through the module. Unlike parameter gradients, we record here the -norm of the gradients, averaged over the batch dimension. This means that if the gradient of the loss with respect to the module input is of dimension , we first flatten to a vector and take the row-wise norm to get a length vector. These values then populate the histogram. For the fc layer in the above example, input 0 (grad norm) would apply this procedure to the gradient of the loss with respect to x, while output 0 (grad norm) would apply this procedure to the gradient of the loss with respect to y.
Parameter histograms. After the input/output histograms are extracted, all submodules will have their immediate parameters (module._parameters) logged as histograms. Note that this is not the same as module.parameters(), which would also recurse to include all child parameters. Some modules (particularly activations) have no parameters and nothing will show up in the interface. For instance, TestModule above has no trainable immediate parameters; fc will have weight and bias parameters; and `activation` will again have nothing.
Parameter gradient histograms. After the backward call is completed, each parameter will have a .grad attribute storing the gradient of the loss with respect to that parameter. This tensor is directly passed to the histogram. Unlike the input/output gradients, no norms are computed.
Add a comment