import math
import matplotlib.pyplot as plt
import torch
from slowai.ddpm import DDPM, fashion_unet, get_dls
from slowai.utils import show_images
Noise schedules
Investigating the curvature of the noise schedules
Adapted from
"ggplot") plt.style.use(
Normally, we have a noise schedule like this:
def diff(x, dt):
return (x[1:] - x[:-1]) / dt
= 1000
nsteps = (0.02 - 0.0001) / nsteps
dt = torch.linspace(0.0001, 0.02, nsteps)
beta = 1 - beta
alpha = alpha.cumprod(dim=0)
ᾱ = diff(ᾱ, dt)
dᾱ_dt = plt.subplots(1, 2, figsize=(4.5, 2.5))
fig, (a0, a1) set(title=r"$\bar{\alpha}$")
a0.
a0.plot(ᾱ)set(title=r"$\frac{d\bar{\alpha}}{dt}$")
a1.
a1.plot(dᾱ_dt) fig.tight_layout()
There are a number of steps in this process where the change in noise is almost nothing. Compare this to a cosine schedule.
def ᾱ_cos(t, T):
return (((t / T) * math.pi / 2).cos() ** 2).clamp(0.0, 0.999)
= plt.subplots(1, 2, figsize=(6.5, 3.5))
fig, (a0, a1) set(title=r"$\bar{\alpha}$")
a0.set(title=r"$\frac{d\bar{\alpha}}{dt}$")
a1.
# Linear
a0.plot(ᾱ)="linear")
a1.plot(dᾱ_dt, label
# Cosine
= ᾱ_cos(torch.linspace(0, nsteps - 1, nsteps), nsteps)
x
a0.plot(x)="cos")
a1.plot(diff(x, dt), label
="center right")
fig.legend(loc fig.tight_layout()
This is a more consistent noise scheduler, especially when considering the slope. Notice that tweaking beta_max
actually makes the curvate more cosinusoidal.
= 1000
nsteps = (0.01 - 0.0001) / nsteps
dt = torch.linspace(
beta 0.0001,
0.01, # 👈
nsteps,
)= 1 - beta
alpha = alpha.cumprod(dim=0)
ᾱ
= plt.subplots(1, 2, figsize=(6.5, 3.5))
fig, (a0, a1) set(title=r"$\bar{\alpha}$")
a0.set(title=r"$\frac{d\bar{\alpha}}{dt}$")
a1.
# Linear
a0.plot(ᾱ)="linear")
a1.plot(dᾱ_dt, label
# Cosine
= ᾱ_cos(torch.linspace(0, nsteps - 1, nsteps), nsteps)
x
a0.plot(x)="cos")
a1.plot(diff(x, dt), label
="center right")
fig.legend(loc fig.tight_layout()
Even if we adopt a linear schedule, we’ll clearly want to use a lower value of beta_max
to spread out the noise addition over time. This was noticed by Robin Rombach, the author of Stable Diffusion, who noted that lower values of \(\beta_{max}\) improved sampling.
If we were to train using this, we note that there are more examples that are clear.
def noisify(x0, ᾱ, n_steps=100):
= x0.device
device = len(x0)
n = torch.randint(0, n_steps, (n,), dtype=torch.long)
t = torch.randn(x0.shape, device=device)
ε = ᾱ[t].reshape(-1, 1, 1, 1).to(device)
ᾱ_t = ᾱ_t.sqrt() * x0 + (1 - ᾱ_t).sqrt() * ε
xt return (xt, t.to(device)), ε
= get_dls()
dls = dls.peek()
xb, _ = noisify(xb, x)
(out, _), _ out.shape
torch.Size([128, 1, 32, 32])
16]) show_images(out[: