aesthetics()
Cosine scheduler, revisited
Adapted from
The idea we’ll be exploring is the removing the concept of having \(\frac{t}{T}\)
aesthetics
aesthetics ()
Improve the look and feel of our visualizations
Often in diffusion, we refer to time steps like so
= 1000
T for t in range(T):
print(f"Progress: {t/T}")
Jeremy notes that we can simply use a progress variable \(\in [0, 1]\). This allows us to simplify the \(\bar{\alpha}\) expression like so:
ᾱ
ᾱ (t, reshape=True, device='cpu')
= torch.linspace(0, 1, 100)
x =False)); plt.plot(x, ᾱ(x, reshape
And, furthermore, the noisify function can be simplified like so:
noisify
noisify (x_0, t=None)
Now, (a) we don’t have to deal with knowing the time steps, computing \(\alpha\) and \(\beta\) and (b) the process is continuous.
Let’s add this to the DDPM training callback. Notice that the constructor has been deleted.
ContinuousDDPM
ContinuousDDPM (n_steps=1000, βmin=0.0001, βmax=0.02)
Modify the training behavior
= get_dls()
dls = dls.peek()
xb, _ # Note: 32x32 xb.shape
torch.Size([128, 1, 32, 32])
We can use the parameters from our training run in the previous notebook.
= Path("../models/fashion_unet_2x_continuous.pt")
fp = ContinuousDDPM(βmax=0.01)
ddpm if fp.exists():
= torch.load(fp)
unet else:
= fashion_unet()
unet
train(
unet,=1e-2,
lr=25,
n_epochs=128,
bs=partial(torch.optim.Adam, eps=1e-5),
opt_func=ddpm,
ddpm
) torch.save(unet, fp)
f"{sum(p.numel() for p in unet.parameters() if p.requires_grad):,}"
'15,890,753'
To denoise, we need to reverse the noisification. Recall, for a given sample \(x_0\), the noised sample is defined as: x_t = ᾱ_t.sqrt() * x_0 + (1 - ᾱ_t).sqrt() * ε \[ x_t = \sqrt{ \bar{\alpha}_t } x_0 + \left( \sqrt{ 1 - \bar{\alpha}_t } \right) \epsilon \]
Thus,
\[ x_0 = \frac{ x_t - \left( \sqrt{ 1 - \bar{\alpha}_t } \right) \epsilon }{ \sqrt{ \bar{\alpha}_t } } \]
denoisify
denoisify (x_t, noise, t)
= noisify(xb)
(x_t, ts), eps 8], titles=[f"{t.item():.2f}" for t in ts[:8]]) show_images(x_t[:
This looks impressive for one step, but recall xb
is part of the training data.
= unet(x_t.to(def_device), ts.to(def_device)).sample
eps_pred = denoisify(x_t, eps_pred.cpu(), ts)
x_0 8]) show_images(x_0[:
Finally, we can rewrite the sampling algorithm without any \(\alpha\)s or \(\beta\)s. The only function we need is \(\bar{\alpha}_t\), which are part of the noisify
and denoisify
functions. This also means we can make the sampler a single, elegant function.
ddpm
ddpm (model, sz=(16, 1, 32, 32), device='cpu', n_steps=100)
= ddpm(unet, sz=(8, 1, 32, 32), n_steps=100)
x_0 =0.8) show_images(x_0, imsize
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 99/99 [00:01<00:00, 67.69time step/s]
This code had a few bugs in it initially that led to deep-fried results.
denoisify
was giventorch.randn
instead ofnoise_pred
- The last denoising iteration was given
t=1
instead oft=0
Let’s try this with DDIM
ddim_noisify
ddim_noisify (η, x_0_pred, noise_pred, t, t_next)
ddim
ddim (model, sz=(16, 1, 32, 32), device='cpu', n_steps=100, eta=1.0, noisify=<function ddim_noisify>)
= ddim(unet, sz=(8, 1, 32, 32), n_steps=100)
x_0 =0.8) show_images(x_0, imsize
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 99/99 [00:01<00:00, 70.26time step/s]
= 128
bs eval = ImageEval.fashion_mnist(bs=bs)
= ddim(unet, sz=(bs, 1, 32, 32), n_steps=100)
x_0 eval.fid(x_0)
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 99/99 [00:03<00:00, 30.48time step/s]
764.1282958984375
Directly comparing this to DDPM:
= ddim(unet, sz=(bs, 1, 32, 32), n_steps=100, eta=0.0) # eta=0 makes this DDPM
x_0 eval.fid(x_0)
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 99/99 [00:03<00:00, 30.55time step/s]
818.5997314453125
and a real batch of data
= dls.peek("test")
xb, _ eval.fid(xb)
158.8096923828125