Activation Statistics

Exploring what we can learn by closely inspecting the model internal distributions

Adapted from:

We need to have a way of looking inside models and diagnosing issues.


source

set_seed

 set_seed (seed, deterministic=False)
set_seed(42)
plt.style.use("ggplot")

Baseline

Let’s look at a fashion MNIST classification problem.


source

Conv2dWithReLU

 Conv2dWithReLU (*args, nonlinearity=<function relu>,
                 stride:Union[int,Tuple[int,int]]=1,
                 padding:Union[str,int,Tuple[int,int]]=0,
                 dilation:Union[int,Tuple[int,int]]=1, groups:int=1,
                 bias:bool=True, padding_mode:str='zeros', device=None,
                 dtype=None)

Convolutional neural network with a built in activation


source

CNN

 CNN ()

Six layer convolutional neural network

Generally, we want a high learning rate to come up with generalizable algorithms. Let’s start with the relatively high 0.6.

def train(model, extra_cbs=None):
    cbs = [
        MetricsCB(MulticlassAccuracy(num_classes=10)),
        DeviceCB(),
        ProgressCB(plot=True),
    ]
    if extra_cbs:
        cbs.extend(extra_cbs)
    learn = TrainLearner(
        model,
        fashion_mnist(),
        F.cross_entropy,
        lr=0.6,
        cbs=cbs,
    ).fit()
    return learn


train(model=CNN())
MulticlassAccuracy loss epoch train
0.159 2.910 0 train
0.100 2.386 0 eval

Let’s look at the underlying activations

Hooks

Jeremy’s implementation is kind of a mess so I did a bit of refactoring. Hooks are just another kind of callback in the PyTorch universe, so we can adopt our Callback conventions.


source

Hook

 Hook (m, f)

Wrapper for a PyTorch hook, facilitating adding instance state


source

HooksCallback

 HooksCallback (hook_cls, mods=None, mod_filter=<function noop>,
                on_train=True, on_valid=False)

Container for hooks with clean up and and options to target certain modules

That being implemented, we can subclass these for adding hook behaviors.


source

StoreModuleStats

 StoreModuleStats (m, on_train=True, on_valid=False, periodicity=1)

A hook for storing the activation statistics


source

StoreModuleStatsCB

 StoreModuleStatsCB (mods=None, mod_filter=<function noop>, on_train=True,
                     on_valid=False, hook_kwargs=None)

Callback for plotting the layer-wise activation statistics

Now, we can rerun while keeping track of the activation stats

model = CNN()
cb = StoreModuleStatsCB(mods=model.layers)
train(model=model, extra_cbs=[cb])
cb.mean_std_plot()
MulticlassAccuracy loss epoch train
0.169 2.274 0 train
0.199 2.096 0 eval

cb.hist_plot()

Jeremy makes the point that his network isn’t training because the weights are close to 0, which makes them “dead units.”

⚠️ Generally, the mean should be 0 and the standard deviation should be close to 1.

Ultimately, Jeremy recommends simply abandoning any training run where the activation variance increases and crashes.