A Reverse-Engineer Approach to Explain Attention and Memory

In this post, I collected a lot of resources, to draw the history of attention, with example codes and attempts, as well as an attempt to explain why attention is so powerful.

Stefano Bosisio
19 min readDec 8, 2023
Image by Silvio Kundt on Unsplash

Support my writing

If you enjoyed my article, please support my writing by joining Medium’s membership through the link below :)

All opinions and content expressed in this work are solely my own and do not reflect the views or opinions of my employer.

In our daily lives, attention serves as the compass guiding our focus through the vast landscape of information, enabling us to selectively engage with specific stimuli while filtering out extraneous noise and chaos. It is the foundation upon which our ability to comprehend, learn, and make decisions is built.

However, attention does not exist in isolation; it intertwines intimately with the concept of memory. Memory acts as the custodian of our past experiences, shaping our perceptions and influencing the direction of our attention. Together, attention and memory form a dynamic duo, orchestrating the symphony of cognition that defines human thought.

In the contemporary landscape of artificial intelligence, the paradigm-shifting introduction of transformers has redefined our conceptualization of attention. These models, exemplified by the transformative architecture of the Transformer model, demonstrate how attention mechanisms can be harnessed to process information in parallel, enabling unprecedented advancements in natural language processing, image recognition, and more.

Have you ever thought that attention is in Convolutional Neural Networks too?

Let’s start with a simple model, a convolutional neural network (CNN). The network acts on the CIFAR10 dataset (Learning Multiple Layers of Features from Tiny Images, MIT License) — You can try to run this code on your laptop, and experiment with a wider number of epochs too.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms


# Define the CNN architecture
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.relu1 = nn.ReLU()
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.relu2 = nn.ReLU()
self.pool2 = nn.MaxPool2d(2, 2)
self.flatten = nn.Flatten()
self.fc1 = nn.Linear(64 * 8 * 8, 128)
self.relu3 = nn.ReLU()
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
x = self.flatten(x)
x = self.relu3(self.fc1(x))
x = self.fc2(x)
return x


# Set random seed for reproducibility
torch.manual_seed(42)

# Define transformations and load the CIFAR dataset
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# Initialize the model, loss function, and optimizer
model = SimpleCNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training the model
num_epochs = 5

for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
# Forward pass
outputs = model(images)

# Calculate the loss
loss = criterion(outputs, labels)

# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()

if (i+1) % 100 == 0:
print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')

# Testing the model
model.eval()
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

print(f'Accuracy on the test set: {100 * correct / total:.2f}%')

This network already has the concept of attention or implicit attention. The network responds better to some parts of the data rather than other parts. This can be mathematically defined through the Jacobian. The Jacobian gives the sensitivity of the network outputs with respect to the inputs:

Eq. 1: Jacobian for a neural network. J is the Jacobian for elements i-j, y is the output from the neural network, x the input.

Eq. 1 shows the Jacobian matrix, where the element Jij is defined as the backdrop of the output of a layer yᵢ, with respect to the input of the layer xⱼ. We can implement in Python the Jacobian calculation as follows:

import numpy as np
import matplotlib.pyplot as plt


# Function to compute Jacobian matrix
def compute_jacobian(model, input_data):
""" Function to compute the Jacobian given a model

Args:
model: (torch.nn) input model
input_data: (torchvision.datasets) input data

Return:
jacobian_matrix: (np.array) Jacobian matrix
"""
input_data.requires_grad_(True)
model.eval()
output = model(input_data)
num_classes = output.size(1)

jacobian_matrix = torch.zeros(num_classes, *input_data.size())

for i in range(num_classes):
model.zero_grad()
output[0, i].backward(retain_graph=True)
jacobian_matrix[i] = input_data.grad.data

return jacobian_matrix


def denormalize(image, mean, std):
""" Function to denormalize an image for better visualization

Args:
image: (torch.tensor) input image
mean: (float) image mean
std: (float) image std dev

Return:
image.numpy(): (np.array) output image
"""
image = image.permute(1, 2, 0) # Channels last
image = image.detach() * std + mean # Detach and denormalize
image = image.clip(0, 1) # Clip to ensure valid image
return image.numpy() # Convert to NumPy array

