= [
pipe 32, 32)),
transforms.Resize((
transforms.PILToTensor(),float),
transforms.ConvertImageDtype(torch.lambda x: x - 0.5),
transforms.Lambda( ]
Augmentation with fixed normality
Training a classifier that uses the same dataset as the DDPM models, with the same normalization routine
This dataloader code is copied from 13_ddpm
.
def get_dls(bs=128):
return tensorize_images(
"fashion_mnist", bs=bs),
DataLoaders.from_hf(=pipe,
pipe=False,
normalize ).listify()
= get_dls() dls
= dls.peek()
xb, _ min(), xb.max() xb.
(tensor(-0.5000), tensor(0.5000))
= ResNetWithGlobalPoolingInitialConv.kaiming(nfs=[32, 64, 128, 256, 512, 512])
mz *mz.layers, mz.lin, mz.norm])
summarize(mz, [=dls, n_epochs=10) train(mz, dls
Type | Input | Output | N. params | MFlops |
---|---|---|---|---|
Conv | (8, 1, 28, 28) | (8, 32, 28, 28) | 864 | 0.6 |
ResidualConvBlock | (8, 32, 28, 28) | (8, 64, 14, 14) | 57,664 | 11.2 |
ResidualConvBlock | (8, 64, 14, 14) | (8, 128, 7, 7) | 230,016 | 11.2 |
ResidualConvBlock | (8, 128, 7, 7) | (8, 256, 4, 4) | 918,784 | 14.7 |
ResidualConvBlock | (8, 256, 4, 4) | (8, 512, 2, 2) | 3,672,576 | 14.7 |
ResidualConvBlock | (8, 512, 2, 2) | (8, 512, 1, 1) | 4,983,296 | 5.0 |
Linear | (8, 512) | (8, 10) | 5,130 | 0.0 |
BatchNorm1d | (8, 10) | (8, 10) | 20 | 0.0 |
Total | 9,868,350 |
MulticlassAccuracy | loss | epoch | train |
---|---|---|---|
0.866 | 0.557 | 0 | train |
0.859 | 0.429 | 0 | eval |
0.891 | 0.352 | 1 | train |
0.872 | 0.352 | 1 | eval |
0.903 | 0.282 | 2 | train |
0.887 | 0.313 | 2 | eval |
0.917 | 0.238 | 3 | train |
0.894 | 0.303 | 3 | eval |
0.928 | 0.204 | 4 | train |
0.891 | 0.321 | 4 | eval |
0.939 | 0.173 | 5 | train |
0.911 | 0.263 | 5 | eval |
0.950 | 0.138 | 6 | train |
0.917 | 0.244 | 6 | eval |
0.966 | 0.095 | 7 | train |
0.928 | 0.239 | 7 | eval |
0.982 | 0.054 | 8 | train |
0.932 | 0.228 | 8 | eval |
0.994 | 0.026 | 9 | train |
0.934 | 0.231 | 9 | eval |
# Save this for later
"../models/fashion_mnist_classifier.pt") torch.save(mz,