VAE

In this module, we train a variational autoencoder

Adapted from: - https://www.youtube.com/watch?v=8AgZ9jcQ9v8&list=PLfYUBJiXbdtRUvTUYpLdfHHp9a58nWVXP&index=17


Perceptron


def Perceptron(
    c_in, c_out, bias:bool=True, act:type=SiLU
):

A sequential container.

Modules will be added to it in the order they are passed in the constructor. Alternatively, an OrderedDict of modules can be passed in. The forward() method of Sequential accepts any input and forwards it to the first module it contains. It then “chains” outputs to inputs sequentially for each subsequent module, finally returning the output of the last module.

The value a Sequential provides over manually calling a sequence of modules is that it allows treating the whole container as a single module, such that performing a transformation on the Sequential applies to each of the modules it stores (which are each a registered submodule of the Sequential).

What’s the difference between a Sequential and a :class:torch.nn.ModuleList? A ModuleList is exactly what it sounds like–a list for storing Module s! On the other hand, the layers in a Sequential are connected in a cascading way.

Example::

# Using Sequential to create a small model. When `model` is run,
# input will first be passed to `Conv2d(1,20,5)`. The output of
# `Conv2d(1,20,5)` will be used as the input to the first
# `ReLU`; the output of the first `ReLU` will become the input
# for `Conv2d(20,64,5)`. Finally, the output of
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
model = nn.Sequential(
    nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU()
)

# Using Sequential with OrderedDict. This is functionally the
# same as the above code
model = nn.Sequential(
    OrderedDict(
        [
            ("conv1", nn.Conv2d(1, 20, 5)),
            ("relu1", nn.ReLU()),
            ("conv2", nn.Conv2d(20, 64, 5)),
            ("relu2", nn.ReLU()),
        ]
    )
)

KaimingMixin


def KaimingMixin(
    args:VAR_POSITIONAL, kwargs:VAR_KEYWORD
):

Helper to initialize the network using Kaiming


VAE


def VAE(
    c_in, c_hidden, c_bottleneck, layers:int=1
):

Variational autoencoder

Sigma can go to 0 to preserve data in the activations, so we need a new loss function to make sure that the hidden distribution is normal. This is known as “Kullback–Leibler divergence” or “KLD” loss. This reaches a minimum when μ is 0 and σ is 1.


kld_loss


def kld_loss(
    μ, log_σ, eps:int=0, dim:NoneType=None
):
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
ax.set(xlabel=r"$\mu$", ylabel="KL divergence loss")
for log_sigma_ in torch.linspace(0, 3, 5):
    mu = torch.linspace(-3, 3, 100).unsqueeze(0)
    log_sigma = torch.full((1, 1), log_sigma_)
    loss = kld_loss(mu, log_sigma, dim=0)
    ax.plot(mu.squeeze(), loss, label=r"$\sigma=${}".format(log_sigma_.item()))
fig.legend();

This is added to a normal reconstruction loss.


vae_loss


def vae_loss(
    inputs, x_pred, μ, log_σ
):

We want to be able to keep track of the KLD loss over time, so let’s track it in a metric.


MeanlikeMetric


def MeanlikeMetric(
    kwargs:VAR_KEYWORD
):

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:

1. Handles the transfer of metric states to the correct device.
2. Handles the synchronization of metric states across processes.
3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: add_state(), forward() and reset() which should almost never be overwritten by child classes. Instead, the following methods should be overwritten update() and compute().

Args: kwargs: additional keyword arguments, see :ref:Metric kwargs for more info.

    - **compute_on_cpu**:
        If metric state should be stored on CPU during computations. Only works for list states.
    - **dist_sync_on_step**:
        If metric state should synchronize on ``forward()``. Default is ``False``.
    - **process_group**:
        The process group on which the synchronization is called. Default is the world.
    - **dist_sync_fn**:
        Function that performs the allgather option on the metric state. Default is a custom
        implementation that calls ``torch.distributed.all_gather`` internally.
    - **distributed_available_fn**:
        Function that checks if the distributed backend is available. Defaults to a
        check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
    - **sync_on_compute**:
        If metric state should synchronize when ``compute`` is called. Default is ``True``.
    - **compute_with_cache**:
        If results from ``compute`` should be cached. Default is ``True``.

KLDMetric


def KLDMetric(
    kwargs:VAR_KEYWORD
):

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:

1. Handles the transfer of metric states to the correct device.
2. Handles the synchronization of metric states across processes.
3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: add_state(), forward() and reset() which should almost never be overwritten by child classes. Instead, the following methods should be overwritten update() and compute().

Args: kwargs: additional keyword arguments, see :ref:Metric kwargs for more info.

    - **compute_on_cpu**:
        If metric state should be stored on CPU during computations. Only works for list states.
    - **dist_sync_on_step**:
        If metric state should synchronize on ``forward()``. Default is ``False``.
    - **process_group**:
        The process group on which the synchronization is called. Default is the world.
    - **dist_sync_fn**:
        Function that performs the allgather option on the metric state. Default is a custom
        implementation that calls ``torch.distributed.all_gather`` internally.
    - **distributed_available_fn**:
        Function that checks if the distributed backend is available. Defaults to a
        check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
    - **sync_on_compute**:
        If metric state should synchronize when ``compute`` is called. Default is ``True``.
    - **compute_with_cache**:
        If results from ``compute`` should be cached. Default is ``True``.