# Choose a CAR image from the sample dataset
for images, labels in test_dataset:
if labels == 1:
sample_image, sample_label = images, labels
break
sample_image = sample_image.unsqueeze(0)

# Compute the Jacobian matrix for the sample image
jacobian_matrix = compute_jacobian(model, sample_image)
sensitivity_map = torch.sum(torch.abs(jacobian_matrix), dim=0)

# these are the values to denormalize teh image
mean = np.array([0.4914, 0.4822, 0.4465])
std = np.array([0.2023, 0.1994, 0.2010])
denormalized_image = denormalize(sample_image.squeeze(), mean, std)

# plot results
fig, axs = plt.subplots(1, 2, figsize=(10, 5), sharex=True, sharey=True)
axs[0].imshow(denormalized_image)
axs[0].set_title('Original Image')
plt.subplot(1, 2, 2)
# take the average sensitivity and normalize it
average_sensitivity = sensitivity_map.mean(dim=0).mean(dim=0)
normalized_sensitivity = (average_sensitivity - average_sensitivity.min()) / (average_sensitivity.max() - average_sensitivity.min())
normalized_sensitivity_np = normalized_sensitivity.numpy()
im = plt.imshow(normalized_sensitivity_np, cmap='gray')
axs[1].set_title('Average Sensitivity Map')
fig.colorbar(im, ax=axs.ravel().tolist(), orientation='vertical')

plt.show()

Averaging the sensitivity across the channels of an image (R, G, and B) is a common approach when visualizing the sensitivity of a CNN. This gives a single 2D map that shows the areas of the input image that the network is most sensitive to, which can be easier to interpret than a separate map for each output class or channel. Fig.1 shows the sensitivity of the network. As you can deduct the implicit attention, based on convolutional operation, make the network focus on specific points of the image, while the rest of the image is put aside in the classification task (e.g. we can see a white spot around the car door)

Fig. 1: Average sensitivity (Jacobian) of a convolutional neural network, trained on CIFAR10 dataset. The original image, a car on the left, has specific regions where the network is focused on (e.g. the region next to the door and window). Image by the author.

The same concept applies to recurrent neural networks (RNN). This has been one of the first concepts that brought us to the creation of transformers. In RNN, the sensitivity acts sequentially, so that there’s a sequential Jacobian that presents the past inputs so that the network can remember, so the outputs are influenced by these elements.

A step forward, the associative attention in Neural Machine Translation

(If you want to know something more about Bahdanau’s approach check this previous post of mine)

At the moment we treated attention as a fixed-sized window, which make the model focus on parts of the input, that are strategic points, ignoring the rest of the picture.

The real difference, with respect to the “convolutional attention”, is the nature of the convolutional network’s weights. The weights in this network change slowly through time. After usually hundreds, if not thousands, of epochs of all the input data, the weights evolve. On the other side, in sequence-to-sequence networks, like in the Neural Machine Translation (NMT) proposed by Bahdanau in 2014, the weights are almost data-dependent. Being data dependent, the weights of the network are evolving quickly, changing in the middle of the process. This process refers to associative attention or soft attention. In this way, the model focuses on different parts of the input sequence dynamically during the generation of the output sequence.

In a nutshell, NMT is based on

  1. Alignment Scores: a) For each position in the output sequence, the model computes alignment scores that represent the relevance of each position in the input sequence. b) These alignment scores are calculated based on the similarity between the current decoder state and the encoder states.
  2. Softmax and Attention Weights: a) The alignment scores are passed through a softmax function to obtain attention weights. b) The softmax operation ensures that the attention weights sum to 1, effectively creating a probability distribution over the input sequence.
  3. Context Vector: a) The context vector is computed as the weighted sum of the encoder states, where the weights are determined by the attention weights. b) This context vector is then used as additional information when generating the output at the current time step.
  4. Integration with Decoder: a) The context vector is concatenated with the input at the current time step (embedding of the previously generated word). b) This concatenated information is then used as input to the decoder to predict the next word in the sequence.

