= get_fashion_dls(bs=512) dls
Diffusion U-Net
Adapted from
Time embedding
Supposedly, time embeddings are neccesary to achieve modeling high performance
= 16
emb_dim
# This was thought to be the longest sequence that a transformer
# should be able to handle, even though nowadays sequences can
# be much longer
= 10_000
max_period
print(f"{-math.log(max_period)=}")
= -math.log(max_period) * torch.linspace(0, 1, emb_dim // 2)
exponent ; plt.plot(exponent)
-math.log(max_period)=-9.210340371976184
They are computed by computing the outer product of the time vector, ts
with the exponent function…
= 100
bs = torch.linspace(-10, 10, bs)
ts
= ts[:, None].float() * exponent.exp()[None, :]
embedding embedding.shape
torch.Size([100, 8])
note here that, so far, the embeddings aren’t very different
= plt.subplots(figsize=(4, 4))
fig, ax for i in range(0, bs, 25):
=f"Timestep #{i+1} Embedding")
ax.plot(embedding[i], labelset(ylabel="Logit", xlabel="Dimension")
ax.; fig.legend()
…and projected into cosine and sine space, and then concatenated.
= torch.cat([embedding.sin(), embedding.cos()], dim=-1)
embedding embedding.shape
torch.Size([100, 16])
= plt.subplots(figsize=(4, 4))
fig, ax for i in range(2):
=f"Sine Dimension #{i+1}")
ax.plot(embedding[:, i], labelfor i in range(8, 10):
=f"Cosine Dimension #{i+1}")
ax.plot(embedding[:, i], labelset(ylabel="Logit", xlabel="Timestep")
ax.; fig.legend()
Overall, this results in a time embedding that is similar to its neighbors, but is overall very diverse. This figure demonstrates this, where the first column is similar to the second column, but dissimilar to the 100th column.
; show_image(embedding.T)
We can consolidate this.
timestep_embedding
timestep_embedding (ts, emb_dim, max_period=10000)
Editting the number of dimensions increases the column size
50).T); show_image(timestep_embedding(ts,
Decreasing the max_period increases inter-column heterogeneity.
16, 100).T); show_image(timestep_embedding(ts,
Let’s reimplement the U-net with time
Conv
Conv (c_in, c_out, ks=3, stride=1)
Wrapper for a Conv block with normalization and activation
EmbeddingPreactResBlock
EmbeddingPreactResBlock (t_embed, c_in, c_out, ks=3, stride=2)
Conv res block with the preactivation configuration
SaveTimeActivationMixin
SaveTimeActivationMixin ()
Helper to save the output of the downblocks to consume in the upblocks
TResBlock
TResBlock (t_embed, c_in, c_out, ks=3, stride=2)
Res block with saved outputs
TDownblock
TDownblock (t_embed, c_in, c_out, downsample=True, n_layers=1)
A superblock consisting of many downblocks of similar resolutions
TUpblock
TUpblock (t_embed, c_in, c_out, upsample=True, n_layers=1)
A superblock consisting of many upblocks of similar resolutions and logic to use the activations of the counterpart downblock.
TimeEmbeddingMLP
TimeEmbeddingMLP (c_in, c_out)
Small neural network to alter the “raw” time embeddings
TUnet
TUnet (nfs=(224, 448, 672, 896), n_blocks=(3, 2, 2, 1, 1), color_channels=3)
Diffusion U-net with a diffusion time dimension
FashionDDPM
FashionDDPM ()
Training specific behaviors for the Learner
train
train (model, dls, lr=0.004, n_epochs=25, extra_cbs=[], loss_fn=<function mse_loss>)
= train(
un
TUnet(=1,
color_channels=(32, 64, 128, 256, 384),
nfs=(3, 2, 1, 1, 1, 1),
n_blocks
),
dls,=4e-3,
lr=25,
n_epochs )
loss | epoch | train |
---|---|---|
0.150 | 0 | train |
0.087 | 0 | eval |
0.056 | 1 | train |
0.054 | 1 | eval |
0.046 | 2 | train |
0.051 | 2 | eval |
0.040 | 3 | train |
0.044 | 3 | eval |
0.037 | 4 | train |
0.044 | 4 | eval |
0.035 | 5 | train |
0.041 | 5 | eval |
0.034 | 6 | train |
0.044 | 6 | eval |
0.032 | 7 | train |
0.037 | 7 | eval |
0.031 | 8 | train |
0.031 | 8 | eval |
0.030 | 9 | train |
0.033 | 9 | eval |
0.029 | 10 | train |
0.035 | 10 | eval |
0.029 | 11 | train |
0.032 | 11 | eval |
0.029 | 12 | train |
0.035 | 12 | eval |
0.028 | 13 | train |
0.030 | 13 | eval |
0.028 | 14 | train |
0.028 | 14 | eval |
0.027 | 15 | train |
0.028 | 15 | eval |
0.027 | 16 | train |
0.029 | 16 | eval |
0.027 | 17 | train |
0.027 | 17 | eval |
0.026 | 18 | train |
0.027 | 18 | eval |
0.026 | 19 | train |
0.027 | 19 | eval |
0.026 | 20 | train |
0.027 | 20 | eval |
0.026 | 21 | train |
0.026 | 21 | eval |
0.026 | 22 | train |
0.025 | 22 | eval |
0.026 | 23 | train |
0.026 | 23 | eval |
0.025 | 24 | train |
0.026 | 24 | eval |
CPU times: user 13min 10s, sys: 24.9 s, total: 13min 34s
Wall time: 13min 42s
= dls.peek()
xb, _ xb.shape
torch.Size([512, 1, 32, 32])
= 100
n_steps = torch.linspace(1 - (1 / n_steps), 0, n_steps).to(xb.device)
ts = torch.randn(16, 3, 64, 64) x_t
ddpm
ddpm (model, sz=(16, 1, 32, 32), device='cpu', n_steps=100)
= ddpm(un, (8, 1, 32, 32))
x_0, _ =0.8) show_images(x_0, imsize
100%|████████████████████████████████████████████████████████████| 99/99 [00:00<00:00, 145.43time step/s]
This is not bad! This achieves a similar performance to the Huggingface implementation with a simliar number of parameters.