Minibatch training

Reviewing cross entropy, the logsumexp trick, and training a categorical model with apropriate loss functions

Adapted from:

Cross entropy loss

Continuing the simple model from the previous notebook, we need to implement a formally apropriate loss function. Regression is inapropriate for categorical outputs because it implies that different categories are different “distances” from eachother depending on their ordinal.

The proper output shall be a probability for each categories and the loss function is known as Cross Entropy loss.

In information theory, the cross-entropy between two probability distributions \(p\) and \(q\) over the same underlying set of events measures the average number of bits needed to identify an event drawn from the set if a coding scheme used for the set is optimized for an estimated probability distribution \(q\), rather than the true distribution \(p\).

https://en.wikipedia.org/wiki/Cross-entropy

This works by:

  1. The model outputs an unnormalized logit for each category (\(\vec{z}\))
  2. The softmax of the output (i.e., expotentiating and dividing by the sum of the expotentiated values) is taken

\[p_{y_i}=\sigma_{\vec{z}}(z_i)=\frac{e^{z_i}}{\Sigma e_{z_j}}\]

  1. The entropy is computed between each softmax and its corresponding label

\[ -\Sigma_{i=1}^{N} \left[ y_i log( p_{y_i} ) + ( 1 - y_i ) log( 1 - p_{y_i} ) ) \right] \]

Where \(y_i \in \{0,1\}\), such that for a single label output distribution, this simplifies to

\[ -log( p_{y_i} ) \]

More information here.

# Number of predictions
N = 5

# Assign some random prediction logits to demonstrate the operation of log-softmax
prd = torch.rand(N, 10)
def log_softmax_naive(x):
    softmax = x.exp() / x.exp().sum(axis=-1, keepdim=True)
    return softmax.log()


lsm_prd = log_softmax_naive(prd)
lsm_prd
tensor([[-2.1, -2.1, -2.6, -2.0, -2.6, -2.4, -2.7, -2.2, -2.0, -2.8],
        [-2.0, -2.4, -2.1, -2.4, -2.2, -2.5, -2.1, -2.4, -2.7, -2.3],
        [-2.4, -2.2, -2.4, -1.8, -2.6, -2.4, -2.3, -2.5, -2.1, -2.7],
        [-2.0, -2.9, -2.1, -2.4, -2.6, -2.1, -2.4, -2.0, -2.4, -2.6],
        [-2.3, -2.6, -2.2, -2.0, -2.1, -2.7, -2.1, -2.3, -2.2, -2.7]])

In generally, \(log\)s are handy because these additions are more numerically stable than products. We can take advantage of this because we have a division within a log:

\[ \begin{align*} log(p_{y_i}) &= log(\frac{e^{z_i}}{\Sigma e_{z_j}}) \\ &= log(e^{z_i}) - log({\Sigma e_{z_j}}) \\ &= z_i - log({\Sigma e_{z_j}}) \end{align*} \]

def log_softmax_less_naive(x):
    return x - x.exp().sum(axis=-1, keepdim=True).log()


assert torch.isclose(lsm_prd, log_softmax_less_naive(prd)).all()

One more trick. These sums can get larger, and we can deal with smaller, more stable sums using the LogSumExp trick.

Let \(a=max(\vec{v})\). Then, \[ \begin{align*} \sum e_{z_j-a} &= e^{z_i-a} + \dots + e^{z_j-a} \\ &= \frac{e^{z_i} + \dots + e^{z_j}}{e^{a\vec{I}}} \\ &= \frac{ \sum e^{z_j} }{e^{a}} \end{align*} \] Therefore, \[ \begin{align*} \sum e^{z_j} &= e^a \left( \sum e_{z_j-a} \right) \\ log \left( \sum e_{z_j} \right) &= log \left( e^a \cdot \left( \sum e_{z_j-a} \right) \right) \\ &= log(e^a) + log \left( \sum e_{z_j-a} \right) \\ &= a + log \left( \sum e_{z_j-a} \right) \end{align*} \]

def logsumexp(x):
    # Since we're using a matrix instead of a vector, we take the row-wise max
    # to vectorize across all rows
    a = x.max(dim=-1).values
    # We also covert `a` into a column vector to be broadcast the same value
    # across each column
    return a + (x - a[:, None]).exp().sum(-1).log()


assert (logsumexp(prd) == prd.logsumexp(axis=1)).all()
def log_softmax(x):
    a = x.max(dim=-1).values[:, None]
    # This gives us the log-sum-exponent term, alternately (x-a).logsumexp(...)
    lse = a + (x - a).exp().sum(axis=-1, keepdim=True).log()
    # We subtract this from x to give the final log softmax
    return x - lse


assert torch.isclose(log_softmax(prd), lsm_prd).all()

Now, for some target \(x\), the prediction \(p(x)\) is given by \[ -\Sigma_i^N x_i \cdot log ( p(x_i ) ) \]

But since the \(x\)’s are one-hot encoded, this is simply \(-log(p(x_{target}))\). We can index into this target by composing a slice of (row_index, target_index) pairs like so:

tgt = torch.randint(0, 9, size=(N,))
tgt, prd.shape, prd[range(N), tgt]
(tensor([3, 0, 1, 0, 7]),
 torch.Size([5, 10]),
 tensor([1.0, 0.9, 0.4, 1.0, 0.6]))

Alternately,

def nll(inp, tgt):
    """mean negative log likelihood loss"""
    (n,) = tgt.shape
    return -inp[range(n), tgt].mean()

This is equivalent to F.nll_loss.

nll(log_softmax(prd), tgt), F.nll_loss(F.log_softmax(prd, dim=-1), tgt)
(tensor(2.1), tensor(2.1))

Training the model

Here, we’ll take what we have implemented by hand and substitute the PyTorch equivalents.

💿 Set up the data

dm = MNISTDataModule()
dm.setup()
X_trn, y_trn = dm.as_matrix("trn")
X_trn = rearrange(X_trn, "n w h -> n (w h)")
bs = 128
n, m = X_trn.shape
nh = 50  # num. hidden dimensions
n_output_categories = y_trn.max().item() + 1
bs, m, n, nh, n_output_categories
(128, 784, 60000, 50, 10)

🗺️ Define the model

class Model(torch.nn.Module):
    def __init__(self, n_in, nh, n_out):
        super().__init__()
        self.layers = [
            torch.nn.Linear(n_in, nh),
            torch.nn.ReLU(),
            torch.nn.Linear(nh, n_out),
        ]

    def __call__(self, x):
        for l in self.layers:
            x = l(x)
        return x


model = Model(m, nh, n_output_categories)

🧐 Do a single prediction

xb = X_trn[:bs, :]
yb = y_trn[:bs]
preds = model(xb)
preds, preds.shape
preds.argmax(axis=1), yb
(tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 6, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]),
 tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7, 2, 8, 6, 9, 4, 0, 9, 1,
         1, 2, 4, 3, 2, 7, 3, 8, 6, 9, 0, 5, 6, 0, 7, 6, 1, 8, 7, 9, 3, 9, 8, 5,
         9, 3, 3, 0, 7, 4, 9, 8, 0, 9, 4, 1, 4, 4, 6, 0, 4, 5, 6, 1, 0, 0, 1, 7,
         1, 6, 3, 0, 2, 1, 1, 7, 9, 0, 2, 6, 7, 8, 3, 9, 0, 4, 6, 7, 4, 6, 8, 0,
         7, 8, 3, 1, 5, 7, 1, 7, 1, 1, 6, 3, 0, 2, 9, 3, 1, 1, 0, 4, 9, 2, 0, 0,
         2, 0, 2, 7, 1, 8, 6, 4]))
F.cross_entropy(preds, yb)
tensor(2.3, grad_fn=<NllLossBackward0>)
accuracy = (preds.argmax(axis=1) == yb).sum() / bs
f"{accuracy:.2%}"
'15.62%'

🏃 Train in a loop

epochs = 3
lr = 0.5

for epoch in range(epochs):
    for i in range(0, n, bs):
        mask = slice(i, min(n, i + bs))
        xb = X_trn[mask]
        yb = y_trn[mask]
        preds = model(xb)
        loss = F.cross_entropy(preds, yb)
        loss.backward()
        if i == 0:
            (bs,) = yb.shape
            accuracy = (preds.argmax(axis=1) == yb).sum() / bs
            print(f"{epoch=}: loss={loss.item():.2f}, accuracy={accuracy.item():.2%}")
        with torch.no_grad():
            for l in model.layers:
                if hasattr(l, "weight"):  # i.e., trainable
                    l.weight -= l.weight.grad * lr
                    l.bias -= l.bias.grad * lr
                    l.weight.grad.zero_()
                    l.bias.grad.zero_()
epoch=0: loss=2.31, accuracy=15.62%
epoch=1: loss=0.14, accuracy=96.09%
epoch=2: loss=0.10, accuracy=96.88%

At this point, Jeremy refactors the training loop to:

  • incorporate a module/parameter registry to make it cleaner to update the weights
  • reimplemented the models in the previous notebook as torch.nn.Module’s
  • implements an optimizer class that stores the parameters and updates them based on the gradient computed by torch itself
  • replaces the optimizer with the torch.optim equivalent
  • refactored the data loader with the apropriate torch primitives

I’m skipping all this because I’m pretty solid with the PyTorch fundamentals already.

A few nice tips

  1. In fastcore, How do I take the *args, **kwargs of a constructor and populate the object state?
class Foo:
    def __init__(self, bar, baz="quz"):
        fc.store_attr()


Foo("qux").baz
'quz'

This is a bit like a dataclass.__post_init__.

  1. How do you control the generation of indecies to sample?

torch.utils.data.Sampler classes are used to specify the sequence of indices/keys used in data loading. They represent iterable objects over the indices to datasets. E.g., in the common case with stochastic gradient decent (SGD), a Sampler could randomly permute a list of indices and yield each one at a time, or yield a small number of them for mini-batch SGD.

https://pytorch.org/docs/stable/data.html