In this approach, we can find the concept of key-vector, where the key is a representation of the input data. Then, during each decoding step, the decoder produces a query vector. The attention scores are computed by measuring the similarity (usually using a dot product or a learned function) between the query and the keys. These scores are then used to weight the values, creating a context vector that summarizes the relevant information from the input sequence for the current decoding step. The attention weight can be written as:

Eq.2: Attention weights in NMT, where S is a MLP or linear operator, k is the key vector, x is the input vector

Jumping on the code, we can implement everything as follows:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import Multi30k
from typing import Iterable, List

# CONSTANTS
BATCH_SIZE = 32
EPOCHS = 10


# Yield list of tokens
def yield_tokens(data_iter: Iterable, tokenizer) -> List[str]:
for _, text in data_iter:
try:
yield tokenizer(text)
except UnicodeDecodeError:
yield "" # or some other default value


def collate_fn(batch):
de_batch, en_batch = [], []
for (de_item, en_item) in batch:
de_batch.append(torch.tensor([de_vocab[token] for token in de_tokenizer(de_item)], dtype=torch.long))
en_batch.append(torch.tensor([en_vocab[token] for token in en_tokenizer(en_item)], dtype=torch.long))
de_batch = pad_sequence(de_batch, padding_value=pad_idx)
en_batch = pad_sequence(en_batch, padding_value=pad_idx)
return de_batch, en_batch


class Encoder(nn.Module):
def __init__(self, input_size, hidden_size):
super(Encoder, self).__init__()
self.embedding = nn.Embedding(input_size, hidden_size)
self.rnn = nn.GRU(hidden_size, hidden_size)

def forward(self, input_seq):
embedded = self.embedding(input_seq)
output, hidden = self.rnn(embedded)
return output, hidden


class Attention(nn.Module):
def __init__(self, hidden_size):
super(Attention, self).__init__()
self.W_h = nn.Linear(hidden_size, hidden_size)
self.W_s = nn.Linear(hidden_size, hidden_size)
self.v = nn.Linear(hidden_size, 1)

def forward(self, encoder_outputs, decoder_hidden):
# Compute energy scores
energy = torch.tanh(self.W_h(encoder_outputs) + self.W_s(decoder_hidden))
attention_scores = self.v(energy)
# Compute attention weights
attention_weights = F.softmax(attention_scores, dim=0)
# Compute context vector for each time step
context = (attention_weights * encoder_outputs).sum(dim=0)
return context, attention_weights


class Decoder(nn.Module):
def __init__(self, hidden_size, output_size):
super(Decoder, self).__init__()
self.embedding = nn.Embedding(output_size, hidden_size)
self.attention = Attention(hidden_size)
self.rnn = nn.GRU(hidden_size * 2, hidden_size) # Input size is doubled due to attention context
self.fc_out = nn.Linear(hidden_size, output_size)

def forward(self, input_seq, decoder_hidden, encoder_outputs):
embedded = self.embedding(input_seq)
context, attention_weights = self.attention(encoder_outputs, decoder_hidden)
context = context.squeeze(1) # Remove the extra dimension
rnn_input = torch.cat([embedded, context], dim=1)
if len(decoder_hidden.shape) == 2:
decoder_hidden = decoder_hidden.unsqueeze(0)
elif len(decoder_hidden.shape) == 4:
decoder_hidden = decoder_hidden.squeeze(0)

output, hidden = self.rnn(rnn_input.unsqueeze(0), decoder_hidden)
output = self.fc_out(output.squeeze(0))
return output, hidden, attention_weights


class Seq2Seq(nn.Module):
def __init__(self, encoder, decoder, device):
super(Seq2Seq, self).__init__()
self.encoder = encoder
self.decoder = decoder
self.device = device

