VAE

In this module, we train a variational autoencoder

Adapted from: - https://www.youtube.com/watch?v=8AgZ9jcQ9v8&list=PLfYUBJiXbdtRUvTUYpLdfHHp9a58nWVXP&index=17


source

Perceptron

 Perceptron (c_in, c_out, bias=True, act=<class
             'torch.nn.modules.activation.SiLU'>)

*A sequential container.

Modules will be added to it in the order they are passed in the constructor. Alternatively, an OrderedDict of modules can be passed in. The forward() method of Sequential accepts any input and forwards it to the first module it contains. It then “chains” outputs to inputs sequentially for each subsequent module, finally returning the output of the last module.

The value a Sequential provides over manually calling a sequence of modules is that it allows treating the whole container as a single module, such that performing a transformation on the Sequential applies to each of the modules it stores (which are each a registered submodule of the Sequential).

What’s the difference between a Sequential and a :class:torch.nn.ModuleList? A ModuleList is exactly what it sounds like–a list for storing Module s! On the other hand, the layers in a Sequential are connected in a cascading way.

Example::

# Using Sequential to create a small model. When `model` is run,
# input will first be passed to `Conv2d(1,20,5)`. The output of
# `Conv2d(1,20,5)` will be used as the input to the first
# `ReLU`; the output of the first `ReLU` will become the input
# for `Conv2d(20,64,5)`. Finally, the output of
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )

# Using Sequential with OrderedDict. This is functionally the
# same as the above code
model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))*

source

KaimingMixin

 KaimingMixin ()

Helper to initialize the network using Kaiming


source

VAE

 VAE (c_in, c_hidden, c_bottleneck, layers=1)

Variational autoencoder

Sigma can go to 0 to preserve data in the activations, so we need a new loss function to make sure that the hidden distribution is normal. This is known as “Kullback–Leibler divergence” or “KLD” loss. This reaches a minimum when μ is 0 and σ is 1.


source

kld_loss

 kld_loss (μ, log_σ, eps=0, dim=None)
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
ax.set(xlabel=r"$\mu$", ylabel="KL divergence loss")
for log_sigma_ in torch.linspace(0, 3, 5):
    mu = torch.linspace(-3, 3, 100).unsqueeze(0)
    log_sigma = torch.full((1, 1), log_sigma_)
    loss = kld_loss(mu, log_sigma, dim=0)
    ax.plot(mu.squeeze(), loss, label=r"$\sigma=${}".format(log_sigma_.item()))
fig.legend();

This is added to a normal reconstruction loss.


source

vae_loss

 vae_loss (inputs, x_pred, μ, log_σ)

We want to be able to keep track of the KLD loss over time, so let’s track it in a metric.


source

MeanlikeMetric

 MeanlikeMetric (**kwargs)

*Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality: 1. Handles the transfer of metric states to correct device 2. Handles the synchronization of metric states across processes

The three core methods of the base class are * add_state() * forward() * reset()

which should almost never be overwritten by child classes. Instead, the following methods should be overwritten * update() * compute()

Args: kwargs: additional keyword arguments, see :ref:Metric kwargs for more info.

    - compute_on_cpu: If metric state should be stored on CPU during computations. Only works for list states.
    - dist_sync_on_step: If metric state should synchronize on ``forward()``. Default is ``False``
    - process_group: The process group on which the synchronization is called. Default is the world.
    - dist_sync_fn: Function that performs the allgather option on the metric state. Default is an custom
      implementation that calls ``torch.distributed.all_gather`` internally.
    - distributed_available_fn: Function that checks if the distributed backend is available. Defaults to a
      check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
    - sync_on_compute: If metric state should synchronize when ``compute`` is called. Default is ``True``
    - compute_with_cache: If results from ``compute`` should be cached. Default is ``True``*

source

KLDMetric

 KLDMetric (**kwargs)

*Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality: 1. Handles the transfer of metric states to correct device 2. Handles the synchronization of metric states across processes

The three core methods of the base class are * add_state() * forward() * reset()

which should almost never be overwritten by child classes. Instead, the following methods should be overwritten * update() * compute()

Args: kwargs: additional keyword arguments, see :ref:Metric kwargs for more info.

    - compute_on_cpu: If metric state should be stored on CPU during computations. Only works for list states.
    - dist_sync_on_step: If metric state should synchronize on ``forward()``. Default is ``False``
    - process_group: The process group on which the synchronization is called. Default is the world.
    - dist_sync_fn: Function that performs the allgather option on the metric state. Default is an custom
      implementation that calls ``torch.distributed.all_gather`` internally.
    - distributed_available_fn: Function that checks if the distributed backend is available. Defaults to a
      check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
    - sync_on_compute: If metric state should synchronize when ``compute`` is called. Default is ``True``
    - compute_with_cache: If results from ``compute`` should be cached. Default is ``True``*

