pipe = [
transforms.Resize((32, 32)),
transforms.PILToTensor(),
transforms.ConvertImageDtype(torch.float),
transforms.Lambda(lambda x: x - 0.5),
]Augmentation with fixed normality
Training a classifier that uses the same dataset as the DDPM models, with the same normalization routine
This dataloader code is copied from 13_ddpm.
def get_dls(bs=128):
return tensorize_images(
DataLoaders.from_hf("fashion_mnist", bs=bs),
pipe=pipe,
normalize=False,
).listify()dls = get_dls()xb, _ = dls.peek()
xb.min(), xb.max()(tensor(-0.5000), tensor(0.5000))
mz = ResNetWithGlobalPoolingInitialConv.kaiming(nfs=[32, 64, 128, 256, 512, 512])
summarize(mz, [*mz.layers, mz.lin, mz.norm])
train(mz, dls=dls, n_epochs=10)| Type | Input | Output | N. params | MFlops |
|---|---|---|---|---|
| Conv | (8, 1, 28, 28) | (8, 32, 28, 28) | 864 | 0.6 |
| ResidualConvBlock | (8, 32, 28, 28) | (8, 64, 14, 14) | 57,664 | 11.2 |
| ResidualConvBlock | (8, 64, 14, 14) | (8, 128, 7, 7) | 230,016 | 11.2 |
| ResidualConvBlock | (8, 128, 7, 7) | (8, 256, 4, 4) | 918,784 | 14.7 |
| ResidualConvBlock | (8, 256, 4, 4) | (8, 512, 2, 2) | 3,672,576 | 14.7 |
| ResidualConvBlock | (8, 512, 2, 2) | (8, 512, 1, 1) | 4,983,296 | 5.0 |
| Linear | (8, 512) | (8, 10) | 5,130 | 0.0 |
| BatchNorm1d | (8, 10) | (8, 10) | 20 | 0.0 |
| Total | 9,868,350 |
| MulticlassAccuracy | loss | epoch | train |
|---|---|---|---|
| 0.866 | 0.557 | 0 | train |
| 0.859 | 0.429 | 0 | eval |
| 0.891 | 0.352 | 1 | train |
| 0.872 | 0.352 | 1 | eval |
| 0.903 | 0.282 | 2 | train |
| 0.887 | 0.313 | 2 | eval |
| 0.917 | 0.238 | 3 | train |
| 0.894 | 0.303 | 3 | eval |
| 0.928 | 0.204 | 4 | train |
| 0.891 | 0.321 | 4 | eval |
| 0.939 | 0.173 | 5 | train |
| 0.911 | 0.263 | 5 | eval |
| 0.950 | 0.138 | 6 | train |
| 0.917 | 0.244 | 6 | eval |
| 0.966 | 0.095 | 7 | train |
| 0.928 | 0.239 | 7 | eval |
| 0.982 | 0.054 | 8 | train |
| 0.932 | 0.228 | 8 | eval |
| 0.994 | 0.026 | 9 | train |
| 0.934 | 0.231 | 9 | eval |

# Save this for later
torch.save(mz, "../models/fashion_mnist_classifier.pt")