def forward(self, input_seq, target_seq, teacher_forcing_ratio=0.5):
max_len = target_seq.size(0)
batch_size = target_seq.size(1)
target_vocab_size = self.decoder.fc_out.out_features

# Initialize tensors to store outputs
outputs = torch.zeros(max_len, batch_size, target_vocab_size).to(self.device)
encoder_outputs, encoder_hidden = self.encoder(input_seq)
decoder_hidden = encoder_hidden.squeeze(0) # Initialize decoder hidden state with the last encoder hidden state
decoder_input = target_seq[0, :] # Start with the <SOS> token

# Teacher forcing: use the ground-truth target sequence as the next input
for t in range(1, max_len):
output, decoder_hidden, _ = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
outputs[t] = output
teacher_force = torch.rand(1).item() < teacher_forcing_ratio
top1 = output.argmax(1)
decoder_input = target_seq[t] if teacher_force else top1

return outputs


# Define tokenizer
de_tokenizer = get_tokenizer('spacy', language='de_core_news_sm')
en_tokenizer = get_tokenizer('spacy', language='en_core_web_sm')
# Define special tokens
specials = ['<unk>', '<pad>', '<bos>', '<eos>']
# Load the Multi30k dataset
train_data, valid_data, test_data = Multi30k()
de_vocab = build_vocab_from_iterator(yield_tokens(train_data, de_tokenizer),
specials=specials,
min_freq=2)
en_vocab = build_vocab_from_iterator(yield_tokens(train_data, en_tokenizer),
specials=specials,
min_freq=2)
# Add the <unk> token to the vocabulary
de_vocab.set_default_index(de_vocab['<unk>'])
en_vocab.set_default_index(en_vocab['<unk>'])
# Get the index of the <pad> token
pad_idx = de_vocab['<pad>']


# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_iterator = DataLoader(train_data,
batch_size=BATCH_SIZE,
collate_fn=collate_fn)
valid_iterator = DataLoader(valid_data,
batch_size=BATCH_SIZE,
collate_fn=collate_fn)
test_iterator = DataLoader(test_data,
batch_size=BATCH_SIZE,
collate_fn=collate_fn)

hidden_size = 256
# Create model and define optimizer
encoder = Encoder(len(de_vocab), hidden_size).to(device)
decoder = Decoder(hidden_size, len(en_vocab)).to(device)
model = Seq2Seq(encoder, decoder, device)
optimizer = optim.Adam(model.parameters(), lr=0.001)


for epoch in range(EPOCHS):
for batch in train_iterator:
input_seq, target_seq = batch
input_seq = input_seq.to(device)
target_seq = target_seq.to(device)
output_seq = model(input_seq, target_seq)

# Compute loss and perform backpropagation
loss = F.cross_entropy(output_seq.view(-1, len(en_vocab)), target_seq.view(-1))
optimizer.zero_grad()
loss.backward()
optimizer.step()

# Print some information during training
print(f"Epoch: {epoch + 1}, Loss: {loss.item():.4f}")

# Validation loop
model.eval()
with torch.no_grad():
total_loss = 0
num_items = 0
for batch in valid_iterator:
input_seq, target_seq = batch
output_seq = model(input_seq, target_seq)
val_loss = F.cross_entropy(output_seq.view(-1, len(en_vocab)), target_seq.view(-1))
total_loss += val_loss.item()
num_items += len(batch)

average_val_loss = total_loss / num_items
print(f"Epoch: {epoch + 1}, Validation Loss: {average_val_loss:.4f}")

This is a fundamental building block for deep learning and transformers, The associative attention, in this example, gives an idea of what the network is attending as a token when there’s the translation from German to English. If you have enough time you could try to run the code above for a thousands of epochs. At the end, you could visualize the role of associative attention as done in this plot:

Fig.2 Image by msarmi9 on HuggingFace Multi30k dataset to represent the associative attention of the network for translating from German to English. Reference: https://huggingface.co/msarmi9/multi30k/blob/main/attention-heatmap.png.

There’s a strong signal from attention, namely, attention is driving the network towards specific words. In Fig. 2 you can see a very general way to allow the network to pick up particular elements of the input data.