source

BCEMetric

 BCEMetric (**kwargs)

*Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality: 1. Handles the transfer of metric states to correct device 2. Handles the synchronization of metric states across processes

The three core methods of the base class are * add_state() * forward() * reset()

which should almost never be overwritten by child classes. Instead, the following methods should be overwritten * update() * compute()

Args: kwargs: additional keyword arguments, see :ref:Metric kwargs for more info.

    - compute_on_cpu: If metric state should be stored on CPU during computations. Only works for list states.
    - dist_sync_on_step: If metric state should synchronize on ``forward()``. Default is ``False``
    - process_group: The process group on which the synchronization is called. Default is the world.
    - dist_sync_fn: Function that performs the allgather option on the metric state. Default is an custom
      implementation that calls ``torch.distributed.all_gather`` internally.
    - distributed_available_fn: Function that checks if the distributed backend is available. Defaults to a
      check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
    - sync_on_compute: If metric state should synchronize when ``compute`` is called. Default is ``True``
    - compute_with_cache: If results from ``compute`` should be cached. Default is ``True``*

source

MetricsCBWithKLDAndBCE

 MetricsCBWithKLDAndBCE (*ms, **metrics)

Update and print metrics


source

VAETrainCB

 VAETrainCB ()

Training specific behaviors for the Learner


source

train

 train (model, dls, lr=0.004, n_epochs=4, extra_cbs=[], loss_fn=<function
        vae_loss>)

Set up the training run

dls = fashion_mnist(normalize=False, bs=256)
x, _ = dls.peek()
x.shape
torch.Size([256, 1, 28, 28])
x.min(), x.max()
(tensor(0.), tensor(1.))

source

FashionMNISTForReconstruction

 FashionMNISTForReconstruction ()

Modify the training behavior

vae = train(
    VAE.kaiming((28**2), 400, 200, layers=2),
    dls,
    extra_cbs=[
        FashionMNISTForReconstruction(),
    ],
    n_epochs=20,
    lr=3e-2,
)
loss kld bce epoch train
0.984 0.427 0.557 0 train
0.723 0.234 0.489 0 eval
0.546 0.097 0.448 1 train
0.453 0.046 0.408 1 eval
0.419 0.034 0.385 2 train
0.401 0.035 0.365 2 eval
0.393 0.034 0.359 3 train
0.390 0.039 0.350 3 eval
0.381 0.033 0.348 4 train
0.377 0.032 0.345 4 eval
0.371 0.031 0.341 5 train
0.369 0.029 0.340 5 eval
0.362 0.029 0.333 6 train
0.359 0.028 0.330 6 eval
0.354 0.028 0.325 7 train
0.354 0.028 0.325 7 eval
0.350 0.029 0.321 8 train
0.352 0.029 0.323 8 eval
0.348 0.029 0.319 9 train
0.350 0.029 0.320 9 eval
0.347 0.029 0.318 10 train
0.348 0.029 0.319 10 eval
0.346 0.030 0.316 11 train
0.347 0.030 0.316 11 eval
0.345 0.030 0.315 12 train
0.345 0.030 0.315 12 eval
0.344 0.030 0.314 13 train
0.345 0.030 0.314 13 eval
0.343 0.030 0.313 14 train
0.343 0.030 0.313 14 eval
0.343 0.030 0.312 15 train
0.342 0.030 0.312 15 eval
0.342 0.030 0.312 16 train
0.342 0.030 0.311 16 eval
0.341 0.030 0.311 17 train
0.341 0.030 0.310 17 eval
0.341 0.031 0.311 18 train
0.341 0.030 0.310 18 eval
0.341 0.031 0.310 19 train
0.341 0.030 0.310 19 eval

It took quite a bit of work to ensure that these results matched Howards’:

  • Use SiLU instead of ReLU
  • Initialize leakily
  • Normalize the output before visualizing
  • Ensure there were the number of encoder layers as decoder layers
  • Use the original FashionMNIST, not the one upsampled to 32x32 for DDPM.
with torch.no_grad():
    xb = rearrange(x, "b c h w -> b (c h w)")
    xb_pred, _, _ = to_cpu(vae(xb.cuda()))
xb_pred = rearrange(xb_pred.sigmoid(), "b (c h w) -> b c h w", c=1, h=28, w=28)
xb_pred = xb_pred.float()
show_images(x[:8], imsize=0.8);

show_images(xb_pred[:8], imsize=0.8);