42)
set_seed("ggplot") plt.style.use(
Activation Statistics
Adapted from:
We need to have a way of looking inside models and diagnosing issues.
set_seed
set_seed (seed, deterministic=False)
Baseline
Let’s look at a fashion MNIST classification problem.
Conv2dWithReLU
Conv2dWithReLU (*args, nonlinearity=<function relu>, stride:Union[int,Tuple[int,int]]=1, padding:Union[str,int,Tuple[int,int]]=0, dilation:Union[int,Tuple[int,int]]=1, groups:int=1, bias:bool=True, padding_mode:str='zeros', device=None, dtype=None)
Convolutional neural network with a built in activation
CNN
CNN ()
Six layer convolutional neural network
Generally, we want a high learning rate to come up with generalizable algorithms. Let’s start with the relatively high 0.6.
def train(model, extra_cbs=None):
= [
cbs =10)),
MetricsCB(MulticlassAccuracy(num_classes
DeviceCB(),=True),
ProgressCB(plot
]if extra_cbs:
cbs.extend(extra_cbs)= TrainLearner(
learn
model,
fashion_mnist(),
F.cross_entropy,=0.6,
lr=cbs,
cbs
).fit()return learn
=CNN()) train(model
MulticlassAccuracy | loss | epoch | train |
---|---|---|---|
0.159 | 2.910 | 0 | train |
0.100 | 2.386 | 0 | eval |
Let’s look at the underlying activations
Hooks
Jeremy’s implementation is kind of a mess so I did a bit of refactoring. Hooks are just another kind of callback in the PyTorch universe, so we can adopt our Callback
conventions.
Hook
Hook (m, f)
Wrapper for a PyTorch hook, facilitating adding instance state
HooksCallback
HooksCallback (hook_cls, mods=None, mod_filter=<function noop>, on_train=True, on_valid=False)
Container for hooks with clean up and and options to target certain modules
That being implemented, we can subclass these for adding hook behaviors.
StoreModuleStats
StoreModuleStats (m, on_train=True, on_valid=False, periodicity=1)
A hook for storing the activation statistics
StoreModuleStatsCB
StoreModuleStatsCB (mods=None, mod_filter=<function noop>, on_train=True, on_valid=False, hook_kwargs=None)
Callback for plotting the layer-wise activation statistics
Now, we can rerun while keeping track of the activation stats
= CNN()
model = StoreModuleStatsCB(mods=model.layers)
cb =model, extra_cbs=[cb])
train(model cb.mean_std_plot()
MulticlassAccuracy | loss | epoch | train |
---|---|---|---|
0.169 | 2.274 | 0 | train |
0.199 | 2.096 | 0 | eval |
cb.hist_plot()
Jeremy makes the point that his network isn’t training because the weights are close to 0, which makes them “dead units.”
⚠️ Generally, the mean should be 0 and the standard deviation should be close to 1.
Ultimately, Jeremy recommends simply abandoning any training run where the activation variance increases and crashes.