Training a Simple Transformer Neural Net on Conway's Game of Life
We create a simplified transformer neural network, and train it compute Conway’s Game of Life from examples of the game.
This exercise presents pretty much the simplest form of a transformer, and Life is an easy source of data for it. (Running the Game of Life on transformers is obviously not of practical use.)
We call the model SingleAttentionNet, because it uses just a single attention block, with singlehead attention.
Before we get into the details, here’s a Life game, computed by a SingleAttentionNet model.
And the following plot shows examples of the SingleAttentionNet model’s attention matrix, over the course of training:
The pattern that emerges is the model learning to attend to just the 8 neighbours of each cell. The attention of the model becomes nearequivalent to a 3x3 average pool, as is used in convolutional neural networks, except unlike an average pool, it excludes the middle cell from the average. It is vastly more efficient to directly use an average pool, rather than an attention layer, but it’s interesting to show that the attention layer can learn to approximate it. (We found that average pooling does also work, even with the middle cell included.)
(For a recap of the rules of Life, check out the appendix.)
Problem formulation and training loop
We formulate the problem as:
next_life_grid = model(life_grid)
Which means the model will take a life_grid
as input,
and the output will be the state of the grid in the next step,
next_life_grid
.
If we usually run Life games using the function, life_step
, e.g.
for _ in range(num_steps):
life_grid = life_step(life_grid)
We could replace that function with our model:
for _ in range(num_steps):
life_grid = model(life_grid)
In order to train our model,
we show it many examples of
(life_grid, next_life_grid)
pairs.
We can generate a practically limitless amount of these,
by randomly initialising grids and running the Game of Life on them.
The following plot shows some examples,
where each row represents a pair.
And here is a simplified version of the training loop we use:
for life_grid, next_life_grid in life_data_generator():
predicted_next_life_grid = model(life_grid)
loss = loss_fn(predicted_next_life_grid, next_life_grid)
run_gradient_descent_step(model, loss)
The model
Our model uses embeddings to represent a Life grid as a set of tokens, with one token per grid cell. These tokens then go through singlehead attention, a hidden layer, and a classifier head, which classifies each token/grid cell as dead or alive in the next step.
This section presents code for the model, then a diagram of the model, and finally a more detailed description of the model.
Model code
class SingleAttentionNet(torch.nn.Module):
def __init__(self, grid_dim: int, ndim: int):
super().__init__()
self.num_grid_cells = grid_dim * grid_dim
self.sqrt_ndim = math.sqrt(ndim)
self.W_state = weight_matrix(h=2, w=ndim, type="embedding")
self.W_positional = weight_matrix(h=self.num_grid_cells, w=ndim, type="embedding")
self.W_q = weight_matrix(h=ndim, w=ndim, type="weight")
self.W_k = weight_matrix(h=ndim, w=ndim, type="weight")
self.W_v = weight_matrix(h=ndim, w=ndim, type="weight")
self.W_h = weight_matrix(h=ndim, w=ndim, type="weight")
self.W_c = weight_matrix(h=ndim, w=1, type="weight")
def forward(self, life_grids: torch.Tensor) > dict[str, torch.Tensor]:
# The input is a batch of grids,
# life_grids.shape = [b, grid_dim, grid_dim],
# Flatten the grids
x = life_grids.reshape(1, self.num_grid_cells) # [b, num_grid_cells]
# Use the embeddings to represent the grids as tokens
x = self.W_state[x] + self.W_positional # [b, num_grid_cells, ndim]
# Singlehead attention
q = x @ self.W_q # [b, num_grid_cells, ndim]
k = x @ self.W_k # [b, num_grid_cells, ndim]
attention_matrix = torch.softmax(
q @ k.transpose(1, 2) / self.sqrt_ndim, dim=1
) # [b, num_grid_cells, num_grid_cells]
v = x @ self.W_v # [b, num_grid_cells, ndim]
x = x + attention_matrix @ v # skip connection, [b, num_grid_cells, ndim]
# # Hidden layer
x = x + torch.nn.functional.silu(x @ self.W_h) # [b, num_grid_cells, ndim]
# # Classifier head
x = x @ self.W_c # [b, num_grid_cells, 1]
return x, attention_matrix
Model diagram
The model in the diagram processes 2by2 Life grids, which means 4 tokens in total per grid. Blue text indicates parameters that are learned via gradient descent. The arrays are labelled with their shape, (with the batch dimension omitted).
Detailed model description
Input tokens
The model represents each grid cell of the Game of Life as a token, (a vector of size ndim
).
A given model instance will be trained on a fixed size Life grid.
The model will construct its input tokens by adding positional embeddings to cell state embeddings.
There will be one positional embedding for each grid cell.
There will be two cell state embeddings — one to represent that a cell is alive, and one to represent that a cell is dead.
The embeddings will be randomly initialised, and then learned through gradient descent.
Singlehead attention
From the paper Attention Is All You Need, single head attention will compute a weighted sum of a projection of the input tokens, for each token, as determined by a square attention matrix that is computed for each before_grid
.
If we call the input tokens, $T$, then we can write singlehead attention as, $x = softmax( \frac{ (T W_q) (T W_k )^T}{√d_k} ) (T W_v)$, where $W_q$, $W_k$ and $W_v$ are weight matrices that will be randomly initialised and then learned through gradient descent. The formula is more commonly written as $x = softmax( \frac{QK^T}{√d_k} )V$, where $Q = T W_q$, $K = T W_k$ and $V = T W_v$. We refer to $Q$, $K$ and $V$ as linear projections of the input tokens. We can break the formula up into an attention matrix, $A = softmax( \frac{QK^T}{√d_k})$, and the attention output, $x = AV$. The attention output will be a set of tokens, of the same shape as the input tokens, where each output token is a weighted sum of a linear projection of the input tokens, (i.e. a weighted sum of all the rows in $V$). The factor, $\frac{1}{√d_k}$, is a constant for a given model instance, (called ndim
in the code below). The softmax is applied to each row of the attention matrix, making the values positive and each row sum to 1, it also tends to make the larger values in each row larger relative to the smaller values in the row.
Hidden layer
This is a single neural network layer, that operates on each token individually. It uses a weight matrix that will be randomly initialised and learned through gradient descent. The hidden layer and the singlehead attention layer that precedes it can together be referred to as an “attention block”; where transformers typically have multiple attention blocks in series.
Classifier layer
This will take each of the tokens above, (recall there’s one for each grid cell), and decide whether it should be dead or alive in the next Life step. It uses a weight matrix that will be randomly initialised and learned through gradient descent.
Training
On a GPU, training the model takes anywhere from a couple of minutes, to 10 minutes, or a seemingly indefinite amount of time, depending on the seed, and other training hyperparameters. The largest grid size we successfully trained on was 16x16.
Notes
If we replaced the attention layer of the model with a manually computed Neighbour Attention matrix, the model learned its task far quicker, and generalised to arbitrary grid sizes. We found that the same was true for replacing the layer with a 3by3 average pool.
We checked that the model worked by looking for 1024
batches with 100% accuracy,
and then testing the model on 100 Life games for 100 steps each.
We found that training it on just the first step after randomly initialising a grid wasn’t enough for it to pass the 100 Life games for 100 steps test, and so randomly introduced pairs with an extra step taken.
The code is here, on GitHub.
Appendix
The rules of Life
Life takes place on a 2D grid with cells that are either dead or alive, (represented by 0 or 1). A cell has 8 neighbours, which are the cells immediately next to it on the grid.
To progress to the next Life step, the following rules are used:
 If a cell has 3 neighbours, it will be alive in the next step, regardless of it’s current state, (alive or dead).
 If a cell is alive and has 2 neighbours, it will stay alive in the next step.
 Otherwise, a cell will be dead in the next step.
These rules are shown in the following plot.
References:

Springer et al  2020  It’s Hard For Neural Networks to Learn the Game of Life  https://arxiv.org/pdf/2009.01398.pdf

McGuigan  2021  Its Easy for Neural Networks To Learn Game of Life  https://www.kaggle.com/code/jamesmcguigan/itseasyforneuralnetworkstolearngameoflife

Vaswani et al  2017  Attention Is All You Need  https://arxiv.org/abs/1706.03762

Conway’s Game of Life  https://en.wikipedia.org/wiki/Conway%27s_Game_of_Life
Citation:
@misc{radcliffe_life_transformer_2024,
title={Training a Simple Transformer Neural Net on Conway's Game of Life},
url={https://sidsite.com/posts/lifetransformer/},
howpublished={Main page: \url{https://sidsite.com/posts/lifetransformer/}, GitHub repository: \url{https://github.com/sradc/trainingasimpletransformeronconwaysgameoflife}},
author={Radclffe, Sidney},
year={2024},
month={July}
}