A fly over the introspective attention

A subsequent approach, that has brought us the Transformer, was made by Graves in the paper “Neural Turing Machine” (NTM). In this paper, the authors are introducing the concept of memory. Memory is attention through time, or introspective attention. Rather than focusing on where to look for the next tokens in a text, in this case, we have an application on the internal state of the network. The network picks up a particular event in time and ignores the rest. The internal information can also be modified or enriched. To achieve this result, the authors introduced the concept of heads. The NTM architecture is mainly made of:

  1. Controller: The controller is the neural network that serves as the “brain” of the NTM. It processes input data and generates output based on its internal state and the contents of the external memory. The controller is often implemented as a recurrent neural network (RNN) or a long short-term memory (LSTM) network. It interacts with the external memory and performs computations over it.
  2. Heads: The NTM has one or more read-and-write heads that allow it to interact with the external memory. Each head can be thought of as a mechanism for reading from or writing to the memory. Read heads are responsible for retrieving information from specific locations in the memory, and write heads are responsible for storing information at specific locations. Each head has its own set of parameters, allowing the network to learn how to use each head effectively.
  3. Memory: The external memory is a matrix (or grid) of addressable locations, and each location contains a vector. The memory serves as an additional storage space that the controller and heads can read from and write to. During operation, the controller decides which locations in the memory to read from or write to, and the read and write heads carry out these operations.

The NTM is trained end-to-end through backpropagation and is capable of learning to perform algorithmic tasks that involve manipulating and storing information over time. The ability to read from and write to a structured external memory allows the NTM to exhibit more sophisticated behaviors compared to traditional neural networks.

The idea behind the NTM is inspired by the Turing machine, a theoretical model of computation proposed by Alan Turing. The NTM extends the capabilities of standard neural networks by incorporating a flexible external memory, making it more suitable for tasks that require reasoning, memory, and sequential processing.

Mathematically, the attention weights are similar to Bahdanau’s ones. In this case, however, we have a key vector emitted by the controller. This key vector is compared to the content of each memory location M[i], using a similarity measure S like the cosine distance. Then, there’s a sharpness parameter β. The sharpness is used to narrow the focus so that it’s possible to find the memories that are closest to the key:

Eq.3: An evolution of Bahdanau’s approach, for the NTM model, where the key vector content is compared with memory location M[i] through a similarity measure and a sharpness parameter β

The self-attention, or “attention is a simply average” (performed with Softmax)

We arrived in 2017 with the famous paper “Attention is all you need” by Vaswani, where the architecture of the transformer has been published for the very first time. In this case, the transform uses attention to continuously transform a complete sequence. In particular, the controller is dropped, and every input sequence now emits a query and a key to be compared to each other. Moreover, the NTM model is brought to the extreme, so that we have multiple heads, achieving multimodal attention. Let’s try to understand how mathematics can help us out here.

Let’s start by understanding conceptually and from the code how the self-attention works. What we have is an input of tokens. We want the tokens to talk to each other so that it’s possible to find the best memories/elements that are closest to each other. To allow this kind of communication, for every i-th token, we can compute the average of all the previous tokens j, where j<i

# use a integer to understand the averaging process
# 4 = Batch size, 8 tokens size
input_tokens = torch.randint(0,9, (4, 8))
result = torch.zeros(4, 8)
for batch in range(4):
for token in range(8):
input_previous = input_tokens[batch, :token+1].float() # Convert to float
print(input_previous)
result[batch, token] = torch.mean(input_previous, 0)

This averaging mechanism can be reached if we make treasure of the lessons learned in the previous papers:

  • Key, Query, and Value Projections: For each token in the sequence, we project it into three vectors: key (k), query (q), and value (v). These projections are linear transformations of the input tokens. Every single token will emit two vectors, a query, and a key. The query vector is what we’re looking for. The key vector is what we contain. A dot product between keys and queries, as we saw above, becomes a weight.
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x)
q = query(x)
v = value(x)
  • Attention Weights Computation: The attention weights are computed as the dot product between the query and key vectors. This reflects the relevance of each token to every other token.
