In this blog we’re going to walk through creating and training a transformer from scratch. We’ll go through each foundational element step by step and explain what is happening along the way. This blog is written in a Jupyter notebook which you can download and use to run the code yourself as you follow along. Running the code as you follow along and changing it to see how the output changes will help you learn the concepts better than reading alone. While this is a lengthy topic, please don’t be too alarmed with the length of the notebook or the amount of code. Most of it is copied from previous cells as we build up the transformer. Rather than just showing the code that was changed which would have shortened things up considerably, I chose to copy all required code down to the next cell to allow this entire notebook to be run from top to bottom. This should make it easier to run as well as allow you to experiment with each new concept as we go.
Let’s open it up and take a peek at what’s inside.
withopen('input.txt') as f: text = f.read()print('Length of input.txt (characters):',len(text))print('First 500 characters:',text[:500])
Length of input.txt (characters): 1115394
First 500 characters: First Citizen:
Before we proceed any further, hear me speak.
All:
Speak, speak.
First Citizen:
You are all resolved rather to die than to famish?
All:
Resolved. resolved.
First Citizen:
First, you know Caius Marcius is chief enemy to the people.
All:
We know't, we know't.
First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?
All:
No more talking on't; let it be done: away, away!
Second Citizen:
One word, good citizens.
First Citizen:
We are accounted poor
2 Tokenization
Now that we have our dataset the first thing we need to do is break it down into tokens which will be fed into the model. This is known as tokenization. There are various techniques for tokenization, sub-word being used in most of the more modern LLM’s. Because our dataset is small for this toy problem, we are instead going to tokenize by each individual character. This makes tokenzation much simpler to implement and will also significantly reduce the number of unique tokens. The first thing we need to do is to go through the dataset and get a list of all unique characters or tokens which is known as the vocab.
The neural networks are unable to take in characters directly. We must first convert the tokens to integers, which in our case means converting each individual character into integers. Those integers will then be used to index into the set of token embeddings. Token embeddings are learned vectors that represent each token that will be passed into the model. Let’s create our character to index (integer) mapping.
char2idx = {char:idx for idx,char inenumerate(vocab)}idx2char = {idx:char for char,idx in char2idx.items()}encode =lambda x: [char2idx[char] for char in x]decode =lambda idxs: ''.join([idx2char[idx] for idx in idxs])print('Character to index:',char2idx)print('Index to character:,',idx2char)print('Tokenization of `Hello World!`:',encode('Hello World!'))print('String for token sequence `[20, 43, 50, 50, 53, 1, 35, 53, 56, 50, 42, 2]`:',decode([20, 43, 50, 50, 53, 1, 35, 53, 56, 50, 42, 2]))
Now that we have our tokenizer we can go through and tokenize the dataset. This step may sometimes be done on the fly during training and inference instead of ahead of time with large datasets, but for our small dataset we can easily and quickly do this without significant memory utilization.
We are going to import pytorch and create a tensor for our encoded dataset. Pytorch is the library we’ll be using to create and train the model.
import torchencoded_text = torch.tensor(encode(text))print('Encoded Text shape:',encoded_text.shape, 'Encoded Text Dtype:', encoded_text.dtype)encoded_text
Encoded Text shape: torch.Size([1115394]) Encoded Text Dtype: torch.int64
tensor([18, 47, 56, ..., 45, 8, 0])
Now that we have the encoded text, we need to split it into a training and validation set. Creating a validation set will allow us to test the performance of our model after each epoch, or at an another interval of our choosing, to ensure that it’s training correctly or improving over time. A typical split between training and validation is 90/10 which is what we’ll use.
train_data = encoded_text[:train_split_idx]valid_data = encoded_text[train_split_idx:]print('Train data length:',len(train_data),'Valid data length:',len(valid_data),'Train percentage:',len(train_data)/len(encoded_text))
Train data length: 1003854 Valid data length: 111540 Train percentage: 0.8999994620734916
Next we need to choose our context length. The context lenth is the maximum length of the sequence used when training the transformer. This is sometimes also referred to as the block size, which is how Andrej refers to it. When the transformer is trained, it is trained on each combination of tokens up to the maximum context length. For example if the context length was 5 the transformer would be trained on (0,1), (0,1,2), (0,1,2,3), and (0,1,2,3,4). Let’s look at a more concrete example in code.
idx: 0 x: tensor([18]) y: tensor(47) | decoded version: x: F y: i
idx: 1 x: tensor([18, 47]) y: tensor(56) | decoded version: x: Fi y: r
idx: 2 x: tensor([18, 47, 56]) y: tensor(57) | decoded version: x: Fir y: s
idx: 3 x: tensor([18, 47, 56, 57]) y: tensor(58) | decoded version: x: Firs y: t
idx: 4 x: tensor([18, 47, 56, 57, 58]) y: tensor(1) | decoded version: x: First y:
idx: 5 x: tensor([18, 47, 56, 57, 58, 1]) y: tensor(15) | decoded version: x: First y: C
idx: 6 x: tensor([18, 47, 56, 57, 58, 1, 15]) y: tensor(47) | decoded version: x: First C y: i
idx: 7 x: tensor([18, 47, 56, 57, 58, 1, 15, 47]) y: tensor(58) | decoded version: x: First Ci y: t
TORCH_SEED =1337#Setting a manual torch seed for reproducable resultstorch.manual_seed(TORCH_SEED) #Used to compare against @karpathy's lecturecontext_length =8#Maximum number of tokens used in each training sequencebatch_size =4#number of batches that will be trained in parallel.
3 Data Loader
Now we’ll implement a function to get a batch of data from our training and validation datasets. We’ll specify which dataset to pull from as a parameter and return the inputs and targets which we’re naming x and y respectively. We’ll go ahead and print out everything through the process so you can see what’s going on.
def get_batch(train_valid): data = train_data if train_valid =='train'else valid_data data_len =len(data) start_idxs = torch.randint(high=len(data) - context_length, size=(batch_size,)) #tensor([ 76049, 234249, 934904, 560986])# print('start_idxs:',start_idxs) x = torch.stack([data[i:i+context_length] for i in start_idxs]) y = torch.stack([data[i+1:i+context_length+1] for i in start_idxs])return x,yxb, yb = get_batch('train')print('inputs:')print('shape:',xb.shape)print(xb)print('targets:')print('shape:',yb.shape)print(yb,'\n-----------------------------------------------')for batch_idx inrange(batch_size):for sequence_idx inrange(context_length): context = xb[batch_idx,:sequence_idx+1] target = yb[batch_idx,sequence_idx]print(f"Given input context ({context.tolist()}) the target is: {target}")
inputs:
shape: torch.Size([4, 8])
tensor([[24, 43, 58, 5, 57, 1, 46, 43],
[44, 53, 56, 1, 58, 46, 39, 58],
[52, 58, 1, 58, 46, 39, 58, 1],
[25, 17, 27, 10, 0, 21, 1, 54]])
targets:
shape: torch.Size([4, 8])
tensor([[43, 58, 5, 57, 1, 46, 43, 39],
[53, 56, 1, 58, 46, 39, 58, 1],
[58, 1, 58, 46, 39, 58, 1, 46],
[17, 27, 10, 0, 21, 1, 54, 39]])
-----------------------------------------------
Given input context ([24]) the target is: 43
Given input context ([24, 43]) the target is: 58
Given input context ([24, 43, 58]) the target is: 5
Given input context ([24, 43, 58, 5]) the target is: 57
Given input context ([24, 43, 58, 5, 57]) the target is: 1
Given input context ([24, 43, 58, 5, 57, 1]) the target is: 46
Given input context ([24, 43, 58, 5, 57, 1, 46]) the target is: 43
Given input context ([24, 43, 58, 5, 57, 1, 46, 43]) the target is: 39
Given input context ([44]) the target is: 53
Given input context ([44, 53]) the target is: 56
Given input context ([44, 53, 56]) the target is: 1
Given input context ([44, 53, 56, 1]) the target is: 58
Given input context ([44, 53, 56, 1, 58]) the target is: 46
Given input context ([44, 53, 56, 1, 58, 46]) the target is: 39
Given input context ([44, 53, 56, 1, 58, 46, 39]) the target is: 58
Given input context ([44, 53, 56, 1, 58, 46, 39, 58]) the target is: 1
Given input context ([52]) the target is: 58
Given input context ([52, 58]) the target is: 1
Given input context ([52, 58, 1]) the target is: 58
Given input context ([52, 58, 1, 58]) the target is: 46
Given input context ([52, 58, 1, 58, 46]) the target is: 39
Given input context ([52, 58, 1, 58, 46, 39]) the target is: 58
Given input context ([52, 58, 1, 58, 46, 39, 58]) the target is: 1
Given input context ([52, 58, 1, 58, 46, 39, 58, 1]) the target is: 46
Given input context ([25]) the target is: 17
Given input context ([25, 17]) the target is: 27
Given input context ([25, 17, 27]) the target is: 10
Given input context ([25, 17, 27, 10]) the target is: 0
Given input context ([25, 17, 27, 10, 0]) the target is: 21
Given input context ([25, 17, 27, 10, 0, 21]) the target is: 1
Given input context ([25, 17, 27, 10, 0, 21, 1]) the target is: 54
Given input context ([25, 17, 27, 10, 0, 21, 1, 54]) the target is: 39
You may have noticed that the inputs and targets appear to be the same, just shifted by one, which is correct. To help give you a better understanding on how the inputs and targets are combined, we printed out every combination of input and target. As you can see each input sequence is really multiple sequences starting with the first token in the sequence as the input and the second token in the sequence being the target all the way to the full input sequence being the input and the subsequent character being the target.
4 Bigram Model
Before jumping into using a transformer we’ll start with a bigram model. A bigram model predicts the probability of one token following another. For example given the token for the letter ‘a’ what is the probability of each token in the vocab will be the next token.
import torch.nn as nnfrom torch.nn import functional as F
torch.manual_seed(TORCH_SEED)class BigramLanguageModel(nn.Module):def__init__(self, vocab_size):super().__init__()self.vocab_size = vocab_size#Each token reads off the logits (~probabilities) from the subsequent token from the lookup tableself.token_embedding_table = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.vocab_size)def forward(self, idx, targets):#Both idx and targets are (B,T) Batch x Time array of integers logits =self.token_embedding_table(idx) #(B,T,C) Batch, Time, Channelreturn logitsbigram_model = BigramLanguageModel(vocab_size=vocab_size)out = bigram_model(xb, yb)print('Bigram Model Output Shapes out:',out.shape,'xb:',xb.shape,'yb:',yb.shape,'embeddings:',bigram_model.token_embedding_table)
Now we’ll add a loss function to the forward method. We will use the negative log likelihood, which is known in pytorch as the cross entropy loss. To be able to use the cross entropy loss we’ll need to reshape the output and targets to match the format that it expects. The model output should be a 2D tensor (B*T x C) and targets should be a 1D tensor (B*T). We need to squash the batch and time dimensions on the model output and the batch and time dimensions of the targets using the .view method.
torch.manual_seed(TORCH_SEED)class BigramLanguageModel(nn.Module):def__init__(self, vocab_size):super().__init__()self.vocab_size = vocab_size#Each token reads off the logits from the subsequent token from the lookup tableself.token_embedding_table = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.vocab_size)def forward(self, idx, targets):#Both idx and targets are (B,T) Batch x Time array of integers logits =self.token_embedding_table(idx) #(B,T,C) Batch, Time, Channel B,T,C = logits.shape logits_reshaped = logits.view(B*T,C) targets_reshaped = targets.view(B*T) loss = F.cross_entropy(input=logits_reshaped, target=targets_reshaped)return logits, lossbigram_model = BigramLanguageModel(vocab_size=vocab_size)out,loss = bigram_model(xb, yb)print('Bigram Model Output Shapes out:',out.shape,'xb:',xb.shape,'yb:',yb.shape)print('The calculated loss is:',loss)
Bigram Model Output Shapes out: torch.Size([4, 8, 65]) xb: torch.Size([4, 8]) yb: torch.Size([4, 8])
The calculated loss is: tensor(4.8786, grad_fn=<NllLossBackward0>)
Now we’re going to add a generate method to perform character generation from our model. Instead of explaining here what each generation step is doing, detailed code comments have been added before each step. There are also commented out print statements that you can uncomment and run to help you understand what’s going on as well.
One pytorch function you may not have seen before is torch.multinomial. It returns num_samples, in our case 1, based on the weighted probability distribution of the predictions for each token.
torch.manual_seed(TORCH_SEED)class BigramLanguageModel(nn.Module):def__init__(self, vocab_size):super().__init__()self.vocab_size = vocab_size#Each token reads off the logits from the subsequent token from the lookup tableself.token_embedding_table = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.vocab_size)def forward(self, idx, targets=None):#Both idx and targets are (B,T) Batch x Time array of integers logits =self.token_embedding_table(idx) #(B,T,C) Batch, Time, Channelif targets isnotNone: B,T,C = logits.shape logits_reshaped = logits.view(B*T,C) targets_reshaped = targets.view(B*T) loss = F.cross_entropy(input=logits_reshaped, target=targets_reshaped)else: loss=Nonereturn logits, lossdef generate(self, idx, max_new_tokens):#idx is (B,T) array of indicies in the current contextfor _ inrange(max_new_tokens):#Get predictions - calling `forward` logits, loss =self(idx)#Get the last time step from logits where the dimensions of the logits are (B,T,C) logits_last_timestep = logits[:,-1,:] #Becomes (B,C)# print('Shape of logits_last_timestep:',logits_last_timestep.shape) #confirming shape#Apply softmax to get probabilities probs = F.softmax(input=logits_last_timestep, dim=-1) #(B,C)# print('Shape of probs:', probs.shape) #confirming shape#Sample from the probs distribution. idx_next = torch.multinomial(input=probs, num_samples=1) #(B,1) Returns (B,idxs) where idxs are random integer indicies.# print('Shape of idx_next:',idx_next.shape,'and contents:',idx_next) #look at the shape and contents of idx_next#Append the sampled indexes idx_next to idx idx = torch.cat((idx, idx_next), dim=1) #(B, T+1)return idxbigram_model = BigramLanguageModel(vocab_size=vocab_size)logits, loss = bigram_model(xb,yb)print('Loss:',loss)idx=torch.zeros((1,1), dtype=torch.long)#Creating a single batch with single time step with index of 0 #which is the `\n` new line char to be our starting point.#Now let's generate a character sequence to see what it looks like.print('100 Generated Tokens:',decode(bigram_model.generate(idx,max_new_tokens=100)[0].tolist()))
For more info on the Adam optimizer or optimizers in general, check out this great video from fast.ai.
Next, we’ll create a basic training loop.
batch_size =32for steps inrange(100):#Sample a batch of data xb, yb = get_batch('train')#Evaluate the loss logits, loss = bigram_model(xb, yb) optimizer.zero_grad() loss.backward() optimizer.step()print('Loss:',round(loss.item(),3), end=' ')
As you can see the loss is improving, but only a little bit. Next we’ll train for longer.
batch_size =32for steps inrange(1000):#Sample a batch of data xb, yb = get_batch('train')#Evaluate the loss logits, loss = bigram_model(xb, yb) optimizer.zero_grad() loss.backward() optimizer.step()if steps %100==0: print('Step:',steps,'Loss:',round(loss.item(),3), end=' ')
The loss seems to be steadily decreasing. Let’s train for a lot longer and then see what the results look like.
batch_size =32for steps inrange(10000):#Sample a batch of data xb, yb = get_batch('train')#Evaluate the loss logits, loss = bigram_model(xb, yb) optimizer.zero_grad() loss.backward() optimizer.step()if steps %100==0: print('Step:',steps,'Loss:',round(loss.item(),3), end=' ')
weangond
OMave wap
I RO:
Banleenoalit-blt
INRon
UM: nd kngonesll;
O: pa heore 'ga llis?-sur inidind;
t me rthay n thavee
Sw s serer Fofow.
Houspathe t:
Mind fit.
DUKINoceamy hun.
CKIUShorst onre t ache bar, simed?
And me theluse BHENurind hesto f w m CK:
YCESI fatass mbre lious ave
Wer'dor' wod y:
Henkns ges wise we me y to elil'doug p in t her spalisusin t wndalu?Y ber lishms vekeang-lumod n odas ine a! thayayor hannd t; frat.
OLArZAUSum,
s I f pin hondecharvyouke helldid t we ke,
HOShe lll
The results are still not great, but they’re looking more reasonable than they were at first. There are now things that look like words and sentences. Keep in mind this model is predicting the next token or character solely based on the previous token so it doesn’t really have a lot to go by. Learning about the dataset, training loop and getting a basline with a simplistic model is always a good first step in any AI project. This allows you to see if further improvements are helping as well as gives you a baseline to compare your model with.
Before moving on, let’s do a quick loop through the validation set to see what its loss is so we can use it for comparison in the future.
print('Length of Validation Set: ',len(valid_data),'Shape of Batch:',xb.shape,'~Steps needed to cover validation set:',len(valid_data)//batch_size//8)
Length of Validation Set: 111540 Shape of Batch: torch.Size([32, 8]) ~Steps needed to cover validation set: 435
batch_size =32torch.manual_seed(TORCH_SEED)losses = []for steps inrange(len(valid_data)//batch_size//8):#Sample a batch of data xb, yb = get_batch('valid')with torch.no_grad():#Evaluate the loss logits, loss = bigram_model(xb, yb) losses.append(loss)if steps %10==0: print('Step:',steps,'Loss:',round(loss.item(),3), end=' ')print('\n\nFinal Validation Loss:',torch.stack(losses,dim=0).mean())
The final validation loss appears similar to the training loss which makes sense given how basic this model is. Now we’ll move on to creating a transformer and see how it does compared to the our basic bigram model.
Clearing out variables from above to start fresh.
%reset -f
5 Code re-write in preparation for Transformers
We’re going to re-write some of our code from before to clean things up before jumping into building the transformer.
import torchimport torch.nn as nnfrom torch.nn import functional as F
#Hyperparametersbatch_size =32#Number of token chunks per batchcontext_length =8#Length of the token chunks. Andrej called this block sizelearning_rate =1e-2max_iters =3000#Number of training iterations or steps. Typically we specify the number of epochs but since we're randomly sampling...#...from the training set we don't necessarily know exactly when we've seen all the text from the training set so we'll use this instead.eval_interval =300#Number of steps between evaluating the validation set to see how our validation loss is doing.eval_iters =200#Number of steps to do on the validation set per each interval. We do more than 1 to get a more accurate overall valid lossdevice ='cuda'if torch.cuda.is_available() else'cpu'#Instead of using the cpu, we'll use the GPU if it's availble.TORCH_SEED =1337torch.manual_seed(TORCH_SEED)
<torch._C.Generator>
withopen('input.txt','r',encoding='utf-8') as f: text = f.read()print('Length of text:',len(text))
Length of text: 1115394
vocab =sorted(list(set(text))) #Called chars in the video, but vocab is a more generic term. Both are correct.vocab_size =len(vocab)print('Vocab size:',vocab_size,',Vocab:',vocab)
#The first 90% of data will be the training set, the rest will be validation set.train_test_split_idx =int(len(tokenized_text) *0.9)print('Train test split index:',train_test_split_idx)train_data = tokenized_text[:train_test_split_idx]valid_data = tokenized_text[train_test_split_idx:]print('Lenth of training data:',len(train_data),'Length of validation data:',len(valid_data))
Train test split index: 1003854
Lenth of training data: 1003854 Length of validation data: 111540
Next we’ll set up a basic data loader to get data in batches
def get_batch(split:str, batch_size:int=batch_size, context_length:int=context_length):#Function to get a batch of data from the train or valid dataset data = train_data if split =='train'else valid_data idxs = torch.randint(low=0, high=len(data)-context_length, size=(batch_size,)) x = torch.stack([data[idx:idx+context_length] for idx in idxs]) y = torch.stack([data[idx+1:idx+context_length+1] for idx in idxs]) x,y = x.to(device), y.to(device) #Send data to the GPU if availablereturn x,ybx,by = get_batch('train')print('Batch of x shape:',bx.shape,'Batch of y shape:',by.shape,' dimension are [batch x context_length]')del bx, by
Batch of x shape: torch.Size([32, 8]) Batch of y shape: torch.Size([32, 8]) dimension are [batch x context_length]
Next we’ll create a function to estimate the loss for our model. Typically this is calculated against the training set for each training step and at the end of each epoch for the validation set but to keep things simple we’ll just calculate it when called based on the number of steps specified as eval_iters and take the mean for the training and validation sets respectively. This also helps smooth out the loss values.
@torch.no_grad()def estimate_loss(): out = {} model.eval()for split in ['train','valid']: losses = torch.zeros(eval_iters)for k inrange(eval_iters): x_batch, y_batch = get_batch(split) logits, loss = model(x_batch, y_batch) losses[k] = loss.item() out[split] = losses.mean() model.train()return out
We’re going to keep the model framework from the Bigram model and add in the transfomer parts shortly. For now, we just need to make sure that our updated code still works.
torch.manual_seed(TORCH_SEED)class BigramLanguageModel(nn.Module):def__init__(self, vocab_size):super().__init__()self.vocab_size = vocab_size#Each token reads off the logits from the subsequent token from the lookup tableself.token_embedding_table = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.vocab_size)def forward(self, idx, targets=None):#Both idx and targets are (B,T) Batch x Time array of integers logits =self.token_embedding_table(idx) #(B,T,C) Batch, Time, Channelif targets isnotNone: B,T,C = logits.shape logits_reshaped = logits.view(B*T,C) targets_reshaped = targets.view(B*T) loss = F.cross_entropy(input=logits_reshaped, target=targets_reshaped)else: loss=Nonereturn logits, lossdef generate(self, idx, max_new_tokens):#idx is (B,T) array of indicies in the current contextfor _ inrange(max_new_tokens):#Get predictions logits, loss =self(idx)#Get the last time step from logits where the dimensions of the logits are (B,T,C) logits_last_timestep = logits[:,-1,:] #Becomes (B,C)# print('Shape of logits_last_timestep:',logits_last_timestep.shape) #confirming shape#Apply softmax to get probabilities probs = F.softmax(input=logits_last_timestep, dim=-1) #(B,C)# print('Shape of probs:', probs.shape) #confirming shape#Sample from the probs distribution. idx_next = torch.multinomial(input=probs, num_samples=1) #(B,1) Returns (B,idxs) where idxs are random integer indicies.# print('Shape of idx_next:',idx_next.shape,'and contents:',idx_next) #look at the shape and contents of idx_next#Append the sampled indexes idx_next to idx idx = torch.cat((idx, idx_next), dim=1) #(B, T+1)return idxmodel = BigramLanguageModel(vocab_size=vocab_size)model = model.to(device)
Step: 0 Training Loss: tensor(4.7305) Validation Loss: tensor(4.7241)
Step: 200 Training Loss: tensor(3.1193) Validation Loss: tensor(3.1313)
Step: 400 Training Loss: tensor(2.6480) Validation Loss: tensor(2.6667)
Step: 600 Training Loss: tensor(2.5434) Validation Loss: tensor(2.5689)
Step: 800 Training Loss: tensor(2.5155) Validation Loss: tensor(2.5248)
Step: 1000 Training Loss: tensor(2.5008) Validation Loss: tensor(2.5115)
Step: 1200 Training Loss: tensor(2.4787) Validation Loss: tensor(2.4992)
Step: 1400 Training Loss: tensor(2.4747) Validation Loss: tensor(2.4983)
Step: 1600 Training Loss: tensor(2.4739) Validation Loss: tensor(2.4940)
Step: 1800 Training Loss: tensor(2.4734) Validation Loss: tensor(2.4947)
Step: 2000 Training Loss: tensor(2.4608) Validation Loss: tensor(2.4934)
Step: 2200 Training Loss: tensor(2.4664) Validation Loss: tensor(2.4994)
Step: 2400 Training Loss: tensor(2.4763) Validation Loss: tensor(2.4762)
Step: 2600 Training Loss: tensor(2.4708) Validation Loss: tensor(2.4858)
Step: 2800 Training Loss: tensor(2.4578) Validation Loss: tensor(2.4950)
Step: 2999 Training Loss: tensor(2.4621) Validation Loss: tensor(2.4919)
CEThik brid owindakis by ble
HAPORDurayou e.
S:
O:3 my d?
LUCous:
Wanthar u qur, vet?
F dXENDoate awice my.
Hastacom oroup
Yowhthetof isth ble mil ndill, ath ireeesengmin lat Heriliovets, and Win nghirileranousel lind me l.
MAshe ce hiry:
Supr aisspllw y.
Hentoul n Boopetelaves
MPOLI s, d mothakleo Windo whth eisbyo wie m dourive we higend t so mower; te
AN ad nterupirf s ar iris! m:
Thiny aleronth,
Mad
RD:
WISo myrangoube!
KENob&y arardsal thes ghesthidin cour ay aney RDUSts I&fr t ce.
J
6 Previous Token Averages - Building Intuition for Self Attention
Attention was the key discovery that enabled the transformer architecture. The idea is that each token should be able to communicate with or look at each previous token in the sequence but not future tokens. For example given token number 4 in a sequence of 8 tokens, token 4 should be able to access token 1, 2 and 3, but not tokens 5 through 8. We will demonstrate this with a for loop implementation to cement the concept and then show the equivalent calculation using matrix multiplication which is how transformers are implemented in real life because the matrix multiplication is orders of magnitudes faster than basic nested for loops.
torch.manual_seed(TORCH_SEED)B,T,C =4, 8, 2#Batch, Time, Channel - Time is each token (char) in the sequence and channel is the embedding dimensionx = torch.randn((B,T,C))x.shape
torch.Size([4, 8, 2])
In this example we’re going to take the average of the previous tokens, just for illustration purposes, not because the average is a good way to represent the data from previous tokens.
x_bag_of_words = torch.zeros((B,T,C))for batch_idx inrange(B):for token_idx inrange(T): x_previous = x[batch_idx,:token_idx+1] # (T,C) x_bag_of_words[batch_idx,token_idx] = torch.mean(x_previous, dim=0)print('Testing if mean is being calculated correctly. Should be True, True, False:',torch.allclose(x[0,0],x_bag_of_words[0,0]), torch.allclose(x[0,:2].mean(dim=0),x_bag_of_words[0,1]), torch.allclose(x[0,0],x_bag_of_words[0,1]))
Testing if mean is being calculated correctly. Should be True, True, False: True True False
Each item in x_bag_of_words should be the cumulative mean of all values in x up to that index. For index 0 you can see the results are the same and for index 1 you can quickly recognize that in fact x_bag_of_words is the cumulative mean of x at index 0 and 1.
Next, we’ll delve into the basic operations of matrix multiplication. Specifically, we’re focusing on the multiplication of two matrices, denoted as a and b. In Python, matrix multiplication is symbolized using the @ operator. The process of matrix multiplication entails multiplying the rows of the first matrix (in this case, matrix a) by the columns of the second matrix (here, matrix b) when dealing with 2-dimensional matrices. After multiplying, the results are summed up to give the final outcome.
Let’s consider a practical example to illustrate this: calculating the value of c[0,0]. To achieve this, we need to multiply the first row of matrix a with the first column of matrix b, and then sum the results.
In mathematical terms, it looks like this: (a[0,0] * b[0,0]) + (a[0,1] * b[1,0]) + (a[0,2] * b[2,0]). Using the example values below, the equation becomes (1 * 1) + (2 * 3) + (3 * 5) = 22.
To enhance your understanding and intuition of matrix multiplication, I recommend the website http://matrixmultiplication.xyz/. It provides interactive visualizations of the matrix multiplication process, which can help you to understand the concept.
a = torch.tensor([[1,2,3],[4,5,6],[7,8,9]])b = torch.tensor([[1,2],[3,4],[5,6]])c = a @ bprint('a =\n',a,'\n-----')print('b =\n',b,'\n-----')print('c = a @ b =\n',c,'\n-----')
a =
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
-----
b =
tensor([[1, 2],
[3, 4],
[5, 6]])
-----
c = a @ b =
tensor([[ 22, 28],
[ 49, 64],
[ 76, 100]])
-----
The next function we need to learn about is the pytorch tril function. It zeros out the upper right portion of a matrix. Let’s look at a few quick examples to help visualize the concept.
Now let’s see what happens when we use the tril function with our matrix multiply from before:
a = torch.tril(torch.tensor([[1,2,3],[4,5,6],[7,8,9]]))b = torch.tensor([[1,2],[3,4],[5,6]])c = a @ bprint('a =\n',a,'\n-----')print('b =\n',b,'\n-----')print('c = a @ b =\n',c,'\n-----')
a =
tensor([[1, 0, 0],
[4, 5, 0],
[7, 8, 9]])
-----
b =
tensor([[1, 2],
[3, 4],
[5, 6]])
-----
c = a @ b =
tensor([[ 1, 2],
[ 19, 28],
[ 76, 100]])
-----
Next we’ll switch out our first matrix with a matrix of ones instead.
a = torch.tril(torch.ones((3,3),dtype=torch.long))b = torch.tensor([[1,2],[3,4],[5,6]])c = a @ bprint('a =\n',a,'\n-----')print('b =\n',b,'\n-----')print('c = a @ b =\n',c,'\n-----')
a =
tensor([[1, 0, 0],
[1, 1, 0],
[1, 1, 1]])
-----
b =
tensor([[1, 2],
[3, 4],
[5, 6]])
-----
c = a @ b =
tensor([[ 1, 2],
[ 4, 6],
[ 9, 12]])
-----
Let’s examine the first row in c which contains the values [1, 2]. Notice that it is the cumulative sum of the first row of b and rows 2 and 3 in b are effectively masked out. Row 2 in c is the cumulative sum of rows 1 and 2 from b and so on. Hopefully you can see that, like in our for loop, we are effectively cumulatively summing up b.
Now we’re going to reproduce our for loop results using matrix multiplication.
x.shape, x_bag_of_words.shape
(torch.Size([4, 8, 2]), torch.Size([4, 8, 2]))
mask = torch.tril(torch.ones((1,8,8),dtype=torch.float))a = maskb = xc = a @ bprint('Shape of c:',c.shape)c
Comparing to x_bag_of_words, we can see the values are different. This is because we have calculated the cumulative sum, but not the mean. This is why at each batch’s idx 0 the answer is correct as it’s only summing 1 item and the sum divided by 1 is both the mean and sum.
Now we’ll divide by the number of items that were summed to get the mean. Just like before we’re going to use a matrix operation to keep things fast. First we’ll create the matrix:
And check whether it matches our previously calculated bow, which it does.
torch.allclose(x_bag_of_words, c2)
True
As an aside, let’s look at how long the x_bag_of_words took to calculate via a for loop vs our matrix version. As the number of calculations is tiny in this case, we’re going to increase the complexity so you can more easily see the time differences. First we’ll do the for loop version:
torch.manual_seed(TORCH_SEED)# B,T,C = 4, 8, 2 #Batch, Time, Channel - Time is each token (char) in the sequence and channel is the embedding dimensionB,T,C =256, 512, 64x = torch.randn((B,T,C))x.shapex_bag_of_words = torch.zeros((B,T,C))
%%timeit -n 2-r 2#| output: truefor batch_idx inrange(B):for token_idx inrange(T): x_previous = x[batch_idx,:token_idx+1] # (T,C) x_bag_of_words[batch_idx,token_idx] = torch.mean(x_previous, dim=0)print('Testing if mean is being calculated correctly. Should be True, True, False:',torch.allclose(x[0,0],x_bag_of_words[0,0]), torch.allclose(x[0,:2].mean(dim=0),x_bag_of_words[0,1]), torch.allclose(x[0,0],x_bag_of_words[0,1]))
Testing if mean is being calculated correctly. Should be True, True, False: True True False
Testing if mean is being calculated correctly. Should be True, True, False: True True False
Testing if mean is being calculated correctly. Should be True, True, False: True True False
Testing if mean is being calculated correctly. Should be True, True, False: True True False
4.84 s ± 22.3 ms per loop (mean ± std. dev. of 2 runs, 2 loops each)
The for loop version on my PC took 4.3 seconds: 4.33 s ± 24.2 ms per loop (mean ± std. dev. of 2 runs, 2 loops each). Now we’ll test the matrix version:
mask = torch.tril(torch.ones((1,T,T),dtype=torch.float))num_toks_summed = torch.Tensor([[i+1]*C for i inrange(T)])a = maskb = x
%%timeit -n 2-r 2#| output: truec = a @ b / num_toks_summed# print('Shape of c:',c.shape)
23.6 ms ± 1.17 ms per loop (mean ± std. dev. of 2 runs, 2 loops each)
The matrix multiplication version on the other hand took 20 ms. 19.8 ms ± 872 µs per loop (mean ± std. dev. of 2 runs, 2 loops each) Now we’ll confirm the results are the same
c = a @ b / num_toks_summed
torch.allclose(x_bag_of_words, c, atol=1e-5)
True
4.33/(19.8/1000) # 4.33 seconds / 19.8 ms
218.68686868686868
As you can see the results were the same but the matrix multiplication version finished in ~20ms vs ~4.3s (numbers may vary between runs and machines) for the for loop version which is ~220x faster and that is running the matmul on the cpu instead of gpu which would speed it up even further.
Alternatively we can use Andrej’s method which should be even more efficient because we’re dividing the mask elements by the number of elements to be summed which is effectively moving the division by num_toks_summed out of our final timed calculation.
%%timeit -n 2-r 2#| output: truec = a @ b# print('Shape of c:',c.shape)
18.3 ms ± 289 µs per loop (mean ± std. dev. of 2 runs, 2 loops each)
20.7 ms ± 1.21 ms per loop (mean ± std. dev. of 2 runs, 2 loops each) In this case the numbers are very close. This could be influenced by a number of factors, but in general it’s best to try and compute things once vs for each iteration if possible.
c = a @ b
torch.allclose(x_bag_of_words, c, atol=1e-5)
True
Now let’s look at the 3rd version of calculating this using softmax. It should produce an identical result.
18.8 ms ± 798 µs per loop (mean ± std. dev. of 2 runs, 2 loops each)
20.5 ms ± 1.97 ms per loop (mean ± std. dev. of 2 runs, 2 loops each)
torch.allclose(x_bag_of_words, c, atol=1e-5)
True
Softmax is used to create probabilities that previous tokens will interact with the current token. Intuitvely you can think of this as not all previous tokens in a sequence carry equal weight or importance so this allows the model to assign a weigting to all previous tokens. All probabilities output from softmax add up to 1. -inf is used to mask out future tokens which should not be accesible to the current token. Running -inf through softmax yields a 0 probability. In the transformer instead of the mask values being either identical or 0 (ex: [0.5, 0.5, 0, 0]) the weighting of past tokens will be learned. i.e. certain tokens can communicate with certain other tokens with a stronger weight based on learned values from the data. Andrej calls this wei for ‘weights’, which is the weighting used to determine which tokens communicate with eachother. This is different that what we’re doing now which is forcing all non-0 weights to be the same. What we’re doing now is just building up the math and intuition on how attention works.
7 Self attention
7.1 Initial Code Setup
To start with we’re going to modify our BigramLanguageModel to be a TransformerLanguageModel class.
embedding_dim =32#The vector size of the token embeddings. Andrej used n_embed as the variable name.
We’re going to add an embedding dimension, change up the token embedding table and modify the token embedding lookup and logits calculation as we work towards modifying this class into a true transformer. We’ll iteratively test as we go to make sure it is still able to train correctly. Please read through the below code taking note of the comments explaining the changes being made.
%%time#| output: truetorch.manual_seed(TORCH_SEED)class TransformerLanguageModel(nn.Module):def__init__(self, vocab_size=vocab_size, embedding_dim=embedding_dim):super().__init__()self.vocab_size = vocab_sizeself.embedding_dim = embedding_dim#A basic linear layer to pass our token embeddings through. This is a preliminary step, not the final network.#Note the input size in the embedding_dim and the output size is the number of tokens or vocab size.#This is because we are going to be predicting the probability for every token in the vocab that it is the next token.self.language_model_head_linear_layer = nn.Linear(in_features=self.embedding_dim, out_features=self.vocab_size)#This will be our lookup table for all of the token embeddings. We'll have an entry for each token (aka vocab size)... #...and each embedding will be a vector of dimension embedding_dim.self.token_embedding_table = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.embedding_dim)def forward(self, idx, targets=None):#Both idx and targets are (B,T) Batch x Time array of integers token_embeddings =self.token_embedding_table(idx) #(B,T,C) Batch, Time, Channel - Lookup token embeddings logits =self.language_model_head_linear_layer(token_embeddings) #(B,T,C) Where C is now token logits of size vocab_sizeif targets isnotNone: B,T,C = logits.shape logits_reshaped = logits.view(B*T,C) targets_reshaped = targets.view(B*T) loss = F.cross_entropy(input=logits_reshaped, target=targets_reshaped)else: loss=Nonereturn logits, lossdef generate(self, idx, max_new_tokens):#idx is (B,T) array of indicies in the current contextfor _ inrange(max_new_tokens):#Get predictions logits, loss =self(idx) #This is calling `forward`#Get the last time step from logits where the dimensions of the logits are (B,T,C) logits_last_timestep = logits[:,-1,:] #Becomes (B,C)# print('Shape of logits_last_timestep:',logits_last_timestep.shape) #confirming shape#Apply softmax to get probabilities probs = F.softmax(input=logits_last_timestep, dim=-1) #(B,C)# print('Shape of probs:', probs.shape) #confirming shape#Sample from the probs distribution. idx_next = torch.multinomial(input=probs, num_samples=1) #(B,1) Returns (B,idxs) where idxs are the token indicies (int).# print('Shape of idx_next:',idx_next.shape,'and contents:',idx_next) #look at the shape and contents of idx_next#Append the sampled indexes idx_next to idx idx = torch.cat((idx, idx_next), dim=1) #(B, T+1)return idx #TODO: Stopped Here model = TransformerLanguageModel(vocab_size=vocab_size, embedding_dim=embedding_dim)model = model.to(device)optimizer = torch.optim.AdamW(params=model.parameters(), lr=learning_rate)for step inrange(max_iters):if step % eval_iters ==0or step == max_iters-1: losses = estimate_loss()print('Step:',step,'Training Loss:',losses['train'],'Validation Loss:',losses['valid']) xb,yb = get_batch('train') logits, loss = model(xb,yb) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step()context = torch.zeros((1,1), dtype=torch.long, device=device)print(decode(model.generate(context,max_new_tokens=500)[0].tolist()))
Step: 0 Training Loss: tensor(4.3278) Validation Loss: tensor(4.3231)
Step: 200 Training Loss: tensor(2.5421) Validation Loss: tensor(2.5626)
Step: 400 Training Loss: tensor(2.4982) Validation Loss: tensor(2.5163)
Step: 600 Training Loss: tensor(2.4936) Validation Loss: tensor(2.5354)
Step: 800 Training Loss: tensor(2.4983) Validation Loss: tensor(2.5067)
Step: 1000 Training Loss: tensor(2.5025) Validation Loss: tensor(2.5045)
Step: 1200 Training Loss: tensor(2.4831) Validation Loss: tensor(2.5028)
Step: 1400 Training Loss: tensor(2.4866) Validation Loss: tensor(2.5157)
Step: 1600 Training Loss: tensor(2.4927) Validation Loss: tensor(2.5120)
Step: 1800 Training Loss: tensor(2.4899) Validation Loss: tensor(2.5120)
Step: 2000 Training Loss: tensor(2.4804) Validation Loss: tensor(2.5071)
Step: 2200 Training Loss: tensor(2.4841) Validation Loss: tensor(2.5178)
Step: 2400 Training Loss: tensor(2.4940) Validation Loss: tensor(2.4883)
Step: 2600 Training Loss: tensor(2.4956) Validation Loss: tensor(2.5065)
Step: 2800 Training Loss: tensor(2.4799) Validation Loss: tensor(2.5138)
Step: 2999 Training Loss: tensor(2.4870) Validation Loss: tensor(2.5165)
CExthy bridcowindakis s, bth
HAPORDurayoule.
S:
O:
IS:
Thachangs ar bthar usqur, vethar dilasoate arche my.
HD:
Yom o mur
Yowhthetof isth bot mil ndill, bes ireeesenghin lat Heridrovets, and Win nghir.
Thabousel lind me l.
HAser ce wiry ptupr aisspllwhy.
HAntoul noroopetelaves
MPOLI swod mothakleo Windo whth eiiby we ath dourive wee, ired t so mo she te
AN ad nterurt f sor; irist m:
Thiny aleronth, af Pre?
WISo myay INouk!
KENoby sarardsal thes ghesthinin cour ay aney RDUES:
I fr t ce.
J
CPU times: user 17 s, sys: 438 ms, total: 17.4 s
Wall time: 17.4 s
We need to also encode the token position so we’ll need to add another embedding table for that which will be learned as well.
7.2 Building Up To Self Attention
We’ll go through the simple cumulative token average again using matrix multiplication and modify it over time to be self attention.
When using the cumulative mean the weights are fixed, but ideally we want the weights to variable and learned so each token can interact with other tokens a varying amount based on learned paramters of what is most important. Some tokens will find other tokens more interesting than others and we want that to be data dependent and learned through training.
The example Andrej gives is “maybe if I’m a vowel token I am more interested in past consonant tokens and I want the [consonant] data to flow to me, this is the problem that self attention solves”. The way that self attention solves this is that every single node or token will emit 2 vectors, a query and key vector. The query vector roughly represents “what am I looking for” and the key vector roughly represents “what do I contain”. The way we then get the affinities between each token is by performing a dot product (matrix multiplication) of all the query vectors against all of the key vectors. So for a given query vector, the dot product is calculated against all of the key vectors and the results of that become the weights that are used for each token. This is the weight variable we used above except now instead of being a fixed average, it varies per token and is learned. If the key and query are aligned they will produce a larger value when the dot product is taken between them which leads to a larger value in the weight matrix.
Let’s take the above example and modify it to implement self attention.
First we need to define our head size. We will use 16. This will be the side dimension of a matrix where each query and key vector are matrix multiplied together. To get the query and key vectors from the token embeddings we first need to run the token embedding through a linear layer for the query and key which will generate a vector of size head_size.
#version 4: self attention#setuptorch.manual_seed(TORCH_SEED)B,T,C = batch_size, context_length, embedding_dimx = torch.randn((B,T,C))print('Batch Size (B):',B,'Context Length (T):',T,'Embedding Dimension (C):',C)head_size =16#Self attention head size#Learned vector to extract key vector from token embedding vectorkey_layer = nn.Linear(in_features=C, out_features=head_size, bias=False) #Learned vector to extract query vector from token embedding vectorquery_layer = nn.Linear(in_features=C, out_features=head_size, bias=False) #Extract query and key values for every token in the batch in parallelkey = key_layer(x) # (B,T,head_size)query = query_layer(x) # (B,T,head_size) #TODO:
Now we will calculate the affinities (weights) between each token in each sequence by matrix multiplying all of the queries and keys. If we simply try to calculate query @ key it will fail because the shapes are not correct to be able to do matrix multiplication. In our case both key and query are of shape (B,T,head_size) which are incompatible shapes to be matrix multiplied together. We need to transpose, or rotate, the key in the T and head_size dimension so they can be matrix multiplied. We cannot simply use the .T transpose because it would transpose in the batch dimension as well which we do not want so instead we’ll specify which dimensions to transpose which we can do by calling key.transpose(-2, -1) which will transpose the last 2 dimensions.
#Calculate affinity (weights)#(B,T,head_size) @ (B, head_size, T) which is (32,8,16) @ (32,16,8) -> (B, T, T) which is (32,8,8)weights = query @ key.transpose(-2, -1) weights.shape
Now we have weights that are calculated based on each token’s affinity to every other token. We then apply the same filtering that we did previously with our cumulative mean so we simply remove the line where the weights were set to zero. This will allow us to finally apply a learned weighting to each previous token embedding.
tril = torch.tril(torch.ones(T,T,dtype=torch.long))# weights = torch.zeros(T,T,dtype=torch.float) #Removed now that weights are calculatedweights = weights.masked_fill(tril ==0, float('-inf')) #Masks future tokensweights = torch.softmax(weights, dim=-1) #Provides even distribution (weights that add up to 1)out = weights @ xout.shape
torch.Size([32, 8, 32])
You can see the weights below. Notice they are no longer uniform. They can now be indivual and learned from the data.
Now each token will be able to calculate its affinity to all other tokens. You can see in the example by looking at the bottom row, that the 8th token has a high affinity for the 6th token because it has the highest value:
There is one more part of self attention we need to look at. That is that when we aggregate the tokens out = weights @ x we don’t aggregate the tokens exactly, we aggregate the value, so in the same way that we calculate key and query via passing the token embedding through a linear layer, we will do the same to get the value.
#version 4: self attention#setuptorch.manual_seed(TORCH_SEED)B,T,C = batch_size, context_length, embedding_dimx = torch.randn((B,T,C))print('Batch Size (B):',B,'Context Length (T):',T,'Embedding Dimension (C):',C)head_size =16#Self attention head size#Learned vector to extract key vector from token embedding vectorkey_layer = nn.Linear(in_features=C, out_features=head_size, bias=False) #Learned vector to extract query vector from token embedding vectorquery_layer = nn.Linear(in_features=C, out_features=head_size, bias=False) #Learned vector to extract value vector from token embedding vectorvalue_layer = nn.Linear(in_features=C, out_features=head_size, bias=False) #NEW#Extract query, key and value for every token in the batch in parallelkey = key_layer(x) # (B,T,head_size)query = query_layer(x) # (B,T,head_size)value = value_layer(x) # (B,T,head_size) #NEW#Calculate affinity (weights)#(B,T,head_size) @ (B, head_size, T) which is (32,8,16) @ (32,16,8) -> (B, T, T) which is (32,8,8)weights = query @ key.transpose(-2, -1) weights.shape
And now instead of calculating the output by matrix multiplying the weights by x we multiply the weights by value.
tril = torch.tril(torch.ones(T,T,dtype=torch.long))weights = weights.masked_fill(tril ==0, float('-inf')) #Masks future tokensweights = torch.softmax(weights, dim=-1) #Provides even distribution (weights that add up to 1)# out = weights @ xout = weights @ value #NEW (B, T, T) @ (B, T, head_size) = (B, T, head_size)out.shape
torch.Size([32, 8, 16])
Notice how the shape of out has changed from torch.Size([32, 8, 32]) to torch.Size([32, 8, 16]) now that the we are using value which is of length 16 instead of the token embedding x which was of length 32.
You can think of the token embedding x as private information of the token and it must be passed through the linear layers to get the query, key and value. You can think of it as the token embedding x has all the information about the token and:
query: represents the things that the token is interested in or wants.
key: represents the things the token has.
value: represents, if you find the token interesting, the information the token wants to communicate.
Additional Notes on Attention:link 1) Attention is a communication mechanism. You can think of it as if you had nodes in a directed graph:
Each node has a vector of information (token embedding) and it gets a weighted sum of all of the nodes that point to it. This is done in a data dependent manner, so it depends on what data is stored in each node at any point in time. Our graph does not look quite like the example. Instead, our graph has 8 nodes. The first node is pointed to by only itself, our second node is pointed to by the first node and itself, our third node is pointed to by our first and second nodes as well as itself and so on. This structure is common in auto-regressive scenarious.
Auto-regressive in this context refers to a type of model that generates sequences by modeling the probability of each item in the sequence given its preceding items. In other words, autoregressive language models generate predictions step by step, where each prediction is dependent on the ones that came before it.
In principal attention can be applied to any arbitrary directed graph as it is just a communication mechanism between nodes.
There is no notion of space or position. Attention simply acts over a set of vectors in this graph. The nodes have no idea of where they are positioned within the space which is why we need to encode them positionally which gives them information that anchors them to a specific position. i.e. inherently the nodes, representing characters in our example, don’t know what position they occur in relative to the other nodes which is why we need to positionally encode them. You can contrast this with convolutional neural networks where the data and network inherently are modeled spatially. For example CNN’s are regularly used for computer vision applications. In these applications adjacent pixels are fed into the CNN where convolutional filters act in space preserving the spatial information about the data.
Attention, in contrast with CNN’s, has no notion of space, so space or position or location need to be encoded into the nodes through some other mechanism, which in our case is a positional encoding vector. This position vector will be added to the token prior to it being processed through the linear layers.
Additional Notes:link * Each example across batch dimensions is processed completely independently. Information from an item in a batch does not affect information in another item within the batch. Different items within a batch never talk to eachother. * In an encoder network (block), you do not filter out future tokens, only in a decoder network. This means that in an encoder network, these lines from our previous example would be removed:
There are many instances where you want all of the nodes to talk to each other, such as in sentiment analysis for example, because later on in the network you are making a simple prediction on whether the text is positive or negative. Another example would be vision transformers where you want all image patches to talk to each other. In these instances you use an encoder block which does not have masking in contrast to the decoder block which is what we have been focusing on here. * There are different types of attention. What we’re looking at now is self attention. The reason this is self-attention is because the data comes from the same source (x). Attention can be much more general than self attention, in that the source of the data can be from a different source. For example in encoder decoder networks, the queries could be produced from x but the the keys and values could come from a completely different source, for example from different encoder blocks that we would want to condition on. A real world example of this could be translating from one language to another, where the original or input language comes from an separate encoder network. The encoder network provides the keys and values and the decoder network provides the queries. This is called cross attention and is where there is a separate set of nodes we would like to pull information from into our node. Self attention, again, is where we are pulling keys, queries and values from one set of nodes.
Attention(Q,K,V) = softmax((Q*K^T)/(sqrt(dk))*V Where: Q = Query, K = Key, V = Value, dk = dimension of the Q and K or ‘head’.
The piece we are missing is dividing by sqrt(dk) which makes this ‘scaled self attention’. To do this we need to divide weights by sqrt(dk) or the dimension of the Q,K head. This makes it so when Q,K are unit variance, weights will be unit variance too which is important so softmax will remain diffused and not be saturated too much, i.e. the dot products betweek Q and K can become very large which pushes the gradients through the softmax to become very small which negatively impact training. This is why we want to scale them first before taking the softmax.
Let’s look at a real example of this:
Where q and k are a gausian or normal distributions so the mean of the values is 0 and the standard deviation is 1. When you compute the matrix multiplication between them you will notice that the variance of weights is quite high.
torch.manual_seed(TORCH_SEED)q = torch.randn(B,T,head_size)k = torch.randn(B,T,head_size)print('Mean of q:',q.mean(),'Variance of q:',q.var(),'Mean of k:',k.mean(),'Variance of k:',k.var())weights = q @ k.transpose(-2,-1)print('Shape of weights:',weights.shape,'Mean of weights:',weights.mean(),'Variance of weights:',weights.var(),'\nMin of weights:',weights.min(),'Max of weights:',weights.max())
Mean of q: tensor(0.0021) Variance of q: tensor(0.9985) Mean of k: tensor(-0.0129) Variance of k: tensor(1.0255)
Shape of weights: torch.Size([32, 8, 8]) Mean of weights: tensor(-0.0302) Variance of weights: tensor(17.3386)
Min of weights: tensor(-16.3490) Max of weights: tensor(13.1295)
Now if you divide the dot product of q and k by the square root of the head_size you can see that it returns the variance of weights back to 1 instead of approximately 17 prior to scaling.
import mathweights = (q @ k.transpose(-2,-1)) / math.sqrt(head_size) #TODO Output size is (B,T,T) (32,8,8)print('Shape of weights:',weights.shape,'Mean of weights:',weights.mean(),'Variance of weights:',weights.var(),'\nMin of weights:',weights.min(),'Max of weights:',weights.max())
Shape of weights: torch.Size([32, 8, 8]) Mean of weights: tensor(-0.0075) Variance of weights: tensor(1.0837)
Min of weights: tensor(-4.0872) Max of weights: tensor(3.2824)
We’ll create a very basic function to plot the tensors to help visualize the results.
import matplotlib.pyplot as pltdef plot_1d_tensor(x):print(x) plt.bar(range(len(x)), x)
Again, the reason that scaling weights is important is because of the subsequent softmax that is applied. When large values are input into softmax it causes the gradients to be small and the output of the softmax to converge toward one-hot vectors. First we’ll start out with one of the example weights that has already been divided by math.sqrt(head_size).
You can see the the output of softmax here is diffuse. None of the output values are overly large or small. If you multiply these same values by math.sqrt(head_size), effectively undoing scaling we applied, you will see that the results after softmax are less evenly distributed or diffuse.
If you push it even further you can see that the second item in the vector continues to grow even though the value of each element, relative to eachother has not changed.
As the input values to the softmax continue to grow the result of the softmax continues to converge to a one-hot encoded vector, which is where one of the values in the vector is 1 and all the rest are 0’s. In effect this would make it so 1 node will only draw information from one other node, which is generally not what we want. This is especially a problem during initialization of the network before training, as it can be difficult for the network to recover from this during training.
7.3 Continuing model definition
Now we’re going to create a Head module where we’ll implement a single self attention head which we’ll use in our transformer, replacing the bigram model. You can reference the video link here to follow along if you would like.
class Head(nn.Module):""" one self attention head """def__init__(self, head_size:int=head_size, embedding_dim:int=embedding_dim, context_length:int=context_length):super().__init__()self.embedding_dim = embedding_dimself.head_size = head_sizeself.context_length = context_lengthself.key_layer = nn.Linear(in_features=self.embedding_dim, out_features=self.head_size, bias=False)self.query_layer = nn.Linear(in_features=self.embedding_dim, out_features=self.head_size, bias=False)self.value_layer = nn.Linear(in_features=self.embedding_dim, out_features=self.head_size, bias=False)self.register_buffer('tril', torch.tril(torch.ones((self.context_length, self.context_length))))def forward(self, x): B,T,C = x.shapeassert T <=self.context_length #check that x.shape matches pre-defined dimsassert C ==self.embedding_dim q =self.query_layer(x) #(B,T,C) (batch size, context length, head_size k =self.key_layer(x) #(B,T,C) v =self.value_layer(x) #(B,T,C)#compute scores based on affinities weights = (q @ k.transpose(-2,-1)) *self.head_size**-0.5# (B,T,C) @ (B,C,T) -> (B,T,T) weights = weights.masked_fill(self.tril[:T,:T] ==0, float('-inf')) #(B,T,T) weights = F.softmax(input=weights, dim=-1) #(B,T,T)#perform weighted aggragation of the values out = weights @ v # (B,T,T) @ (B,T,C) -> (B,T,C)return out# Head()(x)
The register_buffer method is utilized to incorporate the tril matrix as a part of the model’s state. This integration ensures that tril is consistently saved and loaded with the model, maintaining uniform behavior across various runs and settings. Crucially, being a buffer, tril is excluded from gradient calculations and is not included as a parameter during model optimization, thereby rendering it a non-trainable component of the model.
To make visualizing the training loss easier we’ll create a simple function to plot them.
def plot_losses(losses): train_losses = [o['train'] for o in losses if o.get('train') isnotNone] valid_losses = [o['valid'] for o in losses if o.get('valid') isnotNone] plt.plot(train_losses, label='Training Loss') plt.plot(valid_losses, label='Validation Loss') plt.ylabel('Loss') plt.title('Losses') plt.legend() plt.show()
Now we’ll add our new Head implementation to the TransformerLanguageModel class and train a model to ensure everything is working as well as to get a baseline of the results. Note we are also adding a token_position_embedding_table to encode the token positions. This learned looked up value will be added to the token_embeddings.
learning_rate =1e-3# decreate the learning rate because self attention cannot tolerate very high learning rates.max_iters =5000
torch.manual_seed(TORCH_SEED)class TransformerLanguageModel(nn.Module):def__init__(self, vocab_size:int=vocab_size, embedding_dim:int=embedding_dim, context_length:int=context_length, head_size:int=head_size):super().__init__()self.vocab_size = vocab_sizeself.embedding_dim = embedding_dimself.context_length = context_lengthself.head_size = head_size#This will be our lookup table for embeddings. We'll have an entry for each token (aka vocab size) and each embedding will... #...be a vector of dimension embedding_dim.self.token_embedding_table = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.embedding_dim)self.token_position_embedding_table = nn.Embedding(num_embeddings=self.context_length, embedding_dim=self.embedding_dim)self.self_attention_head_linear_layer = Head(head_size=head_size, embedding_dim=embedding_dim, context_length=context_length)self.language_model_head_linear_layer = nn.Linear(in_features=self.head_size, out_features=self.vocab_size)def forward(self, idx, targets=None):#Both idx and targets are (B,T) Batch x Time array of integers B,T = idx.shape token_embeddings =self.token_embedding_table(idx) #(B,T,C) Batch, Time, Channel token_position_embeddings =self.token_position_embedding_table(torch.arange(T, device=device)) #(T,C) x = token_embeddings + token_position_embeddings x =self.self_attention_head_linear_layer(x) #apply one head of self attention logits =self.language_model_head_linear_layer(x) #(B,T,C) Where C is now token logits of size vocab_sizeif targets isnotNone: B,T,C = logits.shape logits_reshaped = logits.view(B*T,C) targets_reshaped = targets.view(B*T) loss = F.cross_entropy(input=logits_reshaped, target=targets_reshaped)else: loss=Nonereturn logits, lossdef generate(self, idx, max_new_tokens):#idx is (B,T) array of indicies in the current contextfor _ inrange(max_new_tokens):#Crop idx to the max size of our positional embeddings table idx_crop = idx[:,-self.context_length:]#Get predictions logits, loss =self(idx_crop)#Get the last time step from logits where the dimensions of the logits are (B,T,C) logits_last_timestep = logits[:,-1,:] #Becomes (B,C)# print('Shape of logits_last_timestep:',logits_last_timestep.shape) #confirming shape#Apply softmax to get probabilities probs = F.softmax(input=logits_last_timestep, dim=-1) #(B,C)# print('Shape of probs:', probs.shape) #confirming shape#Sample from the probs distribution. idx_next = torch.multinomial(input=probs, num_samples=1) #(B,1) Returns (B,idxs) where idxs are random integer indicies.# print('Shape of idx_next:',idx_next.shape,'and contents:',idx_next) #look at the shape and contents of idx_next#Append the sampled indexes idx_next to idx idx = torch.cat((idx, idx_next), dim=1) #(B, T+1)return idxmodel = TransformerLanguageModel(vocab_size=vocab_size, embedding_dim=embedding_dim, context_length=context_length)model = model.to(device)optimizer = torch.optim.AdamW(params=model.parameters(), lr=learning_rate)tracked_losses =list()for step inrange(max_iters):if step % eval_iters ==0or step == max_iters-1: losses = estimate_loss() tracked_losses.append(losses)print('Step:',step,'Training Loss:',losses['train'],'Validation Loss:',losses['valid']) xb,yb = get_batch('train') logits, loss = model(xb,yb) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step()plot_losses(tracked_losses)context = torch.zeros((1,1), dtype=torch.long, device=device)print(decode(model.generate(context,max_new_tokens=500)[0].tolist()))
Step: 0 Training Loss: tensor(4.1743) Validation Loss: tensor(4.1712)
Step: 200 Training Loss: tensor(3.1199) Validation Loss: tensor(3.1343)
Step: 400 Training Loss: tensor(2.8712) Validation Loss: tensor(2.8892)
Step: 600 Training Loss: tensor(2.7071) Validation Loss: tensor(2.7260)
Step: 800 Training Loss: tensor(2.6324) Validation Loss: tensor(2.6392)
Step: 1000 Training Loss: tensor(2.5896) Validation Loss: tensor(2.5849)
Step: 1200 Training Loss: tensor(2.5460) Validation Loss: tensor(2.5497)
Step: 1400 Training Loss: tensor(2.5158) Validation Loss: tensor(2.5259)
Step: 1600 Training Loss: tensor(2.5000) Validation Loss: tensor(2.5051)
Step: 1800 Training Loss: tensor(2.4885) Validation Loss: tensor(2.4980)
Step: 2000 Training Loss: tensor(2.4632) Validation Loss: tensor(2.4858)
Step: 2200 Training Loss: tensor(2.4572) Validation Loss: tensor(2.4797)
Step: 2400 Training Loss: tensor(2.4632) Validation Loss: tensor(2.4467)
Step: 2600 Training Loss: tensor(2.4587) Validation Loss: tensor(2.4553)
Step: 2800 Training Loss: tensor(2.4338) Validation Loss: tensor(2.4533)
Step: 3000 Training Loss: tensor(2.4402) Validation Loss: tensor(2.4562)
Step: 3200 Training Loss: tensor(2.4409) Validation Loss: tensor(2.4492)
Step: 3400 Training Loss: tensor(2.4249) Validation Loss: tensor(2.4487)
Step: 3600 Training Loss: tensor(2.4376) Validation Loss: tensor(2.4395)
Step: 3800 Training Loss: tensor(2.4166) Validation Loss: tensor(2.4278)
Step: 4000 Training Loss: tensor(2.4102) Validation Loss: tensor(2.4275)
Step: 4200 Training Loss: tensor(2.4191) Validation Loss: tensor(2.4384)
Step: 4400 Training Loss: tensor(2.4178) Validation Loss: tensor(2.4217)
Step: 4600 Training Loss: tensor(2.4077) Validation Loss: tensor(2.4109)
Step: 4800 Training Loss: tensor(2.4062) Validation Loss: tensor(2.4189)
Step: 4999 Training Loss: tensor(2.4043) Validation Loss: tensor(2.4176)
And thef tridcowind tis n, ber
Hiset bobe toe.
S:
O-' my dalatanss:
Want he uw hathe.
War dthas ate awice my.
Haldaru zorou wabuts, tof is hy me mil ndill, aes iree sen cin lat Het drovets, and Win ng:
Wilerabous lplind peallllishe onchiry:
Augr aiss hawty.
'Thake norodpeeelaves
Momy.
Whod mothake onWindo whe Ceiiby, wout, fourive wees ired thoous
Ar-x's uhe kad nterthirf so;
Angis! m:
E nge male ont ffaf Pre?
WISo myat houre!
Widby ak
Sadsal thes ghe thidin cour ay aney Iry ts chan th voul
Next we’ll add multi-head attention which is just computing multiple attention heads together in parallel and then concatenating the results.
class MultiHeadAttention(nn.Module):def__init__(self, num_heads:int, head_size:int=head_size, embedding_dim:int=embedding_dim, context_length:int=context_length):super().__init__()self.num_heads = num_headsself.head_size = head_sizeself.embedding_dim = embedding_dimself.context_length = context_lengthself.heads = nn.ModuleList([ Head(head_size=self.head_size, embedding_dim=self.embedding_dim, context_length=self.context_length) for _ inrange(self.num_heads) ])def forward(self, x):return torch.cat([h(x) for h inself.heads], dim=-1) #Note the concat is in the last 'C' dimension => (B,T,C*num_heads)
Now let’s add our newly created multi-head attention back into our Model.
torch.manual_seed(TORCH_SEED)class TransformerLanguageModel(nn.Module):def__init__(self, vocab_size:int=vocab_size, embedding_dim:int=embedding_dim, context_length:int=context_length, head_size:int=head_size):super().__init__()self.vocab_size = vocab_sizeself.embedding_dim = embedding_dimself.context_length = context_lengthself.head_size = head_size#This will be our lookup table for embeddings. We'll have an entry for each token (aka vocab size) and each embedding will... #...be a vector of dimension embedding_dim.self.token_embedding_table = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.embedding_dim)self.token_position_embedding_table = nn.Embedding(num_embeddings=self.context_length, embedding_dim=self.embedding_dim)# self.self_attention_head_linear_layer = Head(head_size=head_size, embedding_dim=embedding_dim, context_length=context_length)#4 heads of 8 dimensional self attention.self.multi_self_attention_heads_layer = MultiHeadAttention(num_heads=4, head_size=self.embedding_dim//4) #NEWself.language_model_head_linear_layer = nn.Linear(in_features=self.embedding_dim, out_features=self.vocab_size)def forward(self, idx, targets=None):#Both idx and targets are (B,T) Batch x Time array of integers B,T = idx.shape token_embeddings =self.token_embedding_table(idx) #(B,T,C) Batch, Time, Channel token_position_embeddings =self.token_position_embedding_table(torch.arange(T, device=device)) #(T,C) x = token_embeddings + token_position_embeddings# x = self.self_attention_head_linear_layer(x) #apply one head of self attention x =self.multi_self_attention_heads_layer(x) logits =self.language_model_head_linear_layer(x) #(B,T,C) Where C is now token logits of size vocab_sizeif targets isnotNone: B,T,C = logits.shape logits_reshaped = logits.view(B*T,C) targets_reshaped = targets.view(B*T) loss = F.cross_entropy(input=logits_reshaped, target=targets_reshaped)else: loss=Nonereturn logits, lossdef generate(self, idx, max_new_tokens):#idx is (B,T) array of indicies in the current contextfor _ inrange(max_new_tokens):#Crop idx to the max size of our positional embeddings table idx_crop = idx[:,-self.context_length:]#Get predictions logits, loss =self(idx_crop)#Get the last time step from logits where the dimensions of the logits are (B,T,C) logits_last_timestep = logits[:,-1,:] #Becomes (B,C)# print('Shape of logits_last_timestep:',logits_last_timestep.shape) #confirming shape#Apply softmax to get probabilities probs = F.softmax(input=logits_last_timestep, dim=-1) #(B,C)# print('Shape of probs:', probs.shape) #confirming shape#Sample from the probs distribution. idx_next = torch.multinomial(input=probs, num_samples=1) #(B,1) Returns (B,idxs) where idxs are random integer indicies.# print('Shape of idx_next:',idx_next.shape,'and contents:',idx_next) #look at the shape and contents of idx_next#Append the sampled indexes idx_next to idx idx = torch.cat((idx, idx_next), dim=1) #(B, T+1)return idxmodel = TransformerLanguageModel(vocab_size=vocab_size, embedding_dim=embedding_dim, context_length=context_length)model = model.to(device)optimizer = torch.optim.AdamW(params=model.parameters(), lr=learning_rate)tracked_losses =list()for step inrange(max_iters):if step % eval_iters ==0or step == max_iters-1: losses = estimate_loss() tracked_losses.append(losses)print('Step:',step,'Training Loss:',losses['train'],'Validation Loss:',losses['valid']) xb,yb = get_batch('train') logits, loss = model(xb,yb) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step()plot_losses(tracked_losses)context = torch.zeros((1,1), dtype=torch.long, device=device)print(decode(model.generate(context,max_new_tokens=500)[0].tolist()))
Step: 0 Training Loss: tensor(4.2248) Validation Loss: tensor(4.2250)
Step: 200 Training Loss: tensor(3.0112) Validation Loss: tensor(3.0132)
Step: 400 Training Loss: tensor(2.7330) Validation Loss: tensor(2.7487)
Step: 600 Training Loss: tensor(2.6190) Validation Loss: tensor(2.6244)
Step: 800 Training Loss: tensor(2.5537) Validation Loss: tensor(2.5700)
Step: 1000 Training Loss: tensor(2.5222) Validation Loss: tensor(2.5220)
Step: 1200 Training Loss: tensor(2.4785) Validation Loss: tensor(2.4870)
Step: 1400 Training Loss: tensor(2.4509) Validation Loss: tensor(2.4563)
Step: 1600 Training Loss: tensor(2.4205) Validation Loss: tensor(2.4278)
Step: 1800 Training Loss: tensor(2.3966) Validation Loss: tensor(2.4144)
Step: 2000 Training Loss: tensor(2.3658) Validation Loss: tensor(2.3828)
Step: 2200 Training Loss: tensor(2.3729) Validation Loss: tensor(2.3910)
Step: 2400 Training Loss: tensor(2.3579) Validation Loss: tensor(2.3466)
Step: 2600 Training Loss: tensor(2.3544) Validation Loss: tensor(2.3499)
Step: 2800 Training Loss: tensor(2.3267) Validation Loss: tensor(2.3427)
Step: 3000 Training Loss: tensor(2.3259) Validation Loss: tensor(2.3410)
Step: 3200 Training Loss: tensor(2.3180) Validation Loss: tensor(2.3313)
Step: 3400 Training Loss: tensor(2.3070) Validation Loss: tensor(2.3142)
Step: 3600 Training Loss: tensor(2.3024) Validation Loss: tensor(2.3078)
Step: 3800 Training Loss: tensor(2.2728) Validation Loss: tensor(2.3038)
Step: 4000 Training Loss: tensor(2.2630) Validation Loss: tensor(2.2855)
Step: 4200 Training Loss: tensor(2.2825) Validation Loss: tensor(2.2850)
Step: 4400 Training Loss: tensor(2.2734) Validation Loss: tensor(2.2868)
Step: 4600 Training Loss: tensor(2.2629) Validation Loss: tensor(2.2753)
Step: 4800 Training Loss: tensor(2.2425) Validation Loss: tensor(2.2706)
Step: 4999 Training Loss: tensor(2.2440) Validation Loss: tensor(2.2609)
And they tridcowd,
This so be madises bube to tavegr-'t theall ands:
Want he us hat tot?
Wedtlas anes wice my.
HDER:
At onoth
Youts, tof is hy me mil nowlit,
Wheirwe sen cin lat Het drov the and the nown iserans!
lolind teall thus, cocrivy prugh aiss hewty.
Hllings kne
To thig I whom.
Whoul to ake onWinso whre piiby we atit,
Crive winghience poo mo the thu the danterupt fis are;
De! muf thre male of,
To fis.
Fe I So myakny, be!
Whied is:
Sadsal the E'd st huin couk ay andy Iry to cof my carey
As you can see there is quite an improvement in the loss, going from Validation Loss: tensor(2.4176) with a single attention head to Validation Loss: tensor(2.2609) with our multi-attention head that has 4 heads. Note, these losses may vary somewhat between training runs. The results are still nonsense, but are looking closer to the training text than previous attempts. The reason that multi-headed attention works better than a single self attention block is that it is helpful to have multiple communication channels between tokens so they can each be looking for different things over different communication channels. As an example one communication channel make be looking back at vowels or consonants while another might be looking for the previous space.
If you look at this transformer block diagram, you can see that we’ve implemented quite a bit of it so far.
We’ve implemented the output embeddings, positional embeddings, (the lower) masked multi-head attention, and the final linear and softmax layers. We are going to skip the multi-head attention block in the middle as that is only needed if your model has an encoder block, which ours does not. This leaves the feed forward network to implement which is just a simple multi layer perceptron. In addition the entire block between the positional encodings and final linear layer can be stacked on top of itself multiple times signified by Nx.
Here is the equation for the feed forward network, which is a simple multi layer perceptron:
class FeedForwardNetwork(nn.Module):"""A simple linear network followed by a non-linearity"""def__init__(self, embedding_dim:int=embedding_dim):super().__init__()self.embedding_dim = embedding_dimself.ffn = nn.Sequential( nn.Linear(in_features=self.embedding_dim, out_features=self.embedding_dim), nn.ReLU() )def forward(self, x):returnself.ffn(x)
Note: In the equation it defines a (linear layer), (relu), and (linear layer). We’ll add the final linear layer later. Now let’s add our FFN to our Transformer Model.
%%time#| output: truetorch.manual_seed(TORCH_SEED)class TransformerLanguageModel(nn.Module):def__init__(self, vocab_size:int=vocab_size, embedding_dim:int=embedding_dim, context_length:int=context_length, head_size:int=head_size):super().__init__()self.vocab_size = vocab_sizeself.embedding_dim = embedding_dimself.context_length = context_lengthself.head_size = head_size#This will be our lookup table for embeddings. We'll have an entry for each token (aka vocab size) and each embedding will... #...be a vector of dimension embedding_dim.self.token_embedding_table = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.embedding_dim)self.token_position_embedding_table = nn.Embedding(num_embeddings=self.context_length, embedding_dim=self.embedding_dim)#4 heads of 8 dimensional self attention.self.multi_self_attention_heads_layer = MultiHeadAttention(num_heads=4, head_size=self.embedding_dim//4)self.feed_forward_network = FeedForwardNetwork(embedding_dim=self.embedding_dim) #NEWself.language_model_head_linear_layer = nn.Linear(in_features=self.embedding_dim, out_features=self.vocab_size)def forward(self, idx, targets=None):#Both idx and targets are (B,T) Batch x Time array of integers B,T = idx.shape token_embeddings =self.token_embedding_table(idx) #(B,T,C) Batch, Time, Channel token_position_embeddings =self.token_position_embedding_table(torch.arange(T, device=device)) #(T,C) x = token_embeddings + token_position_embeddings x =self.multi_self_attention_heads_layer(x) # (B,T,C) x =self.feed_forward_network(x) # (B,T,C) NEW logits =self.language_model_head_linear_layer(x) #(B,T,C) Where C is now token logits of size vocab_sizeif targets isnotNone: B,T,C = logits.shape logits_reshaped = logits.view(B*T,C) targets_reshaped = targets.view(B*T) loss = F.cross_entropy(input=logits_reshaped, target=targets_reshaped)else: loss=Nonereturn logits, lossdef generate(self, idx, max_new_tokens):#idx is (B,T) array of indicies in the current contextfor _ inrange(max_new_tokens):#Crop idx to the max size of our positional embeddings table idx_crop = idx[:,-self.context_length:]#Get predictions logits, loss =self(idx_crop)#Get the last time step from logits where the dimensions of the logits are (B,T,C) logits_last_timestep = logits[:,-1,:] #Becomes (B,C)# print('Shape of logits_last_timestep:',logits_last_timestep.shape) #confirming shape#Apply softmax to get probabilities probs = F.softmax(input=logits_last_timestep, dim=-1) #(B,C)# print('Shape of probs:', probs.shape) #confirming shape#Sample from the probs distribution. idx_next = torch.multinomial(input=probs, num_samples=1) #(B,1) Returns (B,idxs) where idxs are random integer indicies.# print('Shape of idx_next:',idx_next.shape,'and contents:',idx_next) #look at the shape and contents of idx_next#Append the sampled indexes idx_next to idx idx = torch.cat((idx, idx_next), dim=1) #(B, T+1)return idxmodel = TransformerLanguageModel(vocab_size=vocab_size, embedding_dim=embedding_dim, context_length=context_length)model = model.to(device)optimizer = torch.optim.AdamW(params=model.parameters(), lr=learning_rate)tracked_losses =list()for step inrange(max_iters):if step % eval_iters ==0or step == max_iters-1: losses = estimate_loss() tracked_losses.append(losses)print('Step:',step,'Training Loss:',losses['train'],'Validation Loss:',losses['valid']) xb,yb = get_batch('train') logits, loss = model(xb,yb) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step()plot_losses(tracked_losses)context = torch.zeros((1,1), dtype=torch.long, device=device)print(decode(model.generate(context,max_new_tokens=500)[0].tolist()))
Step: 0 Training Loss: tensor(4.2022) Validation Loss: tensor(4.2019)
Step: 200 Training Loss: tensor(2.9494) Validation Loss: tensor(2.9685)
Step: 400 Training Loss: tensor(2.6759) Validation Loss: tensor(2.6864)
Step: 600 Training Loss: tensor(2.5779) Validation Loss: tensor(2.5799)
Step: 800 Training Loss: tensor(2.5171) Validation Loss: tensor(2.5197)
Step: 1000 Training Loss: tensor(2.4739) Validation Loss: tensor(2.4704)
Step: 1200 Training Loss: tensor(2.4210) Validation Loss: tensor(2.4257)
Step: 1400 Training Loss: tensor(2.4079) Validation Loss: tensor(2.4105)
Step: 1600 Training Loss: tensor(2.3843) Validation Loss: tensor(2.3845)
Step: 1800 Training Loss: tensor(2.3682) Validation Loss: tensor(2.3731)
Step: 2000 Training Loss: tensor(2.3387) Validation Loss: tensor(2.3475)
Step: 2200 Training Loss: tensor(2.3342) Validation Loss: tensor(2.3500)
Step: 2400 Training Loss: tensor(2.3180) Validation Loss: tensor(2.3127)
Step: 2600 Training Loss: tensor(2.3176) Validation Loss: tensor(2.3160)
Step: 2800 Training Loss: tensor(2.2881) Validation Loss: tensor(2.3087)
Step: 3000 Training Loss: tensor(2.2834) Validation Loss: tensor(2.3059)
Step: 3200 Training Loss: tensor(2.2796) Validation Loss: tensor(2.2901)
Step: 3400 Training Loss: tensor(2.2719) Validation Loss: tensor(2.2743)
Step: 3600 Training Loss: tensor(2.2675) Validation Loss: tensor(2.2681)
Step: 3800 Training Loss: tensor(2.2428) Validation Loss: tensor(2.2751)
Step: 4000 Training Loss: tensor(2.2294) Validation Loss: tensor(2.2524)
Step: 4200 Training Loss: tensor(2.2468) Validation Loss: tensor(2.2545)
Step: 4400 Training Loss: tensor(2.2373) Validation Loss: tensor(2.2437)
Step: 4600 Training Loss: tensor(2.2310) Validation Loss: tensor(2.2448)
Step: 4800 Training Loss: tensor(2.2182) Validation Loss: tensor(2.2522)
Step: 4999 Training Loss: tensor(2.2135) Validation Loss: tensor(2.2291)
Wher bef bridcowf,
The lay ble
bairet bube to tave O-' my dalllauss:
Want he us he hertbar dilth anes with my thand a wizorm he offs, to fit her! Varl nowlit,
Wheiree sen cin lat Heacliov the and the nown!
Ferablesel lind teall thull cechir speave aiss hewty.
HETBHUSIRCBETI:
Alave whom
Ill, demet aklecal-'so wher piichs withe dour warce hidend thoouse the the the danderthirf son; igis! muf thre ifled at tise Pried my of.
HKINGLER:
Widby and adsal ther grest hoin cour ay aney Iry thel fronf veay
CPU times: user 1min 15s, sys: 445 ms, total: 1min 15s
Wall time: 1min 15s
Our loss has improved again from Validation Loss: tensor(2.2854) now to Validation Loss: tensor(2.2720) now that we’ve added the feed forward network.
Next we need to create a Block module that incorporates everthing within the block on the transformer architecture diagram (grey box) which will then allow us to stack them.
class TransformerBlock(nn.Module):"""Transformer Block: Communication folled by computation."""def__init__(self, embedding_dim:int=embedding_dim, context_length:int=context_length, num_heads:int=4):#embedding_dim: embedding dimension, num_heads: the number of heads that we wantsuper().__init__()self.embedding_dim = embedding_dimself.context_length = context_lengthself.head_size = embedding_dim // num_headsself.num_heads = num_headsself.multi_self_attention_heads_layer = MultiHeadAttention(num_heads=self.num_heads, head_size=self.head_size, embedding_dim=embedding_dim, context_length=context_length)self.feed_forward_network = FeedForwardNetwork(embedding_dim=self.embedding_dim)def forward(self, x):returnself.feed_forward_network(self.multi_self_attention_heads_layer(x))
Now we can add our new Transformer Block to our model and start stacking it.
%%time#| output: truetorch.manual_seed(TORCH_SEED)class TransformerLanguageModel(nn.Module):def__init__(self, vocab_size:int=vocab_size, embedding_dim:int=embedding_dim, context_length:int=context_length, head_size:int=head_size):super().__init__()self.vocab_size = vocab_sizeself.embedding_dim = embedding_dimself.context_length = context_lengthself.head_size = head_size#This will be our lookup table for embeddings. We'll have an entry for each token (aka vocab size) and each embedding will... #...be a vector of dimension embedding_dim.self.token_embedding_table = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.embedding_dim)self.token_position_embedding_table = nn.Embedding(num_embeddings=self.context_length, embedding_dim=self.embedding_dim)# self.multi_self_attention_heads_layer = MultiHeadAttention(num_heads=4, head_size=self.embedding_dim//4)# self.feed_forward_network = FeedForwardNetwork(embedding_dim=self.embedding_dim)self.transformer_blocks = nn.Sequential( TransformerBlock(embedding_dim=embedding_dim, num_heads=4, context_length=context_length), TransformerBlock(embedding_dim=embedding_dim, num_heads=4, context_length=context_length), TransformerBlock(embedding_dim=embedding_dim, num_heads=4, context_length=context_length), ) #NEWself.language_model_head_linear_layer = nn.Linear(in_features=self.embedding_dim, out_features=self.vocab_size)def forward(self, idx, targets=None):#Both idx and targets are (B,T) Batch x Time array of integers B,T = idx.shape token_embeddings =self.token_embedding_table(idx) #(B,T,C) Batch, Time, Channel token_position_embeddings =self.token_position_embedding_table(torch.arange(T, device=device)) #(T,C) x = token_embeddings + token_position_embeddings# x = self.multi_self_attention_heads_layer(x) # (B,T,C)# x = self.feed_forward_network(x) # (B,T,C) x =self.transformer_blocks(x) #NEW logits =self.language_model_head_linear_layer(x) #(B,T,C) Where C is now token logits of size vocab_sizeif targets isnotNone: B,T,C = logits.shape logits_reshaped = logits.view(B*T,C) targets_reshaped = targets.view(B*T) loss = F.cross_entropy(input=logits_reshaped, target=targets_reshaped)else: loss=Nonereturn logits, lossdef generate(self, idx, max_new_tokens):#idx is (B,T) array of indicies in the current contextfor _ inrange(max_new_tokens):#Crop idx to the max size of our positional embeddings table idx_crop = idx[:,-self.context_length:]#Get predictions logits, loss =self(idx_crop)#Get the last time step from logits where the dimensions of the logits are (B,T,C) logits_last_timestep = logits[:,-1,:] #Becomes (B,C)# print('Shape of logits_last_timestep:',logits_last_timestep.shape) #confirming shape#Apply softmax to get probabilities probs = F.softmax(input=logits_last_timestep, dim=-1) #(B,C)# print('Shape of probs:', probs.shape) #confirming shape#Sample from the probs distribution. idx_next = torch.multinomial(input=probs, num_samples=1) #(B,1) Returns (B,idxs) where idxs are random integer indicies.# print('Shape of idx_next:',idx_next.shape,'and contents:',idx_next) #look at the shape and contents of idx_next#Append the sampled indexes idx_next to idx idx = torch.cat((idx, idx_next), dim=1) #(B, T+1)return idxmodel = TransformerLanguageModel(vocab_size=vocab_size, embedding_dim=embedding_dim, context_length=context_length)model = model.to(device)optimizer = torch.optim.AdamW(params=model.parameters(), lr=learning_rate)tracked_losses =list()for step inrange(max_iters):if step % eval_iters ==0or step == max_iters-1: losses = estimate_loss() tracked_losses.append(losses)print('Step:',step,'Training Loss:',losses['train'],'Validation Loss:',losses['valid']) xb,yb = get_batch('train') logits, loss = model(xb,yb) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step()plot_losses(tracked_losses)context = torch.zeros((1,1), dtype=torch.long, device=device)print(decode(model.generate(context,max_new_tokens=500)[0].tolist()))
Step: 0 Training Loss: tensor(4.2116) Validation Loss: tensor(4.2078)
Step: 200 Training Loss: tensor(3.2643) Validation Loss: tensor(3.2907)
Step: 400 Training Loss: tensor(3.1541) Validation Loss: tensor(3.1676)
Step: 600 Training Loss: tensor(3.0360) Validation Loss: tensor(3.0239)
Step: 800 Training Loss: tensor(2.8569) Validation Loss: tensor(2.8526)
Step: 1000 Training Loss: tensor(2.7738) Validation Loss: tensor(2.7607)
Step: 1200 Training Loss: tensor(2.6645) Validation Loss: tensor(2.6827)
Step: 1400 Training Loss: tensor(2.6202) Validation Loss: tensor(2.6159)
Step: 1600 Training Loss: tensor(2.5581) Validation Loss: tensor(2.5613)
Step: 1800 Training Loss: tensor(2.5231) Validation Loss: tensor(2.5388)
Step: 2000 Training Loss: tensor(2.5020) Validation Loss: tensor(2.5028)
Step: 2200 Training Loss: tensor(2.4899) Validation Loss: tensor(2.4974)
Step: 2400 Training Loss: tensor(2.4812) Validation Loss: tensor(2.4668)
Step: 2600 Training Loss: tensor(2.4656) Validation Loss: tensor(2.4641)
Step: 2800 Training Loss: tensor(2.4574) Validation Loss: tensor(2.4511)
Step: 3000 Training Loss: tensor(2.4306) Validation Loss: tensor(2.4413)
Step: 3200 Training Loss: tensor(2.4016) Validation Loss: tensor(2.4273)
Step: 3400 Training Loss: tensor(2.3946) Validation Loss: tensor(2.3999)
Step: 3600 Training Loss: tensor(2.3738) Validation Loss: tensor(2.3823)
Step: 3800 Training Loss: tensor(2.3844) Validation Loss: tensor(2.3700)
Step: 4000 Training Loss: tensor(2.3571) Validation Loss: tensor(2.3570)
Step: 4200 Training Loss: tensor(2.3426) Validation Loss: tensor(2.3729)
Step: 4400 Training Loss: tensor(2.3394) Validation Loss: tensor(2.3696)
Step: 4600 Training Loss: tensor(2.3300) Validation Loss: tensor(2.3343)
Step: 4800 Training Loss: tensor(2.3263) Validation Loss: tensor(2.3400)
Step: 4999 Training Loss: tensor(2.3301) Validation Loss: tensor(2.3403)
And thik bry cowd,
This bor thibe sou bobe to:
ave rud my thichanss:
Warth fou qor, ve bar dilth afe aw cramy.
Hhy ar mereou waow somtof is he ce mil nowlincaes ireees, hein latiser lilv the and the non ond wans!
Aplind pealltliser cechiry: tur hais's, why hou to u nor
To thigh sond:
Il wo to thake o Windo wher eiibk we ati dourive we hidend thoo mowr-x'd und kad nonrtf he sor; iris! mef thin inled,
The af Pre?
KIS
INUSH:
Nube!
Giyd is:
ards beace Eghes bidin cou afar tey ir-ltome fronf ve y
CPU times: user 2min 56s, sys: 421 ms, total: 2min 57s
Wall time: 2min 56s
As you can see the accuracy actually got worse. Given our new much more powerful model, this is not something that we want. As the depth of models increase they can become harder to train. Fortunately there are a few things that we can do about that. link
First we can implement skip connections, also known as residual connections, which are depicted on the transformer architecture diagram as black lines that bypass the masked multi-head attention block and feed into the add and norm block. You can also see one bypassing the FFN. The idea for these originally came from deep residual networks paper. In this case we are going to add the input data back to the output of the blocks that are being skipped. When you use addition, the gradients are evenly distributed between both the skip branch and the block branch. An alternative that is sometimes used is a simple concatenation of the input and output of the skipped block.
When we initialize the network before training we typically want to start off with very low weights for the branches that go through the blocks so the blocks contribute very little to the overall loss. This way the gradient signal makes its way through the entire network. Then during training the network will slowly increase the weights and participation of the blocks.
Now let’s implement the skip connections in our TransfomerBlock module.
class TransformerBlock(nn.Module):"""Transformer Block: Communication folled by computation."""def__init__(self, embedding_dim:int=embedding_dim, context_length:int=context_length, num_heads:int=4):#embedding_dim: embedding dimension, num_heads: the number of heads that we wantsuper().__init__()self.embedding_dim = embedding_dimself.context_length = context_lengthself.head_size = embedding_dim // num_headsself.num_heads = num_headsself.multi_self_attention_heads_layer = MultiHeadAttention(num_heads=self.num_heads, head_size=self.head_size, embedding_dim=embedding_dim, context_length=context_length)self.feed_forward_network = FeedForwardNetwork(embedding_dim=self.embedding_dim)def forward(self, x):# return self.feed_forward_network(self.multi_self_attention_heads_layer(x)) x = x +self.multi_self_attention_heads_layer(x) # adding input back to the output of each block for skip connection. NEW x = x +self.feed_forward_network(x) # adding input back to the output of each block for skip connection. NEWreturn x
We also need to add a projection layer to our MultiHeadAttention module as well as the feed forward network. This is a simple linear layer.
class MultiHeadAttention(nn.Module):def__init__(self, num_heads:int, head_size:int=head_size, embedding_dim:int=embedding_dim, context_length:int=context_length):super().__init__()self.num_heads = num_headsself.head_size = head_sizeself.embedding_dim = embedding_dimself.context_length = context_lengthself.heads = nn.ModuleList([ Head(head_size=self.head_size, embedding_dim=self.embedding_dim, context_length=self.context_length) for _ inrange(self.num_heads)])self.projection_layer = nn.Linear(in_features=self.embedding_dim, out_features=self.embedding_dim, bias=True) #NEWdef forward(self, x): out = torch.cat([h(x) for h inself.heads], dim=-1) out =self.projection_layer(out) #NEWreturn out
In the FFN rather than adding the same projection layer parameter we’ll simply just add an additional linear layer to the existing sequential module. Also we are going to fan out and then back in by a factor of 4 between the linear layers in the FFN to add additional computation.
class FeedForwardNetwork(nn.Module):"""A simple linear network followed by a non-linearity"""def__init__(self, embedding_dim:int=embedding_dim):super().__init__()self.embedding_dim = embedding_dimself.ffn = nn.Sequential( nn.Linear(in_features=self.embedding_dim, out_features=self.embedding_dim*4),#Updated nn.ReLU(), nn.Linear(in_features=self.embedding_dim*4, out_features=self.embedding_dim) #NEW )def forward(self, x):returnself.ffn(x)
Now let’s train the network again to see how we end up.
%%time#| output: truetorch.manual_seed(TORCH_SEED)class TransformerLanguageModel(nn.Module):def__init__(self, vocab_size:int=vocab_size, embedding_dim:int=embedding_dim, context_length:int=context_length, head_size:int=head_size):super().__init__()self.vocab_size = vocab_sizeself.embedding_dim = embedding_dimself.context_length = context_lengthself.head_size = head_size#This will be our lookup table for embeddings. We'll have an entry for each token (aka vocab size) and each embedding will... #...be a vector of dimension embedding_dim.self.token_embedding_table = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.embedding_dim)self.token_position_embedding_table = nn.Embedding(num_embeddings=self.context_length, embedding_dim=self.embedding_dim)self.transformer_blocks = nn.Sequential( TransformerBlock(embedding_dim=embedding_dim, num_heads=4, context_length=context_length), TransformerBlock(embedding_dim=embedding_dim, num_heads=4, context_length=context_length), TransformerBlock(embedding_dim=embedding_dim, num_heads=4, context_length=context_length), )self.language_model_head_linear_layer = nn.Linear(in_features=self.embedding_dim, out_features=self.vocab_size)def forward(self, idx, targets=None):#Both idx and targets are (B,T) Batch x Time array of integers B,T = idx.shape token_embeddings =self.token_embedding_table(idx) #(B,T,C) Batch, Time, Channel token_position_embeddings =self.token_position_embedding_table(torch.arange(T, device=device)) #(T,C) x = token_embeddings + token_position_embeddings x =self.transformer_blocks(x) logits =self.language_model_head_linear_layer(x) #(B,T,C) Where C is now token logits of size vocab_sizeif targets isnotNone: B,T,C = logits.shape logits_reshaped = logits.view(B*T,C) targets_reshaped = targets.view(B*T) loss = F.cross_entropy(input=logits_reshaped, target=targets_reshaped)else: loss=Nonereturn logits, lossdef generate(self, idx, max_new_tokens):#idx is (B,T) array of indicies in the current contextfor _ inrange(max_new_tokens):#Crop idx to the max size of our positional embeddings table idx_crop = idx[:,-self.context_length:]#Get predictions logits, loss =self(idx_crop)#Get the last time step from logits where the dimensions of the logits are (B,T,C) logits_last_timestep = logits[:,-1,:] #Becomes (B,C)# print('Shape of logits_last_timestep:',logits_last_timestep.shape) #confirming shape#Apply softmax to get probabilities probs = F.softmax(input=logits_last_timestep, dim=-1) #(B,C)# print('Shape of probs:', probs.shape) #confirming shape#Sample from the probs distribution. idx_next = torch.multinomial(input=probs, num_samples=1) #(B,1) Returns (B,idxs) where idxs are random integer indicies.# print('Shape of idx_next:',idx_next.shape,'and contents:',idx_next) #look at the shape and contents of idx_next#Append the sampled indexes idx_next to idx idx = torch.cat((idx, idx_next), dim=1) #(B, T+1)return idxmodel = TransformerLanguageModel(vocab_size=vocab_size, embedding_dim=embedding_dim, context_length=context_length)model = model.to(device)optimizer = torch.optim.AdamW(params=model.parameters(), lr=learning_rate)tracked_losses =list()for step inrange(max_iters):if step % eval_iters ==0or step == max_iters-1: losses = estimate_loss() tracked_losses.append(losses)print('Step:',step,'Training Loss:',losses['train'],'Validation Loss:',losses['valid']) xb,yb = get_batch('train') logits, loss = model(xb,yb) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step()plot_losses(tracked_losses)context = torch.zeros((1,1), dtype=torch.long, device=device)print(decode(model.generate(context,max_new_tokens=500)[0].tolist()))
Step: 0 Training Loss: tensor(4.6328) Validation Loss: tensor(4.6313)
Step: 200 Training Loss: tensor(2.5782) Validation Loss: tensor(2.5969)
Step: 400 Training Loss: tensor(2.4491) Validation Loss: tensor(2.4365)
Step: 600 Training Loss: tensor(2.3560) Validation Loss: tensor(2.3455)
Step: 800 Training Loss: tensor(2.2816) Validation Loss: tensor(2.2922)
Step: 1000 Training Loss: tensor(2.2414) Validation Loss: tensor(2.2609)
Step: 1200 Training Loss: tensor(2.2245) Validation Loss: tensor(2.2473)
Step: 1400 Training Loss: tensor(2.1878) Validation Loss: tensor(2.2126)
Step: 1600 Training Loss: tensor(2.1557) Validation Loss: tensor(2.1949)
Step: 1800 Training Loss: tensor(2.1444) Validation Loss: tensor(2.1952)
Step: 2000 Training Loss: tensor(2.1448) Validation Loss: tensor(2.1569)
Step: 2200 Training Loss: tensor(2.1297) Validation Loss: tensor(2.1741)
Step: 2400 Training Loss: tensor(2.0952) Validation Loss: tensor(2.1558)
Step: 2600 Training Loss: tensor(2.0832) Validation Loss: tensor(2.1392)
Step: 2800 Training Loss: tensor(2.0740) Validation Loss: tensor(2.1216)
Step: 3000 Training Loss: tensor(2.0602) Validation Loss: tensor(2.1131)
Step: 3200 Training Loss: tensor(2.0669) Validation Loss: tensor(2.1428)
Step: 3400 Training Loss: tensor(2.0427) Validation Loss: tensor(2.0881)
Step: 3600 Training Loss: tensor(2.0371) Validation Loss: tensor(2.1069)
Step: 3800 Training Loss: tensor(2.0253) Validation Loss: tensor(2.1075)
Step: 4000 Training Loss: tensor(2.0300) Validation Loss: tensor(2.1037)
Step: 4200 Training Loss: tensor(2.0191) Validation Loss: tensor(2.0958)
Step: 4400 Training Loss: tensor(2.0207) Validation Loss: tensor(2.0896)
Step: 4600 Training Loss: tensor(1.9983) Validation Loss: tensor(2.0888)
Step: 4800 Training Loss: tensor(1.9998) Validation Loss: tensor(2.0826)
Step: 4999 Training Loss: tensor(1.9828) Validation Loss: tensor(2.0681)
KING RIVAR:
I will to lay ble
HAPOMENBELA:
And thruans that hands:
Waither us his vet?
MEXENDEL:
Warch, my feans' to zokn he oursertef it her than welll butes is eesen cin latistlivilv the do kine nown is wace!
lill dise littius, on him speage aissell, yet lord.
I mame, this down'st you, thee killo Wicho dhat evings to thed suis Then, it he poorter,-; the day danter firf sorre;
I therf threy fleront than Pried by of.
HENNG ERLANCE:
YO:
Ard all his a for huin cour ay and your to-chan the!
J
CPU times: user 3min 17s, sys: 490 ms, total: 3min 18s
Wall time: 3min 17s
This looks much better than our last run without the residual layers which had a loss of Validation Loss: tensor(2.4430) and it also beats the previous run before that had a los of Validation Loss: tensor(2.2720) with a final loss of Validation Loss: tensor(2.0940). Also, as you can see the text output, while still gibberish, is much better than in all previous runs.
The second trick that helps with training deep neural nets, in addition to residual blocks, is the Norm as depicted on the block which in our case is layer norm. Let’s implement and add that. link
class LayerNorm:def__init__(self, dim:int, eps:float=1e-5):self.dim = dimself.eps = epsself.gamma = torch.ones(dim)self.beta = torch.zeros(dim)def__call__(self, x): x_mean = x.mean(dim=1, keepdim=True) # layer mean x_variance = x.var(dim=1, keepdim=True) # layer variance x_hat = (x - x_mean) / torch.sqrt(x_variance +self.eps) # normalize to the unit varianceself.out =self.gamma * x_hat +self.betareturnself.outdef parameters(self):return [self.gamma, self.beta]
Since the original attention is all you need paper came out, it has become more common to apply the norm prior to the blocks instead of after them with the add as is depicted on the transformer architecture diagram. We will follow what common practice is today. Also instead of using the layer norm we developed, we will use the Pytorch version instead.
class TransformerBlock(nn.Module):"""Transformer Block: Communication folled by computation."""def__init__(self, embedding_dim:int=embedding_dim, context_length:int=context_length, num_heads:int=4):#embedding_dim: embedding dimension, num_heads: the number of heads that we wantsuper().__init__()self.embedding_dim = embedding_dimself.context_length = context_lengthself.head_size = embedding_dim // num_headsself.num_heads = num_headsself.multi_self_attention_heads_layer = MultiHeadAttention(num_heads=self.num_heads, head_size=self.head_size, embedding_dim=embedding_dim, context_length=context_length)self.feed_forward_network = FeedForwardNetwork(embedding_dim=self.embedding_dim)self.layer_norm_1 = nn.LayerNorm(normalized_shape=self.embedding_dim) #NEWself.layer_norm_2 = nn.LayerNorm(normalized_shape=self.embedding_dim) #NEWdef forward(self, x):# return self.feed_forward_network(self.multi_self_attention_heads_layer(x)) x = x +self.multi_self_attention_heads_layer(self.layer_norm_1(x)) # added layer norm. UPDATED x = x +self.feed_forward_network(self.layer_norm_2(x)) # added layer norm. UPDATEDreturn x
These layer norms are applied to each token embedding to ensure they start off having a unit gausian at initialization, but because of the trainable parameters, this may change during training.
We also need to add a layer norm after the last transformer block and before the last linear layer. Now let’s train the model again and see how it does.
%%time#| output: truetorch.manual_seed(TORCH_SEED)class TransformerLanguageModel(nn.Module):def__init__(self, vocab_size:int=vocab_size, embedding_dim:int=embedding_dim, context_length:int=context_length, head_size:int=head_size):super().__init__()self.vocab_size = vocab_sizeself.embedding_dim = embedding_dimself.context_length = context_lengthself.head_size = head_size#This will be our lookup table for embeddings. We'll have an entry for each token (aka vocab size) and each embedding will... #...be a vector of dimension embedding_dim.self.token_embedding_table = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.embedding_dim)self.token_position_embedding_table = nn.Embedding(num_embeddings=self.context_length, embedding_dim=self.embedding_dim)self.transformer_blocks = nn.Sequential( TransformerBlock(embedding_dim=embedding_dim, num_heads=4, context_length=context_length), TransformerBlock(embedding_dim=embedding_dim, num_heads=4, context_length=context_length), TransformerBlock(embedding_dim=embedding_dim, num_heads=4, context_length=context_length), nn.LayerNorm(embedding_dim), #NEW )self.language_model_head_linear_layer = nn.Linear(in_features=self.embedding_dim, out_features=self.vocab_size)def forward(self, idx, targets=None):#Both idx and targets are (B,T) Batch x Time array of integers B,T = idx.shape token_embeddings =self.token_embedding_table(idx) #(B,T,C) Batch, Time, Channel token_position_embeddings =self.token_position_embedding_table(torch.arange(T, device=device)) #(T,C) x = token_embeddings + token_position_embeddings x =self.transformer_blocks(x) logits =self.language_model_head_linear_layer(x) #(B,T,C) Where C is now token logits of size vocab_sizeif targets isnotNone: B,T,C = logits.shape logits_reshaped = logits.view(B*T,C) targets_reshaped = targets.view(B*T) loss = F.cross_entropy(input=logits_reshaped, target=targets_reshaped)else: loss=Nonereturn logits, lossdef generate(self, idx, max_new_tokens):#idx is (B,T) array of indicies in the current contextfor _ inrange(max_new_tokens):#Crop idx to the max size of our positional embeddings table idx_crop = idx[:,-self.context_length:]#Get predictions logits, loss =self(idx_crop)#Get the last time step from logits where the dimensions of the logits are (B,T,C) logits_last_timestep = logits[:,-1,:] #Becomes (B,C)# print('Shape of logits_last_timestep:',logits_last_timestep.shape) #confirming shape#Apply softmax to get probabilities probs = F.softmax(input=logits_last_timestep, dim=-1) #(B,C)# print('Shape of probs:', probs.shape) #confirming shape#Sample from the probs distribution. idx_next = torch.multinomial(input=probs, num_samples=1) #(B,1) Returns (B,idxs) where idxs are random integer indicies.# print('Shape of idx_next:',idx_next.shape,'and contents:',idx_next) #look at the shape and contents of idx_next#Append the sampled indexes idx_next to idx idx = torch.cat((idx, idx_next), dim=1) #(B, T+1)return idxmodel = TransformerLanguageModel(vocab_size=vocab_size, embedding_dim=embedding_dim, context_length=context_length)model = model.to(device)optimizer = torch.optim.AdamW(params=model.parameters(), lr=learning_rate)tracked_losses =list()for step inrange(max_iters):if step % eval_iters ==0or step == max_iters-1: losses = estimate_loss() tracked_losses.append(losses)print('Step:',step,'Training Loss:',losses['train'],'Validation Loss:',losses['valid']) xb,yb = get_batch('train') logits, loss = model(xb,yb) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step()plot_losses(tracked_losses)context = torch.zeros((1,1), dtype=torch.long, device=device)print(decode(model.generate(context,max_new_tokens=500)[0].tolist()))
Step: 0 Training Loss: tensor(4.3103) Validation Loss: tensor(4.3100)
Step: 200 Training Loss: tensor(2.6644) Validation Loss: tensor(2.6888)
Step: 400 Training Loss: tensor(2.4590) Validation Loss: tensor(2.4470)
Step: 600 Training Loss: tensor(2.3602) Validation Loss: tensor(2.3479)
Step: 800 Training Loss: tensor(2.2801) Validation Loss: tensor(2.2854)
Step: 1000 Training Loss: tensor(2.2313) Validation Loss: tensor(2.2563)
Step: 1200 Training Loss: tensor(2.2185) Validation Loss: tensor(2.2377)
Step: 1400 Training Loss: tensor(2.1741) Validation Loss: tensor(2.2103)
Step: 1600 Training Loss: tensor(2.1425) Validation Loss: tensor(2.1853)
Step: 1800 Training Loss: tensor(2.1290) Validation Loss: tensor(2.1792)
Step: 2000 Training Loss: tensor(2.1295) Validation Loss: tensor(2.1381)
Step: 2200 Training Loss: tensor(2.1140) Validation Loss: tensor(2.1594)
Step: 2400 Training Loss: tensor(2.0825) Validation Loss: tensor(2.1407)
Step: 2600 Training Loss: tensor(2.0727) Validation Loss: tensor(2.1325)
Step: 2800 Training Loss: tensor(2.0618) Validation Loss: tensor(2.1148)
Step: 3000 Training Loss: tensor(2.0459) Validation Loss: tensor(2.1033)
Step: 3200 Training Loss: tensor(2.0515) Validation Loss: tensor(2.1216)
Step: 3400 Training Loss: tensor(2.0321) Validation Loss: tensor(2.0743)
Step: 3600 Training Loss: tensor(2.0179) Validation Loss: tensor(2.0913)
Step: 3800 Training Loss: tensor(2.0171) Validation Loss: tensor(2.0952)
Step: 4000 Training Loss: tensor(2.0151) Validation Loss: tensor(2.0876)
Step: 4200 Training Loss: tensor(1.9998) Validation Loss: tensor(2.0803)
Step: 4400 Training Loss: tensor(2.0134) Validation Loss: tensor(2.0872)
Step: 4600 Training Loss: tensor(1.9862) Validation Loss: tensor(2.0807)
Step: 4800 Training Loss: tensor(1.9923) Validation Loss: tensor(2.0776)
Step: 4999 Training Loss: tensor(1.9644) Validation Loss: tensor(2.0590)
Will be Roridce.
STAOLOLIO:
KI a set bube to takegry.
MBROKING
My LANGANGENV KINCE:
that dight ane away, my feans' to zormuse off Lroof is here vail; dight,
Whiiree,
You, will is therev the do;
Whe now oir wans!
Al lind teal.
-huch courly speap; airse, why.
Herents norfore elguls;
Protle, demees kneoul-wou what eiich o' maits, rive ceessience poor gier; thume known,
refter so;
Angatt must wity ale of whith Pried by of.
HKING ESTEL:
Prisar adaid the Edwart hiin courchard ny ity to chan the whi
CPU times: user 3min 38s, sys: 526 ms, total: 3min 38s
Wall time: 3min 38s
The loss is now down to Validation Loss: tensor(2.0630) from Validation Loss: tensor(2.0940) during the last run.
%reset -f#| output: true
8 Scaling Up
Now that we have a fully functioning transformer network, to achieve better performance, we need to scale up. We’ll be doing a bit of code cleanup and refactoring as we scale up the architecture. To make things easier to follow, I’ve reset the kernel so we’ll be re-declaring everything again from scratch.
import torchimport torch.nn as nnfrom torch.nn import functional as Fimport matplotlib.pyplot as plt#Hyperparametersbatch_size =64#Number of token chunks per batch #UPDATEDcontext_length =256#Length of the token chunks. Andrej called this block size #UPDATEDembedding_dim =384#The vector size of the token embeddings. Andrej used n_embed as the variable name. #UPDATEDhead_size1 =16#Self attention head sizenum_layers =6#Number of transformer block layers # NEWnum_heads =6# NEWlearning_rate =3e-4#UPDATEDdropout =0.2#NEWmax_iters =5000#Number of training iterations or steps. eval_interval =500#Number of steps between evaluating the validation set to see how our validation loss is doing. #UPDATEDeval_iters =200#Number of steps to do on the validation set per each interval. We do more than 1 to get a more accurate overall valid lossdevice ='cuda'if torch.cuda.is_available() else'cpu'#Instead of using the cpu, we'll use the GPU if it's availble.TORCH_SEED =1337torch.manual_seed(TORCH_SEED)#Datasetwithopen('input.txt','r',encoding='utf-8') as f: text = f.read()vocab =sorted(list(set(text))) #Called chars in the video, but vocab is a more generic term. Both are correct.vocab_size =len(vocab)char2idx = {char:idx for idx,char inenumerate(vocab)}idx2char = {idx:char for char,idx in char2idx.items()}encode =lambda x: [char2idx[char] for char in x]decode =lambda idxs: ''.join([idx2char[idx] for idx in idxs])tokenized_text = torch.tensor(encode(text),dtype=torch.long)#Train / Valid split.train_test_split_idx =int(len(tokenized_text) *0.9)train_data = tokenized_text[:train_test_split_idx]valid_data = tokenized_text[train_test_split_idx:]
def plot_losses(losses): train_losses = [o['train'] for o in losses if o.get('train') isnotNone] valid_losses = [o['valid'] for o in losses if o.get('valid') isnotNone] plt.plot(train_losses, label='Training Loss') plt.plot(valid_losses, label='Validation Loss') plt.ylabel('Loss') plt.title('Losses') plt.legend() plt.show()
def get_batch(split:str, batch_size:int=batch_size, context_length:int=context_length):#Function to get a batch of data from the train or valid dataset data = train_data if split =='train'else valid_data idxs = torch.randint(low=0, high=len(data)-context_length, size=(batch_size,)) x = torch.stack([data[idx:idx+context_length] for idx in idxs]) y = torch.stack([data[idx+1:idx+context_length+1] for idx in idxs]) x,y = x.to(device), y.to(device) #Send data to the GPU if availablereturn x,y@torch.no_grad()def estimate_loss(): out = {} model.eval()for split in ['train','valid']: losses = torch.zeros(eval_iters)for k inrange(eval_iters): x_batch, y_batch = get_batch(split) logits, loss = model(x_batch, y_batch) losses[k] = loss.item() out[split] = losses.mean() model.train()return out
Adding dropout
class FeedForwardNetwork(nn.Module):"""A simple linear network followed by a non-linearity"""def__init__(self, embedding_dim:int=embedding_dim, dropout:float=dropout):super().__init__()self.embedding_dim = embedding_dimself.dropout = dropoutself.ffn = nn.Sequential( nn.Linear(in_features=self.embedding_dim, out_features=self.embedding_dim*4), nn.ReLU(), nn.Linear(in_features=self.embedding_dim*4, out_features=self.embedding_dim), nn.Dropout(dropout), #NEW )def forward(self, x):returnself.ffn(x)
Adding dropout
class Head(nn.Module):""" one head of self attention """def__init__(self, head_size:int, embedding_dim:int=embedding_dim, context_length:int=context_length, dropout:float=dropout):super().__init__()self.embedding_dim = embedding_dimself.head_size = head_sizeself.context_length = context_lengthself.dropout = dropoutself.key_layer = nn.Linear(in_features=self.embedding_dim, out_features=self.head_size, bias=False)self.query_layer = nn.Linear(in_features=self.embedding_dim, out_features=self.head_size, bias=False)self.value_layer = nn.Linear(in_features=self.embedding_dim, out_features=self.head_size, bias=False)self.register_buffer('tril', torch.tril(torch.ones((self.context_length, self.context_length))))self.dropout_layer = nn.Dropout(dropout) # NEWdef forward(self, x):# input of size (batch, time-step, channels)# output of size (batch, time-step, head size) B,T,C = x.shapeassert T <=self.context_lengthassert C ==self.embedding_dim q =self.query_layer(x) #(B,T,C) (batch size, context length, head_size) k =self.key_layer(x) #(B,T,C) (batch size, context length, head_size) v =self.value_layer(x) #(B,T,C) (batch size, context length, head_size)#compute scores based on affinities weights = (q @ k.transpose(-2,-1)) *self.head_size**-0.5# (B,T,C) @ (B,C,T) -> (B,T,T) #FIXED ^-.5 is 1/sqrt() so need to mult, not div weights = weights.masked_fill(self.tril[:T,:T] ==0, float('-inf')) #(B,T,T) weights = F.softmax(input=weights, dim=-1) #(B,T,T) weights =self.dropout_layer(weights) # NEW#perform weighted aggragation of the values out = weights @ v # (B,T,T) @ (B,T,C) -> (B,T,C)return out# Head()(x)
Adding dropout
class MultiHeadAttention(nn.Module):def__init__(self, num_heads:int, head_size:int, embedding_dim:int=embedding_dim, context_length:int=context_length, dropout:float=dropout):super().__init__()self.num_heads = num_headsself.head_size = head_sizeself.embedding_dim = embedding_dimself.context_length = context_lengthself.dropout = dropoutself.heads = nn.ModuleList([ Head(head_size=self.head_size, embedding_dim=self.embedding_dim, context_length=self.context_length) for _ inrange(self.num_heads)])self.projection_layer = nn.Linear(in_features=self.embedding_dim, out_features=self.embedding_dim, bias=True)self.dropout_layer = nn.Dropout(dropout) # NEWdef forward(self, x): out = torch.cat([h(x) for h inself.heads], dim=-1) out =self.projection_layer(out) out =self.dropout_layer(out) # NEWreturn out
class TransformerBlock(nn.Module):"""Transformer Block: Communication folled by computation."""def__init__(self, num_heads:int, embedding_dim:int=embedding_dim, context_length:int=context_length, dropout:float=dropout): #UPDATED#embedding_dim: embedding dimension, num_heads: the number of heads that we wantsuper().__init__()self.embedding_dim = embedding_dimself.context_length = context_lengthself.head_size = embedding_dim // num_headsself.num_heads = num_headsself.dropout = dropout # NEWself.multi_self_attention_heads_layer = MultiHeadAttention(num_heads=num_heads, head_size=self.head_size, embedding_dim=embedding_dim, context_length=context_length, dropout=dropout) #UPDATEDself.feed_forward_network = FeedForwardNetwork(embedding_dim=self.embedding_dim, dropout=dropout)self.layer_norm_1 = nn.LayerNorm(normalized_shape=self.embedding_dim)self.layer_norm_2 = nn.LayerNorm(normalized_shape=self.embedding_dim)def forward(self, x): x = x +self.multi_self_attention_heads_layer(self.layer_norm_1(x)) x = x +self.feed_forward_network(self.layer_norm_2(x))return x
%%time#| output: truetorch.manual_seed(TORCH_SEED)class TransformerLanguageModel(nn.Module):def__init__(self, head_size:int, vocab_size:int=vocab_size, embedding_dim:int=embedding_dim, context_length:int=context_length, num_layers:int=num_layers, dropout:float=dropout, num_heads:int=num_heads): #UPDATEDsuper().__init__()self.vocab_size = vocab_sizeself.embedding_dim = embedding_dimself.context_length = context_lengthself.head_size = head_sizeself.num_layers = num_layers #NEWself.dropout = dropoutself.token_embedding_table = nn.Embedding(num_embeddings=self.vocab_size, embedding_dim=self.embedding_dim)self.token_position_embedding_table = nn.Embedding(num_embeddings=self.context_length, embedding_dim=self.embedding_dim)self.transformer_blocks = nn.Sequential(*([ TransformerBlock(embedding_dim=embedding_dim, num_heads=num_heads, context_length=context_length, dropout=self.dropout) for _ inrange(self.num_layers)]+[ nn.LayerNorm(embedding_dim) ])) #UPDATEDself.language_model_head_linear_layer = nn.Linear(in_features=self.embedding_dim, out_features=self.vocab_size)def forward(self, idx, targets=None):#Both idx and targets are (B,T) Batch x Time array of integers B,T = idx.shape token_embeddings =self.token_embedding_table(idx) #(B,T,C) Batch, Time, Channel token_position_embeddings =self.token_position_embedding_table(torch.arange(T, device=device)) #(T,C) x = token_embeddings + token_position_embeddings x =self.transformer_blocks(x) logits =self.language_model_head_linear_layer(x) #(B,T,C) Where C is now token logits of size vocab_sizeif targets isnotNone: B,T,C = logits.shape logits_reshaped = logits.view(B*T,C) targets_reshaped = targets.view(B*T) loss = F.cross_entropy(input=logits_reshaped, target=targets_reshaped)else: loss=Nonereturn logits, lossdef generate(self, idx, max_new_tokens):#idx is (B,T) array of indicies in the current contextfor _ inrange(max_new_tokens):#Crop idx to the max size of our positional embeddings table idx_crop = idx[:,-self.context_length:]#Get predictions logits, loss =self(idx_crop)#Get the last time step from logits where the dimensions of the logits are (B,T,C) logits_last_timestep = logits[:,-1,:] #Becomes (B,C)#Apply softmax to get probabilities probs = F.softmax(input=logits_last_timestep, dim=-1) #(B,C)#Sample from the probs distribution. idx_next = torch.multinomial(input=probs, num_samples=1) #(B,1) Returns (B,idxs) where idxs are random integer indicies.#Append the sampled indexes idx_next to idx idx = torch.cat((idx, idx_next), dim=1) #(B, T+1)return idxmodel = TransformerLanguageModel(head_size=head_size1, vocab_size=vocab_size, embedding_dim=embedding_dim, context_length=context_length)model = model.to(device)optimizer = torch.optim.AdamW(params=model.parameters(), lr=learning_rate)tracked_losses =list()for step inrange(max_iters):if step % eval_iters ==0or step == max_iters-1: losses = estimate_loss() tracked_losses.append(losses)print('Step:',step,'Training Loss:',round(losses['train'].item(),3),'Validation Loss:',round(losses['valid'].item(),3)) xb,yb = get_batch('train') logits, loss = model(xb,yb) optimizer.zero_grad(set_to_none=True) loss.backward() optimizer.step()plot_losses(tracked_losses)context = torch.zeros((1,1), dtype=torch.long, device=device)print(decode(model.generate(context,max_new_tokens=500)[0].tolist()))
Step: 0 Training Loss: 4.285 Validation Loss: 4.282
Step: 200 Training Loss: 2.393 Validation Loss: 2.414
Step: 400 Training Loss: 2.021 Validation Loss: 2.097
Step: 600 Training Loss: 1.774 Validation Loss: 1.912
Step: 800 Training Loss: 1.632 Validation Loss: 1.808
Step: 1000 Training Loss: 1.533 Validation Loss: 1.719
Step: 1200 Training Loss: 1.465 Validation Loss: 1.671
Step: 1400 Training Loss: 1.409 Validation Loss: 1.611
Step: 1600 Training Loss: 1.368 Validation Loss: 1.596
Step: 1800 Training Loss: 1.337 Validation Loss: 1.567
Step: 2000 Training Loss: 1.309 Validation Loss: 1.55
Step: 2200 Training Loss: 1.281 Validation Loss: 1.523
Step: 2400 Training Loss: 1.26 Validation Loss: 1.514
Step: 2600 Training Loss: 1.239 Validation Loss: 1.503
Step: 2800 Training Loss: 1.227 Validation Loss: 1.509
Step: 3000 Training Loss: 1.203 Validation Loss: 1.497
Step: 3200 Training Loss: 1.185 Validation Loss: 1.492
Step: 3400 Training Loss: 1.169 Validation Loss: 1.487
Step: 3600 Training Loss: 1.151 Validation Loss: 1.484
Step: 3800 Training Loss: 1.138 Validation Loss: 1.485
Step: 4000 Training Loss: 1.122 Validation Loss: 1.477
Step: 4200 Training Loss: 1.105 Validation Loss: 1.479
Step: 4400 Training Loss: 1.09 Validation Loss: 1.487
Step: 4600 Training Loss: 1.077 Validation Loss: 1.487
Step: 4800 Training Loss: 1.066 Validation Loss: 1.492
Step: 4999 Training Loss: 1.048 Validation Loss: 1.494
Thy women divorcuse and me, whereof, if you live,
Here overthrives be gentle climber, thy ball;
My cripation, Tybalt of face and my hand,
That is our hately made for requends:
Your conquers, my suffice shive, to my service doubt
To whom I am life and tafe malice thus
Ere you all not with this householy persive of true:
Your mirrous arm'd, when, they to say it at
The world's love not takes me drid. Coriolanus: all
they not to the remedier's small flap
To liese as drively in answer'd: any whom
Hav
CPU times: user 13min 56s, sys: 2min 50s, total: 16min 47s
Wall time: 16min 45s
The results are starting to look pretty decent. The loss has dramatically improved. Scaling up the network has made a big difference.
Debugging Models Aside: While trying to train the model I realized I had made a mistake in the code. I tried to train this several times but loss would not drop below about 2.4. I went back through my code and nothing obvious was standing out to me that was wrong. After an hour of scouring my code I finally found the issue. It was a subtle change, but made all the difference. This was my code before the fix:
When we first implemented dividing the weights by the head size to fix the issue with large numbers passing through softmax, I had been using the .../math.sqrt(head_size). To try and make things more consistent with Andrej’s code and to remove the requirement to import math I switched the implementation over to the way he was doing it which was taking the head_size to the power of 0.5 which is equivalent to the sqrt, but I had missed that it was a -0.5 which is equivalent to 1/sqrt(0.5) so instead of dividing by self.head_size**-0.5 I should have been multiplying by it. This can be one of the big challenges in deep learning. Often times when you make a mistake, no error is thrown, it just doesn’t work. Sometimes it affects the results by a little and other times it affects it by a lot which is what happened in my case. Thankfully I had benchmarks to compare my result against so it was clear that I had an implementation issue, not that the model I was using was incapable of getting better results.
9 Conclusion
In this notebook we have built a transformer model based on the Attention Is All You Need paper following along with Andrej Karpathy’s fantastic YouTube video: Let’s build GPT: from scratch, in code, spelled out. While building out the transformer we tried to build an intuition on what makes the transformer work. I hope you found the format of this notebook useful, adding and modifying the code as we went, enabling you to follow along, run the code and see the output with each step.