BCEMetric


def BCEMetric(
    kwargs:VAR_KEYWORD
):

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:

1. Handles the transfer of metric states to the correct device.
2. Handles the synchronization of metric states across processes.
3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: add_state(), forward() and reset() which should almost never be overwritten by child classes. Instead, the following methods should be overwritten update() and compute().

Args: kwargs: additional keyword arguments, see :ref:Metric kwargs for more info.

    - **compute_on_cpu**:
        If metric state should be stored on CPU during computations. Only works for list states.
    - **dist_sync_on_step**:
        If metric state should synchronize on ``forward()``. Default is ``False``.
    - **process_group**:
        The process group on which the synchronization is called. Default is the world.
    - **dist_sync_fn**:
        Function that performs the allgather option on the metric state. Default is a custom
        implementation that calls ``torch.distributed.all_gather`` internally.
    - **distributed_available_fn**:
        Function that checks if the distributed backend is available. Defaults to a
        check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
    - **sync_on_compute**:
        If metric state should synchronize when ``compute`` is called. Default is ``True``.
    - **compute_with_cache**:
        If results from ``compute`` should be cached. Default is ``True``.

MetricsCBWithKLDAndBCE


def MetricsCBWithKLDAndBCE(
    ms:VAR_POSITIONAL, metrics:VAR_KEYWORD
):

Update and print metrics


VAETrainCB


def VAETrainCB(
    args:VAR_POSITIONAL, kwargs:VAR_KEYWORD
):

Training specific behaviors for the Learner


train


def train(
    model, dls, lr:float=0.004, n_epochs:int=4, extra_cbs:list=[], loss_fn:function=vae_loss
):

Set up the training run

dls = fashion_mnist(normalize=False, bs=256)
x, _ = dls.peek()
x.shape
torch.Size([256, 1, 28, 28])
x.min(), x.max()
(tensor(0.), tensor(1.))

FashionMNISTForReconstruction


def FashionMNISTForReconstruction(
    args:VAR_POSITIONAL, kwargs:VAR_KEYWORD
):

Modify the training behavior

vae = train(
    VAE.kaiming((28**2), 400, 200, layers=2),
    dls,
    extra_cbs=[
        FashionMNISTForReconstruction(),
    ],
    n_epochs=20,
    lr=3e-2,
)
loss kld bce epoch train
0.984 0.427 0.557 0 train
0.723 0.234 0.489 0 eval
0.546 0.097 0.448 1 train
0.453 0.046 0.408 1 eval
0.419 0.034 0.385 2 train
0.401 0.035 0.365 2 eval
0.393 0.034 0.359 3 train
0.390 0.039 0.350 3 eval
0.381 0.033 0.348 4 train
0.377 0.032 0.345 4 eval
0.371 0.031 0.341 5 train
0.369 0.029 0.340 5 eval
0.362 0.029 0.333 6 train
0.359 0.028 0.330 6 eval
0.354 0.028 0.325 7 train
0.354 0.028 0.325 7 eval
0.350 0.029 0.321 8 train
0.352 0.029 0.323 8 eval
0.348 0.029 0.319 9 train
0.350 0.029 0.320 9 eval
0.347 0.029 0.318 10 train
0.348 0.029 0.319 10 eval
0.346 0.030 0.316 11 train
0.347 0.030 0.316 11 eval
0.345 0.030 0.315 12 train
0.345 0.030 0.315 12 eval
0.344 0.030 0.314 13 train
0.345 0.030 0.314 13 eval
0.343 0.030 0.313 14 train
0.343 0.030 0.313 14 eval
0.343 0.030 0.312 15 train
0.342 0.030 0.312 15 eval
0.342 0.030 0.312 16 train
0.342 0.030 0.311 16 eval
0.341 0.030 0.311 17 train
0.341 0.030 0.310 17 eval
0.341 0.031 0.311 18 train
0.341 0.030 0.310 18 eval
0.341 0.031 0.310 19 train
0.341 0.030 0.310 19 eval

It took quite a bit of work to ensure that these results matched Howards’:

  • Use SiLU instead of ReLU
  • Initialize leakily
  • Normalize the output before visualizing
  • Ensure there were the number of encoder layers as decoder layers
  • Use the original FashionMNIST, not the one upsampled to 32x32 for DDPM.
with torch.no_grad():
    xb = rearrange(x, "b c h w -> b (c h w)")
    xb_pred, _, _ = to_cpu(vae(xb.cuda()))
xb_pred = rearrange(xb_pred.sigmoid(), "b (c h w) -> b c h w", c=1, h=28, w=28)
xb_pred = xb_pred.float()
show_images(x[:8], imsize=0.8);

show_images(xb_pred[:8], imsize=0.8);