wei = q @ k.transpose(-2, -1)
  • Softmax: To ensure that tokens cannot attend to future tokens in the sequence, a lower triangular mask is applied to the attention weights, setting the upper triangular part to -inf. After applying the mask, the weights are normalized using the softmax function.
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
  • Weighted Aggregation: The final step involves aggregating the values (v) based on the computed attention weights (wei). This aggregation is achieved by multiplying the softmax-normalized weights with the values and summing along the sequence dimension
out = wei @ v

What we have achieved is an effective average that gives the attention weights with the perfect memory to deal with the input sequence

  • The softmax-normalized attention weights determine how much each token contributes to the final output. In the case of self-attention, these weights are computed based on the relevance of each token to every other token.
  • When the softmax operation is applied to the dot products of query and key, it produces a probability distribution over all tokens in the sequence. The resulting weights represent the relative importance or attention given to each token.
  • The weighted aggregation step (wei @ v) effectively computes a weighted sum of the values (v) based on these attention weights.
  • Since the attention weights are normalized probabilities, the weighted sum is essentially a weighted average. Tokens that are more relevant or attended to receive higher weights, contributing more to the final output. Tokens with lower weights contribute less, effectively receiving a lower weight in the averaging process.

In summary, the self-attention mechanism allows each token to contribute to the final output in a way that is influenced by the relevance of that token to the entire sequence. The softmax normalization ensures that the weights form a probability distribution, making the aggregation akin to a weighted average over all tokens.

import torch
import torch.nn as nn
import torch.nn.functional as F

# Tokenized sentence
tokens = ["The", "cat", "sat", "on", "the", "mat"]
tokenized_sentence = torch.randn(1, len(tokens), 8) # Assuming word embeddings of size 8

C = tokenized_sentence.size(-1)
head_size = 4 # Example head size

key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(tokenized_sentence)
q = query(tokenized_sentence)
v = value(tokenized_sentence)

wei = q @ k.transpose(-2, -1)

T = len(tokens)
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
out = wei @ v

import matplotlib.pyplot as plt

# Visualization of attention weights
plt.imshow(wei.squeeze().detach().numpy(), cmap='viridis', aspect='auto')
plt.xticks(range(len(tokens)), tokens, rotation=45, ha='right')
plt.yticks(range(len(tokens)), tokens)
plt.xlabel('Query Token')
plt.ylabel('Key Token')
plt.title('Self-Attention Weights')
plt.colorbar()
plt.tight_layout()
plt.show()

The code above implements the self-attention logic. Fig.4 shows the final “attention-average” result. The plot visualizes the attention weights as a heatmap. Each cell (i, j) in the heatmap represents the attention weight assigned by the token at position j (key) to the token at position i (query).

  • X-axis (bottom): Query Tokens
  • Y-axis (left): Key Tokens
  • Color intensity: Indicates the strength of attention (brighter colors mean higher attention)

Not only you can appreciate the triangular averaging, but also you can follow how attention re-directs each token ( Brighter cells indicate higher attention weights, suggesting stronger relationships between those tokens), based on what could be the guess after a given input.

Fig.4: Softmax average. In this example we can see how the self-attention averaging is able to teach each token what is it’s most natural next token. Image by the author.

A reverse-engineer approach to understand what makes attention so powerful

Fig.5 is a bare-bone scheme of a transformer. It starts with positional encoding and word embedding. Then, there are a series of blocks, the attention MLP, that are processing the input information, working with memory. Eventually, there’s an output layer where tokens are decoded and softmax is applied (logits).

Fig.5: Skeleton of a transformer model. Image by the author.

In the transformer model, the what’s so-called residual blocks, consisting of an attention layer and an MLP layer, contribute significantly to the model’s main function. However, we can simplify this view and pick up each single piece. If we neglect all the attention layers, we’ll have a direct connection between the input tokens to embeddings to the decoder layer and the final logits output. This can be sketched as a direct stream so that there is a direct communication channel, made up of linear transformations. Such a stream can be simply defined as x through equation 4:

Eq.4: Embedding weights acting on the input tokens

where Wₑ is the embedding matrix and inpt represents the input tokens.

Following, we’ll have the additional information brought by the attention heads, that can be described as:

Eq.5: The direct stream of communication between all the layers is enriched by attention layers

The core point here is represented by the h(x) function, namely the attention heads functions. Each head function can be represented by two main matrices:

Eq.6: Each attention head function can be decomposed into the product of a residual stream weight Wo and a value matrix Wv

From what we learned above we can see that, Wᵥxj is the value vector, that’s produced from the combination of all the previous tokens in the input sequence; Wₒ is the weight that takes into account all the residual streams (output weight). This result can be re-written as:

Eq.7: The attention head function can be re-written as a function of the virtual weight WoWv

This is a tensor product. In this product, we can see that the attention head is the real engine of the entire transformer architecture. On a qualitative level, we can see that the attention pattern, depicted by the attention weights, is independent of the weights’ action. Moreover, what the attention head is doing is moving the information from the residual stream of one token to the residual stream of the next token. Therefore, what it does is to pick some subspace info from a token and re-write that to a different subspace that belongs to the subsequent token. This is a fascinating consequence of the linear nature of the residual stream. The weight term WₒWᵥ is called virtual weight. These weights connect layers across the residual stream, facilitating the flow of information. Virtual weights, essentially the product of output weights of one layer with the input weights of another, describe how extensively a later layer reads the information written by a previous layer.

If we consider the entire architecture, we can extend the result above, so that different layers can send information to distinct subspaces within the embedded vector space. This becomes particularly crucial for attention heads, which operate on smaller subspaces. The dimensions of the residual stream act as memory or bandwidth, and understanding their allocation is key to deciphering the transformer’s inner workings.

Conclusions

As we journeyed through the historical development of attention mechanisms, from convolutional neural networks to the transformative power of transformers, a clear narrative emerged.

Attention, the compass of focus guiding us through the vast landscape of information, entwines itself intimately with memory. Together, they form a duo orchestrating the symphony of cognition. In the realm of artificial intelligence, transformers, with their attention mechanisms, have reshaped how we process information, ushering in unprecedented advancements in natural language processing, image recognition, and beyond.

The reverse-engineering approach undertaken here delves into the mathematical underpinnings of attention mechanisms. From the implicit attention in convolutional neural networks to the associative attention in Neural Machine Translation, the narrative evolves. The introduction of memory, as seen in the Neural Turing Machine, adds an introspective layer to attention, allowing the network to focus on specific events in time.

The self-attention mechanism, epitomized in the revolutionary “Attention is all you need” paper, takes center stage. We tried to understand how each token communicates with others, following both a qualitative and a reverse-engineer approach. The power of attention lies not just in its ability to focus but in its capacity to distribute importance across a sequence. The attention heads, acting as the engine of transformers, reshape information, enabling the flow of knowledge across layers. The virtual weights, a consequence of linear interactions, emerge as key players, connecting layers and facilitating information flow.

In conclusion, attention emerges as a monolith in AI, allowing models to recognize patterns, and process information in a way that mirrors human cognition. It’s undoubtedly that in the next months, more and more powerful models will be created, and it’s great we’re all witnesses of this AI evolution and revolution.

Useful references

To write this post I made us of a lot of references. I think the following is what made the most:

  • A Mathematical Framework for Transformer Circuits This is a fantastic effort, made to understand how attention really works under the hoods. My kudos and well done to this entire team, that are creating a wonderful mathematical effort.
  • LLM Visualisation If you want to test and play with attention, I strongly suggest you to visit this website. Brendan Bycroft made a wonderful work, to have a (finally!) clear visualizaiton of all the steps for the most useful GPT-based models

--

--

Stefano Bosisio

Machine Learning Engineer, PhD in Computational Chemistry. My writing covers neuroscience research, coding tutorial and social-media analyses