"ggplot") plt.style.use(
DDPM and Mixed Precision
Adapted from:
Prelude
The huggingface U-net requires an image that has a height/width of a power of 2.
I’m also removing global normalization, as that seemed to degrade performance compared to scaling between -0.5 and 0.5.
get_dls
get_dls (bs=128)
= get_dls()
dls = dls.peek()
xb, _ xb.shape
torch.Size([128, 1, 32, 32])
0, ...]) show_image(xb[
-1)) plt.hist(xb.view(
(tensor(-0.5000), tensor(0.5000))
-1)) plt.hist(xb.view(
(array([70274., 4709., 4792., 5277., 6619., 6254., 7328., 9274.,
11526., 5019.]),
array([-0.5 , -0.40000001, -0.30000001, -0.2 , -0.1 ,
0. , 0.1 , 0.2 , 0.30000001, 0.40000001,
0.5 ]),
<BarContainer object of 10 artists>)
The diffusion model paper is available here and proposes an important simplification that improved the performance of this kind of modeling.
Jeremy mentions that, even shortly after the Stable Diffusion release, better models were in progress that had comparable performance that do not use diffusion. “Iterative refinement” may be a better term that encompasses other models, such as GANs.
See this clarifying post for details.
Review, lesson 9b
“Generative modeling” produces some complicated probability distribution, such that sampling produces a complicated example. GAN’s and VAE’s can produce such a distribution, and iterative refinement models can do so as well.
The forward process is defined recurively like so: \[ q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1-\beta_t} (x_{t-1}), \beta_t I) \]
During the forward process of diffusion, the mean decreases and the variance increases. In other words, the original image is lost and is replaced by pure noise.
It’s easy to compute the forward process for a given \(X\) and \(t\). There is, in fact, a closed-form solution for a given \(\{\beta_0,...,\beta_T\}\)
\[ \begin{align*} \epsilon &\sim \mathcal{N}(0, I) \\ \alpha_t &= 1-\beta_t \\ \bar{\alpha_t} &= \prod_{i=0}^t \alpha_i \\ x_t &= \sqrt{\bar{\alpha_t}} x_0 + \sqrt{1-\bar{\alpha_t}} \epsilon \end{align*} \]
= LMSDiscreteScheduler(
s =0.00085,
beta_start=0.012,
beta_end="scaled_linear",
beta_schedule=1000,
num_train_timesteps
)= plt.subplots(1, 3, figsize=(8, 3))
fig, (a0, a1, a2)
a0.plot(s.betas)set(xlabel="Time", ylabel=r"$\beta$")
a0.
a1.plot(s.sigmas)set(xlabel="Time", ylabel=r"$\sigma$")
a1.
a2.plot(s.alphas_cumprod)set(xlabel="Time", ylabel=r"$\bar{\alpha}$")
a2. fig.tight_layout()
To review:
- \(\beta\) and \(\sigma\) are the mean and standard deviation (respectively) of the noise added to the image at each step
- \(\bar{\alpha}\) is the cumulative amount of noise added at a particular time
To train this model, we:
- Randomly select a timepoint, \(t\) and add the apropriate amount of noise to the image
- Predict the added noise and back-prop with MSE loss
Specifically, we need the equation from the paper Algorithm 1.
\[ \begin{align*} t &\sim Uniform(\{1,...,T\}) \\ \epsilon &\sim \mathcal{N}(0, I) \\ loss &= MSE \left( \epsilon, \epsilon_\theta ( \sqrt{\bar{\alpha_t}} x_0 + \sqrt{1-\bar{\alpha_t}} \epsilon, t ) \right) \\ \end{align*} \]
Here, \(\epsilon_\theta\) is a neural network that predicts the mean and uses a fixed variance.
DDPM
DDPM (n_steps=1000, βmin=0.0001, βmax=0.02)
Modify the training behavior
By iteratively predicting the noise and removing it, ultimately we end up with a high probability data point.
This is defined in Algorithm 2 in the DDPM paper
\[ \begin{align*} z &\sim \mathcal{N}(0, I) \text{ if } t > 1 \text{ else } z=0 \\ \hat{x}_{t-1} &= \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(x_t, t) \right) + \sigma_t z \end{align*} \]
Let’s break this down into manageable chunks
Equation | Description |
---|---|
\[z \sim \mathcal{N}(0, I) \text{ if } t > 1 \text{ else } z=0\] | Samples a noise vector from a standard normal distribution if t > 1, otherwise sets z to 0 |
\[\epsilon_\theta(x_t, t)\] | Represents the neural network |
\[\frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}}\] | The predicted noise scaling factor required such that the mean of the predicted noise matches that of \(x_t\) |
\[x_t - \frac{1-\alpha_t}{\sqrt{1-\bar{\alpha}_t}} \epsilon_\theta(x_t, t) \] | \(\hat{x}_0\) |
\[ \frac{1}{\sqrt{\alpha_t}} \] | The overall scaling factor required such that adding $ _t z $ yields a mean and variance apropriate to the schedule |
\[\sigma_t z\] | Noise to add to the predicted \(x_0\) to get \(x_{t-1}\) |
See Tanishq’s post for details on the derivation of the coefficients.
sample
sample (model, sz=(16, 1, 32, 32), device='cpu', return_all=False)
Architecture
Our model is a U-Net, with some modern tricks like attention.
0] UNet2DModel().down_blocks[
DownBlock2D(
(resnets): ModuleList(
(0-1): 2 x ResnetBlock2D(
(norm1): GroupNorm(32, 224, eps=1e-05, affine=True)
(conv1): LoRACompatibleConv(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(time_emb_proj): LoRACompatibleLinear(in_features=896, out_features=224, bias=True)
(norm2): GroupNorm(32, 224, eps=1e-05, affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(conv2): LoRACompatibleConv(224, 224, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(nonlinearity): SiLU()
)
)
(downsamplers): ModuleList(
(0): Downsample2D(
(conv): LoRACompatibleConv(224, 224, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
)
)
)
train
train (model, lr=0.004, n_epochs=2, bs=128, opt_func=<class 'torch.optim.adam.Adam'>, extra_cbs=[], ddpm=<__main__.DDPM object at 0x7f20958bc9d0>)
fashion_unet
fashion_unet ()
Let’s give this a test run.
= ddpm.sample(unet)
imgs show_images(imgs)
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 999/999 [00:14<00:00, 68.60time step/s]
Initialization
diffusers
does not perform initialization. Let’s fix this by:
- Making the non-residual weights into a zero-convolution, such that the training dynamics resemble that of a smaller, more stable network
- Using orthoganol initialization for the downsampler blocks
Jeremy mentions here that he spoke to Katherine Crowson who wrote K-diffusion and recommended using orthoganol initialization. To understand the motivation behind orthoganol weights, Jeremy recommends taking Rachel’s Thomas’ computational linear algrebra course.
I’ll take a shot at summarizing this. Suppose \(W\) is a matrix where each row is approximately orthonormal to each other row and has a unit-magnitude. Under such a linear transormation, the scale of the input vector is preserved. (An identity matrix has the same property.)
Moreover, orthonormal feature extractors may learn different representations of the input features. This is basically a change of basis function, so neurons will have a mixure of different features in a different from the underlying layer, but the magnitude does not change.
init.orthogonal_?
Signature: init.orthogonal_(tensor, gain=1) Docstring: Fills the input `Tensor` with a (semi) orthogonal matrix, as described in `Exact solutions to the nonlinear dynamics of learning in deep linear neural networks` - Saxe, A. et al. (2013). The input tensor must have at least 2 dimensions, and for tensors with more than 2 dimensions the trailing dimensions are flattened. Args: tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2` gain: optional scaling factor Examples: >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) >>> w = torch.empty(3, 5) >>> nn.init.orthogonal_(w) File: ~/micromamba/envs/slowai/lib/python3.11/site-packages/torch/nn/init.py Type: function
init_
init_ ()
= fashion_unet()
unet
unet.init_()=1); train(unet, n_epochs
loss | epoch | train |
---|---|---|
0.839 | 0 | train |
0.996 | 0 | eval |
This (sometimes) explodes at the highest learning rate. Recall, in RMSProp, the gradients are divided by the expontentially weighted square of the gradients plus \(\epsilon\). If the gradients are small, and \(\epsilon\) is also small, then the learning rate becomes extremely large.
Increasing \(\epsilon\) can mitigate this.
= fashion_unet()
unet
unet.init_()= train(unet, opt_func=partial(torch.optim.Adam, eps=1e-5)) ddpm
loss | epoch | train |
---|---|---|
0.129 | 0 | train |
0.022 | 0 | eval |
0.020 | 1 | train |
0.018 | 1 | eval |
CPU times: user 1min 2s, sys: 4.74 s, total: 1min 7s
Wall time: 1min 8s
This might slightly improve results. I was getting about twice the loss as Jeremy’s version, but I removed the global normalization and now I get competitive performance without orthoganol normalization.
Speeding up training
GPUs are extremely fast at 16-bit precision. Unfortunately, not everything can be done at such low precision. Most training is performed using “mixed precision.”
We can adapt the example in the Torch docs to that of a Callback
= True
use_amp
= make_model(in_size, out_size, num_layers)
net = torch.optim.SGD(net.parameters(), lr=0.001)
opt = torch.cuda.amp.GradScaler(enabled=use_amp)
scaler
start_timer()for epoch in range(epochs):
for input, target in zip(data, targets):
with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):
= net(input)
output = loss_fn(output, target)
loss
scaler.scale(loss).backward()
scaler.step(opt)
scaler.update()# set_to_none=True here can modestly improve performance opt.zero_grad()
MixedPrecision
MixedPrecision ()
Training specific behaviors for the Learner
Since fp16
takes only a fraction of the original representation memory space and processing time, we can dramatically increase the batch size. However, since we have fewer opportunities to update per epoch, we need to compensate with a higher learning rate and more epochs.
= fashion_unet()
unet
unet.init_()
train(
unet,=1e-2,
lr=8,
n_epochs=512,
bs=partial(torch.optim.Adam, eps=1e-5),
opt_func=[MixedPrecision()],
extra_cbs )
loss | epoch | train |
---|---|---|
0.262 | 0 | train |
0.028 | 0 | eval |
0.026 | 1 | train |
0.026 | 1 | eval |
0.022 | 2 | train |
0.021 | 2 | eval |
0.019 | 3 | train |
0.019 | 3 | eval |
0.018 | 4 | train |
0.017 | 4 | eval |
0.017 | 5 | train |
0.016 | 5 | eval |
0.016 | 6 | train |
0.016 | 6 | eval |
0.016 | 7 | train |
0.016 | 7 | eval |
CPU times: user 1min 44s, sys: 19.3 s, total: 2min 4s
Wall time: 2min 7s
# Let's save a copy of this for the FID notebook
"../models/fashion_unet.pt") torch.save(unet,
= ddpm.sample(unet)
imgs show_images(imgs)
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 999/999 [00:14<00:00, 67.09time step/s]
This makes training for longer more feasible, ultimately resulting in a lower loss.
I tried a lot of configurations for this loop. It would usually diverge, until I increased the number of epochs to 8. This probably spread out the learning rate scheduling.
# For the homework, let's keep a copy of these
= unet, ddpm unet_hw, ddpm_hw
Accelerate
Accelerate was developed by Sylvain Gugger at Huggingface to convert existing PyTorch code (and PyTorch-framework, like ours) to use optimizations such as:
- ZeRO-Offload: enables multi-billion parameter model training by offloading tensor storage to the CPU and reloading strategically onto the GPU before computation (Author presentation here)
- Fully-sharded data parallelism: divide the compute graph into shards, such that each shard computes an intermediate activation and dispatches it to the apropriate GPU
- Mixed Precision Training
AccelerateCB
AccelerateCB (mixed_precision='fp16')
Training specific behaviors for the Learner
= fashion_unet()
unet
unet.init_()
train(
unet,=1e-2,
lr=8,
n_epochs=512,
bs=partial(torch.optim.Adam, eps=1e-5),
opt_func=[AccelerateCB()],
extra_cbs )
loss | epoch | train |
---|---|---|
0.203 | 0 | train |
0.032 | 0 | eval |
0.029 | 1 | train |
0.026 | 1 | eval |
0.023 | 2 | train |
0.021 | 2 | eval |
0.021 | 3 | train |
0.020 | 3 | eval |
0.018 | 4 | train |
0.018 | 4 | eval |
0.017 | 5 | train |
0.017 | 5 | eval |
0.016 | 6 | train |
0.016 | 6 | eval |
0.016 | 7 | train |
0.016 | 7 | eval |
CPU times: user 1min 57s, sys: 19.1 s, total: 2min 16s
Wall time: 2min 35s
Homework
Do the same training setup, but use only the final 200 steps of the noise scheduler we used in this lesson.
Let’s think. This assignment assumes we only need the values right of the dashed line.
= plt.subplots(1, 3, figsize=(8, 3))
fig, (a0, a1, a2)
a0.plot(s.betas)set(xlabel="Time", ylabel=r"$\beta$")
a0.
a1.plot(s.sigmas)set(xlabel="Time", ylabel=r"$\sigma$")
a1.
a2.plot(s.alphas_cumprod)set(xlabel="Time", ylabel=r"$\bar{\alpha}$")
a2.
fig.tight_layout()
for y_min, y_max, ax in [(0.00085, 0.012, a0), (0, 15, a1), (0, 1, a2)]:
800, 800], [y_min, y_max], linestyle="--", color="black") ax.plot([
What happens if we use the values of \(\beta\) from the dashed line to 0.012
.
1000 - 200] s.betas[
tensor(0.0087)
= fashion_unet()
unet = train(unet, ddpm=DDPM(200, 0.0087, 0.012), n_epochs=5) ddpm
loss | epoch | train |
---|---|---|
0.093 | 0 | train |
0.023 | 0 | eval |
0.021 | 1 | train |
0.020 | 1 | eval |
0.018 | 2 | train |
0.018 | 2 | eval |
0.017 | 3 | train |
0.016 | 3 | eval |
0.016 | 4 | train |
0.015 | 4 | eval |
CPU times: user 2min 46s, sys: 9.92 s, total: 2min 56s
Wall time: 2min 59s
= ddpm.sample(unet)
imgs show_images(imgs)
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 199/199 [00:02<00:00, 68.44time step/s]
Quite blurry.
Let’s try something else. Let’s use a model trained “normally”, but only take the last 200 steps.
= unet_hw, ddpm_hw unet, ddpm
@torch.no_grad()
def sample_using_only_n_steps(
n_steps,
ddpm,
model,=(16, 1, 32, 32),
sz=def_device,
device
):= ddpm.ᾱ.to(device), ddpm.ɑ.to(device), ddpm.σ.to(device)
ᾱ, ɑ, σ = torch.randn(sz, device=device) * σ[ddpm.n_steps - n_steps]
x_t *_ = sz
bs, = []
preds = list(reversed(range(1, ddpm.n_steps)))
iter_ = iter_[-n_steps:]
iter_ for t in tqdm(iter_, unit="time step"):
# Predict the noise for each example in the image
= torch.full((bs,), fill_value=t, device=device, dtype=torch.long)
t_batch = model(x_t, t_batch).sample
noise_pred
# Predict the image without noise
= x_t - (1 - ɑ[t]) / torch.sqrt(1 - ᾱ[t]) * noise_pred
x_0_pred
# Add noise to the predicted noiseless image such that it ulimately
# has slightly less noise than before
= x_0_pred / ɑ[t].sqrt() + (σ[t] * torch.randn(sz, device=device))
x_t_minus_1
# Repeat
= x_t_minus_1
x_t
# At the last step, simply rescale and do not add noise
= x_0_pred / ɑ[0].sqrt()
x_0 return x_0
= sample_using_only_n_steps(200, ddpm, unet)
imgs show_images(imgs)
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:02<00:00, 68.70time step/s]
Not really working. Maybe we can take every other nth step, until step 200, and then sample normally.
@torch.no_grad()
def sample_skip(
int,
skip: int,
until:
ddpm,
model,=(16, 1, 32, 32),
sz=def_device,
device
):"""For the first n steps, reuse the noise prediction"""
= ddpm.ᾱ.to(device), ddpm.ɑ.to(device), ddpm.σ.to(device)
ᾱ, ɑ, σ = torch.randn(sz, device=device)
x_t *_ = sz
bs, = []
preds = list(reversed(range(1, ddpm.n_steps)))
ts
= ts[:until][::skip], ts[until:]
skip_ts, nonskip_ts
try:
for t in tqdm(skip_ts, unit="time step"):
# Predict the noise for each example in the image
= torch.full((bs,), fill_value=t, device=device, dtype=torch.long)
t_batch = model(x_t, t_batch).sample
noise_pred = torch.randn(sz, device=device)
epsilon
for t_offset in range(skip):
# Predict the image without noise
if t + t_offset > until:
raise StopIteration
= (1 - ɑ[t - t_offset]) / torch.sqrt(1 - ᾱ[t - t_offset])
K = x_t - K * noise_pred
x_0_pred
# Add noise to the predicted noiseless image such that it ulimately
# has slightly less noise than before
= x_0_pred / ɑ[t].sqrt() + (σ[t] * epsilon)
x_t_minus_1
# Repeat
= x_t_minus_1
x_t except StopIteration:
...
for t in tqdm(nonskip_ts, unit="time step"):
= torch.full((bs,), fill_value=t, device=device, dtype=torch.long)
t_batch = model(x_t, t_batch).sample
noise_pred = x_t - (1 - ɑ[t]) / torch.sqrt(1 - ᾱ[t]) * noise_pred
x_0_pred = x_0_pred / ɑ[t].sqrt() + (σ[t] * torch.randn(sz, device=device))
x_t_minus_1 = x_t_minus_1
x_t
= x_0_pred / ɑ[0].sqrt()
x_0 return x_0
= sample_skip(100, 300, ddpm, unet)
imgs show_images(imgs)
0%| | 0/3 [00:00<?, ?time step/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 699/699 [00:10<00:00, 68.30time step/s]
Let’s try one more thing where we skip more at the beginning than at the end.
@torch.no_grad()
def sample_skip_schedule(schedule, ddpm, model, sz=(16, 1, 32, 32), device=def_device):
assert sum(schedule) < ddpm.n_steps
= ddpm.ᾱ.to(device), ddpm.ɑ.to(device), ddpm.σ.to(device)
ᾱ, ɑ, σ = torch.randn(sz, device=device)
x_t *_ = sz
bs, = []
preds = reversed(range(1, ddpm.n_steps))
ts = iter(tqdm(ts, total=ddpm.n_steps - 1))
iter_ts
for block in schedule:
= next(iter_ts)
t = torch.full((bs,), fill_value=t, device=device, dtype=torch.long)
t_batch = model(x_t, t_batch).sample
noise_pred = torch.randn(sz, device=device)
epsilon
# Reuse the predicted and sampled noise
for _ in range(block - 1):
= next(iter_ts)
t = (1 - ɑ[t]) / torch.sqrt(1 - ᾱ[t])
K = x_t - K * noise_pred
x_0_pred = x_0_pred / ɑ[t].sqrt() + (σ[t] * epsilon)
x_t_minus_1
# Repeat
= x_t_minus_1
x_t
for t in iter_ts:
= torch.full((bs,), fill_value=t, device=device, dtype=torch.long)
t_batch = model(x_t, t_batch).sample
noise_pred = x_t - (1 - ɑ[t]) / torch.sqrt(1 - ᾱ[t]) * noise_pred
x_0_pred = x_0_pred / ɑ[t].sqrt() + (σ[t] * torch.randn(sz, device=device))
x_t_minus_1 = x_t_minus_1
x_t
= x_0_pred / ɑ[0].sqrt()
x_0 return x_0
= [200] + [25] * 10 + [5] * 25
sched = ddpm.n_steps - sum(sched) + len(sched)
n_steps = sample_skip_schedule(sched, ddpm, unet)
imgs
show_images(imgs) n_steps
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 999/999 [00:06<00:00, 149.67it/s]
461
None of these worked very well. I was able to get it down to 100 steps in the cosine_revisited
.