Attention and Conditionality

In this module, we implement attention and conditionality

Attention is a “high-pass” filter that can help with the limitations of a “low-pass” filter such as a convolution.

Huggingface diffusers implements a 1D-attention by rasterizing the image, which is how we will implement this (although this is known to be “suboptimal,” accourding to Howard.) Then, each pixel has a \(C\)-dimensional embedding.

Johno mentions that softmax tends to assign all weight to a single dimension, but this is often undesirable. Multi-headedness compensates for this by creating many orthoganol subspaces within which attention is assigned.

 MultiheadSelfAttention1D (nc, nh)

Multi-head self-attention

attn = MultiheadSelfAttention1D(nc=32, nh=4)
x = torch.randn(8, 32, 16, 16)
torch.Size([8, 32, 16, 16])

Let’s use this in a Diffusion UNet. Note that we cannot use attention near the beginning or head due to the quadratic time and space capacity (there are more time steps at these points). Typically, attention is only used when the feature map is 16x16 or 32x32 or higher.



 TAResBlock (t_embed, c_in, c_out, ks=3, stride=2, nh=None)

Res-block with attention



 TADownblock (t_embed, c_in, c_out, downsample=True, n_layers=1, nh=None)

Resdownblock with attention



 TAUpblock (t_embed, c_in, c_out, upsample=True, n_layers=1, nh=None)

Resupblock with attention



 TAUnet (nfs=(224, 448, 672, 896), attention_heads=(0, 8, 8, 8),
         n_blocks=(3, 2, 2, 1, 1), color_channels=3)

U-net with attention up/down-blocks

dls = get_fashion_dls(512)
un = train(
        nfs=(32, 64, 128, 256, 384),
        n_blocks=(3, 2, 1, 1, 1, 1),
        attention_heads=(0, 8, 8, 8, 8, 8),
x_0, _ = ddpm(un, (8, 1, 32, 32), n_steps=100)
show_images(x_0, imsize=0.8)
del un

Adding conditionality



 ConditionalTAUnet (n_classes, nfs=(224, 448, 672, 896),
                    attention_heads=(0, 8, 8, 8), n_blocks=(3, 2, 2, 1,
                    1), color_channels=3)

U-net with attention up/down-blocks



 ConditionalFashionDDPM ()

Training specific behaviors for the Learner



 conditional_train (model, dls, lr=0.004, n_epochs=25, extra_cbs=[],
                    loss_fn=<function mse_loss>)
un = conditional_train(
        nfs=(32, 64, 128, 256, 384),
        n_blocks=(3, 2, 1, 1, 1, 1),
        attention_heads=(0, 8, 8, 8, 8, 8),
 conditional_ddpm (model, c, sz=(16, 1, 32, 32), device='cpu',
c = torch.arange(0, 9)
x_0, _ = conditional_ddpm(un, c, (9, 1, 32, 32))
show_images(x_0, imsize=0.8);
x_0, _ = conditional_ddpm(un, c, (9, 1, 32, 32))
show_images(x_0, imsize=0.8);
