Activation Statistics

Exploring what we can learn by closely inspecting the model internal distributions

Adapted from:

We need to have a way of looking inside models and diagnosing issues.


set_seed


def set_seed(
    seed, deterministic:bool=False
):
set_seed(42)
plt.style.use("ggplot")

Baseline

Let’s look at a fashion MNIST classification problem.


Conv2dWithReLU


def Conv2dWithReLU(
    args:VAR_POSITIONAL, nonlinearity:function=relu, stride:Union=1, padding:Union=0, dilation:Union=1, groups:int=1,
    bias:bool=True, padding_mode:Literal='zeros', device:NoneType=None, dtype:NoneType=None
):

Convolutional neural network with a built in activation


CNN


def CNN(
    
):

Six layer convolutional neural network

Generally, we want a high learning rate to come up with generalizable algorithms. Let’s start with the relatively high 0.6.

def train(model, extra_cbs=None):
    cbs = [
        MetricsCB(MulticlassAccuracy(num_classes=10)),
        DeviceCB(),
        ProgressCB(plot=True),
    ]
    if extra_cbs:
        cbs.extend(extra_cbs)
    learn = TrainLearner(
        model,
        fashion_mnist(),
        F.cross_entropy,
        lr=0.6,
        cbs=cbs,
    ).fit()
    return learn


train(model=CNN())
MulticlassAccuracy loss epoch train
0.159 2.910 0 train
0.100 2.386 0 eval

Let’s look at the underlying activations

Hooks

Jeremy’s implementation is kind of a mess so I did a bit of refactoring. Hooks are just another kind of callback in the PyTorch universe, so we can adopt our Callback conventions.


Hook


def Hook(
    m, f
):

Wrapper for a PyTorch hook, facilitating adding instance state


HooksCallback


def HooksCallback(
    hook_cls, mods:NoneType=None, mod_filter:function=noop, on_train:bool=True, on_valid:bool=False
):

Container for hooks with clean up and and options to target certain modules

That being implemented, we can subclass these for adding hook behaviors.


StoreModuleStats


def StoreModuleStats(
    m, on_train:bool=True, on_valid:bool=False, periodicity:int=1
):

A hook for storing the activation statistics


StoreModuleStatsCB


def StoreModuleStatsCB(
    mods:NoneType=None, mod_filter:function=noop, on_train:bool=True, on_valid:bool=False, hook_kwargs:NoneType=None
):

Callback for plotting the layer-wise activation statistics

Now, we can rerun while keeping track of the activation stats

model = CNN()
cb = StoreModuleStatsCB(mods=model.layers)
train(model=model, extra_cbs=[cb])
cb.mean_std_plot()
MulticlassAccuracy loss epoch train
0.169 2.274 0 train
0.199 2.096 0 eval

cb.hist_plot()

Jeremy makes the point that his network isn’t training because the weights are close to 0, which makes them “dead units.”

⚠️ Generally, the mean should be 0 and the standard deviation should be close to 1.

Ultimately, Jeremy recommends simply abandoning any training run where the activation variance increases and crashes.