= plt.subplots(1, 1, figsize=(4, 4))
fig, ax set(xlabel=r"$\mu$", ylabel="KL divergence loss")
ax.for log_sigma_ in torch.linspace(0, 3, 5):
= torch.linspace(-3, 3, 100).unsqueeze(0)
mu = torch.full((1, 1), log_sigma_)
log_sigma = kld_loss(mu, log_sigma, dim=0)
loss =r"$\sigma=${}".format(log_sigma_.item()))
ax.plot(mu.squeeze(), loss, label; fig.legend()
VAE
Adapted from: - https://www.youtube.com/watch?v=8AgZ9jcQ9v8&list=PLfYUBJiXbdtRUvTUYpLdfHHp9a58nWVXP&index=17
Perceptron
Perceptron (c_in, c_out, bias=True, act=<class 'torch.nn.modules.activation.SiLU'>)
*A sequential container.
Modules will be added to it in the order they are passed in the constructor. Alternatively, an OrderedDict
of modules can be passed in. The forward()
method of Sequential
accepts any input and forwards it to the first module it contains. It then “chains” outputs to inputs sequentially for each subsequent module, finally returning the output of the last module.
The value a Sequential
provides over manually calling a sequence of modules is that it allows treating the whole container as a single module, such that performing a transformation on the Sequential
applies to each of the modules it stores (which are each a registered submodule of the Sequential
).
What’s the difference between a Sequential
and a :class:torch.nn.ModuleList
? A ModuleList
is exactly what it sounds like–a list for storing Module
s! On the other hand, the layers in a Sequential
are connected in a cascading way.
Example::
# Using Sequential to create a small model. When `model` is run,
# input will first be passed to `Conv2d(1,20,5)`. The output of
# `Conv2d(1,20,5)` will be used as the input to the first
# `ReLU`; the output of the first `ReLU` will become the input
# for `Conv2d(20,64,5)`. Finally, the output of
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
# Using Sequential with OrderedDict. This is functionally the
# same as the above code
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))*
KaimingMixin
KaimingMixin ()
Helper to initialize the network using Kaiming
VAE
VAE (c_in, c_hidden, c_bottleneck, layers=1)
Variational autoencoder
Sigma can go to 0 to preserve data in the activations, so we need a new loss function to make sure that the hidden distribution is normal. This is known as “Kullback–Leibler divergence” or “KLD” loss. This reaches a minimum when μ is 0 and σ is 1.
kld_loss
kld_loss (μ, log_σ, eps=0, dim=None)
This is added to a normal reconstruction loss.
vae_loss
vae_loss (inputs, x_pred, μ, log_σ)
We want to be able to keep track of the KLD loss over time, so let’s track it in a metric.
MeanlikeMetric
MeanlikeMetric (**kwargs)
*Base class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality: 1. Handles the transfer of metric states to correct device 2. Handles the synchronization of metric states across processes
The three core methods of the base class are * add_state()
* forward()
* reset()
which should almost never be overwritten by child classes. Instead, the following methods should be overwritten * update()
* compute()
Args: kwargs: additional keyword arguments, see :ref:Metric kwargs
for more info.
- compute_on_cpu: If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step: If metric state should synchronize on ``forward()``. Default is ``False``
- process_group: The process group on which the synchronization is called. Default is the world.
- dist_sync_fn: Function that performs the allgather option on the metric state. Default is an custom
implementation that calls ``torch.distributed.all_gather`` internally.
- distributed_available_fn: Function that checks if the distributed backend is available. Defaults to a
check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
- sync_on_compute: If metric state should synchronize when ``compute`` is called. Default is ``True``
- compute_with_cache: If results from ``compute`` should be cached. Default is ``True``*
KLDMetric
KLDMetric (**kwargs)
*Base class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality: 1. Handles the transfer of metric states to correct device 2. Handles the synchronization of metric states across processes
The three core methods of the base class are * add_state()
* forward()
* reset()
which should almost never be overwritten by child classes. Instead, the following methods should be overwritten * update()
* compute()
Args: kwargs: additional keyword arguments, see :ref:Metric kwargs
for more info.
- compute_on_cpu: If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step: If metric state should synchronize on ``forward()``. Default is ``False``
- process_group: The process group on which the synchronization is called. Default is the world.
- dist_sync_fn: Function that performs the allgather option on the metric state. Default is an custom
implementation that calls ``torch.distributed.all_gather`` internally.
- distributed_available_fn: Function that checks if the distributed backend is available. Defaults to a
check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
- sync_on_compute: If metric state should synchronize when ``compute`` is called. Default is ``True``
- compute_with_cache: If results from ``compute`` should be cached. Default is ``True``*
BCEMetric
BCEMetric (**kwargs)
*Base class for all metrics present in the Metrics API.
This class is inherited by all metrics and implements the following functionality: 1. Handles the transfer of metric states to correct device 2. Handles the synchronization of metric states across processes
The three core methods of the base class are * add_state()
* forward()
* reset()
which should almost never be overwritten by child classes. Instead, the following methods should be overwritten * update()
* compute()
Args: kwargs: additional keyword arguments, see :ref:Metric kwargs
for more info.
- compute_on_cpu: If metric state should be stored on CPU during computations. Only works for list states.
- dist_sync_on_step: If metric state should synchronize on ``forward()``. Default is ``False``
- process_group: The process group on which the synchronization is called. Default is the world.
- dist_sync_fn: Function that performs the allgather option on the metric state. Default is an custom
implementation that calls ``torch.distributed.all_gather`` internally.
- distributed_available_fn: Function that checks if the distributed backend is available. Defaults to a
check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
- sync_on_compute: If metric state should synchronize when ``compute`` is called. Default is ``True``
- compute_with_cache: If results from ``compute`` should be cached. Default is ``True``*
MetricsCBWithKLDAndBCE
MetricsCBWithKLDAndBCE (*ms, **metrics)
Update and print metrics
VAETrainCB
VAETrainCB ()
Training specific behaviors for the Learner
train
train (model, dls, lr=0.004, n_epochs=4, extra_cbs=[], loss_fn=<function vae_loss>)
Set up the training run
= fashion_mnist(normalize=False, bs=256) dls
= dls.peek()
x, _ x.shape
torch.Size([256, 1, 28, 28])
min(), x.max() x.
(tensor(0.), tensor(1.))
FashionMNISTForReconstruction
FashionMNISTForReconstruction ()
Modify the training behavior
= train(
vae 28**2), 400, 200, layers=2),
VAE.kaiming((
dls,=[
extra_cbs
FashionMNISTForReconstruction(),
],=20,
n_epochs=3e-2,
lr )
loss | kld | bce | epoch | train |
---|---|---|---|---|
0.984 | 0.427 | 0.557 | 0 | train |
0.723 | 0.234 | 0.489 | 0 | eval |
0.546 | 0.097 | 0.448 | 1 | train |
0.453 | 0.046 | 0.408 | 1 | eval |
0.419 | 0.034 | 0.385 | 2 | train |
0.401 | 0.035 | 0.365 | 2 | eval |
0.393 | 0.034 | 0.359 | 3 | train |
0.390 | 0.039 | 0.350 | 3 | eval |
0.381 | 0.033 | 0.348 | 4 | train |
0.377 | 0.032 | 0.345 | 4 | eval |
0.371 | 0.031 | 0.341 | 5 | train |
0.369 | 0.029 | 0.340 | 5 | eval |
0.362 | 0.029 | 0.333 | 6 | train |
0.359 | 0.028 | 0.330 | 6 | eval |
0.354 | 0.028 | 0.325 | 7 | train |
0.354 | 0.028 | 0.325 | 7 | eval |
0.350 | 0.029 | 0.321 | 8 | train |
0.352 | 0.029 | 0.323 | 8 | eval |
0.348 | 0.029 | 0.319 | 9 | train |
0.350 | 0.029 | 0.320 | 9 | eval |
0.347 | 0.029 | 0.318 | 10 | train |
0.348 | 0.029 | 0.319 | 10 | eval |
0.346 | 0.030 | 0.316 | 11 | train |
0.347 | 0.030 | 0.316 | 11 | eval |
0.345 | 0.030 | 0.315 | 12 | train |
0.345 | 0.030 | 0.315 | 12 | eval |
0.344 | 0.030 | 0.314 | 13 | train |
0.345 | 0.030 | 0.314 | 13 | eval |
0.343 | 0.030 | 0.313 | 14 | train |
0.343 | 0.030 | 0.313 | 14 | eval |
0.343 | 0.030 | 0.312 | 15 | train |
0.342 | 0.030 | 0.312 | 15 | eval |
0.342 | 0.030 | 0.312 | 16 | train |
0.342 | 0.030 | 0.311 | 16 | eval |
0.341 | 0.030 | 0.311 | 17 | train |
0.341 | 0.030 | 0.310 | 17 | eval |
0.341 | 0.031 | 0.311 | 18 | train |
0.341 | 0.030 | 0.310 | 18 | eval |
0.341 | 0.031 | 0.310 | 19 | train |
0.341 | 0.030 | 0.310 | 19 | eval |
It took quite a bit of work to ensure that these results matched Howards’:
- Use SiLU instead of ReLU
- Initialize leakily
- Normalize the output before visualizing
- Ensure there were the number of encoder layers as decoder layers
- Use the original FashionMNIST, not the one upsampled to 32x32 for DDPM.
with torch.no_grad():
= rearrange(x, "b c h w -> b (c h w)")
xb = to_cpu(vae(xb.cuda()))
xb_pred, _, _ = rearrange(xb_pred.sigmoid(), "b (c h w) -> b c h w", c=1, h=28, w=28)
xb_pred = xb_pred.float() xb_pred
8], imsize=0.8); show_images(x[:
8], imsize=0.8); show_images(xb_pred[: