Investigating the loss kink while training diffusion U-Net

Why does the loss hover around 1?
aesthetics()
dls = get_fashion_dls(bs=512)
m = TUnet(
    color_channels=1,
    nfs=(32, 64, 128, 256, 384),
    n_blocks=(3, 2, 1, 1, 1, 1),
)
blocks = [*m.downblocks, *m.upblocks]
stats = StoreModuleStatsCB(
    sum([b.convs for b in blocks], nn.ModuleList()),
    hook_kwargs={"periodicity": 1},
)
train(
    m,
    dls,
    lr=4e-3,
    n_epochs=2,
    extra_cbs=[stats],
)
stats.mean_std_plot()
loss epoch train
0.121 0 train
0.074 0 eval
0.044 1 train
0.040 1 eval

CPU times: user 19min 6s, sys: 4min 29s, total: 23min 35s
Wall time: 6min 56s

class KTUnet(TUnet, KaimingMixin):
    ...
m = KTUnet.kaiming(
    color_channels=1,
    nfs=(32, 64, 128, 256, 384),
    n_blocks=(3, 2, 1, 1, 1, 1),
)
blocks = [*m.downblocks, *m.upblocks]
stats = StoreModuleStatsCB(
    sum([b.convs for b in blocks], nn.ModuleList()),
    hook_kwargs={"periodicity": 1},
)
train(
    m,
    dls,
    lr=1e-3,
    n_epochs=2,
    extra_cbs=[stats],
)
stats.mean_std_plot()
loss epoch train
0.866 0 train
0.754 0 eval
0.128 1 train
0.097 1 eval

CPU times: user 17min 40s, sys: 7.59 s, total: 17min 48s
Wall time: 4min 10s

This experiment demonstrates that the issue is with Kaiming initialization, which I analyze further here in the FastAI forums.