= get_imagenet_dls(bs=64) dls
Scaling up U-net diffusion
In this module, we train a conditional U-net with a larger dataset
Adapted from: - https://www.youtube.com/watch?v=8AgZ9jcQ9v8&list=PLfYUBJiXbdtRUvTUYpLdfHHp9a58nWVXP&index=17
How many classes do we have?
= set()
cs for xb, c in dls["test"]:
cs.update(c.tolist())max(cs)
199
8, ...], imsize=0.8) show_images(denorm(xb)[:
= ConditionalTAUnet(
un =200,
n_classes=3,
color_channels=(32, 64, 128, 256, 384, 512),
nfs=(3, 2, 1, 1, 1, 1, 1),
n_blocks=(0, 8, 8, 8, 8, 8, 8),
attention_heads
)f"{sum(p.numel() for p in un.parameters()):,}"
'34,446,784'
= conditional_train(un, dls, lr=4e-3, n_epochs=1) un
loss | epoch | train |
---|---|---|
0.148 | 0 | train |
0.107 | 0 | eval |
CPU times: user 6min 15s, sys: 1min 12s, total: 7min 28s
Wall time: 7min 31s
= 8
N = conditional_ddpm(un, torch.arange(N), (N, 3, 64, 64))
x_0, _ =0.8); show_images(denorm(x_0.cpu()), imsize
100%|████████████████████████████████████████████████████████████████████████████████████| 99/99 [00:01<00:00, 73.02time step/s]
Not great 😂