dls = get_imagenet_dls(bs=64)Scaling up U-net diffusion
In this module, we train a conditional U-net with a larger dataset
Adapted from: - https://www.youtube.com/watch?v=8AgZ9jcQ9v8&list=PLfYUBJiXbdtRUvTUYpLdfHHp9a58nWVXP&index=17
How many classes do we have?
cs = set()
for xb, c in dls["test"]:
cs.update(c.tolist())
max(cs)199
show_images(denorm(xb)[:8, ...], imsize=0.8)
un = ConditionalTAUnet(
n_classes=200,
color_channels=3,
nfs=(32, 64, 128, 256, 384, 512),
n_blocks=(3, 2, 1, 1, 1, 1, 1),
attention_heads=(0, 8, 8, 8, 8, 8, 8),
)
f"{sum(p.numel() for p in un.parameters()):,}"'34,446,784'
un = conditional_train(un, dls, lr=4e-3, n_epochs=1)| loss | epoch | train |
|---|---|---|
| 0.148 | 0 | train |
| 0.107 | 0 | eval |

CPU times: user 6min 15s, sys: 1min 12s, total: 7min 28s
Wall time: 7min 31s
N = 8
x_0, _ = conditional_ddpm(un, torch.arange(N), (N, 3, 64, 64))
show_images(denorm(x_0.cpu()), imsize=0.8);100%|████████████████████████████████████████████████████████████████████████████████████| 99/99 [00:01<00:00, 73.02time step/s]

Not great 😂