import random
from functools import partial
from pdb import set_trace as bp
import matplotlib.pyplot as plt
import numpy as np
import timm
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from einops import rearrange
from IPython.display import display
from torch import nn, tensor
from tqdm import trange
from slowai.learner import (
DataLoaders,
Learner,
MetricsCB,
ProgressCB,
TrainCB,
TrainLearner,
def_device,
)from slowai.style_transfer import GramLoss, pt_normalize_imagenet
from slowai.utils import download_image, show_image, show_images
Neural Cellular Automata
Complex behavior can emerge from simple rules. Conway’s game of life is a famous example.
This lesson was inspired by this distil.pub article that demonstrates a self-organizing, self-repairing system. How do we train something like this?
This starts with a neural network that takes the cell state and that that of its neighbors and predicts a evolution that leads to a particular image. Unfortunately, this doesn’t produce a stable output. We also need to train with random initialization that gives it the ability to correct and maintain its shape.
We’ll train a simple texture restoration model, which allows us to leverage the Gram loss from the previous module.
Defining the model
= tensor([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]])
I = tensor([[-1.0, 0.0, 1.0], [-2.0, 1.0, 2.0], [-1.0, 0.0, 0.1]])
G = tensor([[1.0, 2.0, 1.0], [2.0, -12.0, 2.0], [1.0, 2.0, 1.0]])
S = torch.stack([I, G, G.T, S]).to(def_device)
filters = filters.shape channels, _, _
show_images([I, G, G.T, S])
The last filter is called a Sobel operator.
def make_grid(n, sz=128):
return torch.zeros(n, channels, sz, sz).to(def_device)
def apply_filters(x):
= x.shape
b, c, w, h = rearrange(x, "b c h w -> (b c) h w").unsqueeze(1)
y = F.pad(y, (1, 1, 1, 1), "circular")
y = F.conv2d(y, filters.unsqueeze(1))
y return y.reshape(b, -1, w, h)
= make_grid(1)
grid grid.shape
torch.Size([1, 4, 128, 128])
Our “world” is \(128 \times 128\) and each position carries four data points or channels.
= apply_filters(grid)
x x.shape
torch.Size([1, 16, 128, 128])
This gives us 16 model inputs per pixel, which comes from the four filters applied to each of the four channels .
class LinearBrain(nn.Module):
def __init__(self, grid, nh=8, nc=4, nf=4):
"""NCA update model
Args:
grid: grid, needed for shape
nh: number of hidden dimensions
nc: number of input channels
nf: number of filters
"""
super().__init__()
= [
layers # Bias must be true here to break the symmetry of a newly
# initialized zero-filled grid
* nf, nh, bias=True),
nn.Linear(nc
nn.ReLU(),# The bias is false here because updates should be centered
# around 0; and, we also want to keep the number of parameters
# to a minimum
=False),
nn.Linear(nh, nc, bias
]self.layers = nn.ModuleList(layers)
self.grid = grid
def forward(self, x):
= rearrange(x, "b c h w -> (b h w) c")
x for layer in self.layers:
= layer(x)
x return x.reshape(self.grid.shape)
= LinearBrain(grid)
m
m.to(def_device) m.forward(x).shape
torch.Size([1, 4, 128, 128])
An alernate approach to the reshaping to use a convolution with a kernel size of 1.
class Brain(nn.Module):
def __init__(self, nh=8, nc=4, nf=4):
"""NCA update model
Args:
nh: number of hidden dimensions
nc: number of input channels
nf: number of filters
"""
super().__init__()
= [
layers * nf, nh, kernel_size=1),
nn.Conv2d(nc
nn.ReLU(),=1, bias=False),
nn.Conv2d(nh, nc, kernel_size
]self.layers = nn.ModuleList(layers)
def forward(self, x):
for layer in self.layers:
= layer(x)
x return x
= Brain()
m
m.to(def_device) m.forward(x).shape
torch.Size([1, 4, 128, 128])
This is quite elegant! It’s also highly performant on GPUs, since they were designed to run matrix operations on each pixel.
Consolidating the model
Let’s put this all into a class
class NCA(Brain):
@torch.no_grad()
def init_(self):
= self.layers[-1]
w2
w2.weight.data.zero_()
def forward(self, grid, update_rate=0.5):
= apply_filters(grid)
y for layer in self.layers:
= layer(y)
y = y.shape
b, c, h, w # Randomly dropout some updates to reflect the non-global
# update behavior of biological systems
= F.dropout(y, update_rate)
y return grid + y
= NCA()
m
m.to(def_device)= m.forward(grid)
x x.shape, grid.shape
(torch.Size([1, 4, 128, 128]), torch.Size([1, 4, 128, 128]))
Training
= "https://sanctuarymentalhealth.org/wp-content/uploads/2021/03/The-Starry-Night-1200x630-1-979x514.jpg"
starry_night = download_image(starry_night)
target = pt_normalize_imagenet(target)
target ; show_image(target)
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
def to_rgb(x):
return x[:, :3, :, :] + 0.5
class StyleLoss:
def __init__(
self,
target_img,=(1, 6, 11, 18, 25),
target_layers=None,
vgg
):if vgg is None:
self.vgg = timm.create_model("vgg16", pretrained=True).to(def_device)
else:
self.vgg = vgg
for p in self.vgg.parameters():
# No need to train VGG
= False
p.requires_grad
self.target_layers = target_layers
with torch.no_grad():
self.tgt = self.grams(target_img.to(def_device))
def grams(self, x):
= pt_normalize_imagenet(x)
x if len(x) < 4:
= x.unsqueeze(0)
x = []
grams_ for i, layer in enumerate(self.vgg.features[: max(self.target_layers) + 1]):
= x.shape
b, c, h, w = layer(x)
x if i in self.target_layers:
= x.clone() # Not sure if I need this
f = torch.einsum("bchw, bdhw -> bcd", f, f) / (h * w)
g
grams_.append(g)return grams_
def __call__(self, img):
= self.grams(img)
src # Writing MSE out manually here helps by broadcasting the style gram
# matrices to each of the sample image gram matrices
return sum((f1 - f2).pow(2).mean() for f1, f2 in zip(src, self.tgt))
= StyleLoss(torch.randn((3, 64, 64)).to(def_device)) loss_f
def train(
stlye_img,=0.1,
style_loss_scale=128,
n=256,
sz=4,
bs=32,
step_n_min=96,
step_n_max=1e-3,
lr=1200,
train_iterations=(32, 96),
model_application_iterations
):= NCA()
nca
nca.to(def_device)
nca.init_()= make_grid(n, sz=sz).to(def_device)
pool = StyleLoss(stlye_img)
loss_f = torch.optim.Adam(nca.parameters(), lr)
opt
= None
ipy_output = 3.5
K = plt.subplots(1, 3, figsize=(K * 3, K))
fig, (a0, a1, a2) = []
losses
= trange(train_iterations)
pbar for i in pbar:
# Subsample with replacement
= torch.randint(0, n, (bs,))
subpool_idxs = pool[subpool_idxs]
subpool
# Randomly zero out samples
if random.random() > 0.8:
1] = make_grid(1, sz=sz).to(def_device)
subpool[:
# Apply the model
= model_application_iterations
min_, max_ = random.randrange(min_, max_ + 1)
n_iterations for _ in range(n_iterations):
= nca(subpool)
subpool
if i > 0:
assert not (subpool == 0).all()
# Update the pool
with torch.no_grad():
= subpool
pool[subpool_idxs]
# Compute loss
= loss_f(to_rgb(subpool)) * style_loss_scale
style_loss = (subpool - subpool.clamp(-1.0, 1.0)).abs().sum()
overflow_loss = style_loss + overflow_loss
loss
losses.append((loss.item(), style_loss.item(), overflow_loss.item()))
f"{style_loss.item():.2f} {overflow_loss.item():.2f}")
pbar.set_description(if i % 100 == 0 and i > 0:
= range(0, i + 1)
x = zip(*losses)
combined, style_losses, overflow_losses for ax, y, label in [
"style"),
(a0, style_losses, "overflow"),
(a1, overflow_losses, "overall"),
(a2, combined,
]:
ax.clear()=label)
ax.scatter(x, y, label"log")
ax.set_yscale(
ax.legend()
fig.tight_layout()
if ipy_output is None:
= display(fig, display_id=True)
ipy_output else:
ipy_output.update(fig)
# Backprop with gradient normalization
loss.backward()for p in nca.parameters():
/= p.grad.norm() + 1e-8
p.grad
opt.step()
opt.zero_grad()
return nca, pool.detach()
= train(target) model, pool
1258.36 34410.16: 9%|███████████▎ | 108/1200 [00:10<01:28, 12.38it/s]
750.20 28.38: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1200/1200 [01:56<00:00, 10.28it/s]
0, 1)[:8, ...]) show_images(to_rgb(pool).clip(
Starting from an empty grid
= []
images = make_grid(n=1)
x for i in range(90):
= model(x)
x if i % 10 == 0:
= to_rgb(x).clip(0, 1).squeeze()
imgs
images.append(imgs) show_images(images)