generalize makemore into other types of language models, and add bigram LM and an MLP LM

morelm
Andrej Karpathy 2022-08-21 17:53:52 -07:00
parent 50617fa75d
commit 6694b67d37
1 changed files with 133 additions and 41 deletions

View File

@ -28,16 +28,19 @@ from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter
# -----------------------------------------------------------------------------
# GPT model definition
@dataclass
class GPTConfig:
# size of the model
class ModelConfig:
block_size: int = None # length of the input sequences of integers
vocab_size: int = None # the input integers are in range [0 .. vocab_size -1]
# parameters below control the sizes of each model slightly differently
n_layer: int = 4
n_head: int = 4
n_embd: int = 64
vocab_size: int = None
block_size: int = None
n_embd2: int = 64
n_head: int = 4
# -----------------------------------------------------------------------------
# Transformer Language Model (*exactly* as used in GPT-2)
class NewGELU(nn.Module):
"""
@ -127,6 +130,9 @@ class GPT(nn.Module):
n_params = sum(p.numel() for p in self.transformer.parameters())
print("number of parameters: %.2fM" % (n_params/1e6,))
def get_block_size(self):
return self.block_size
def forward(self, idx, targets=None):
device = idx.device
b, t = idx.size()
@ -149,45 +155,124 @@ class GPT(nn.Module):
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
# forward the model to get the logits for the index in the sequence
logits, _ = self(idx_cond)
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
# either sample from the distribution or take the most likely element
if do_sample:
idx_next = torch.multinomial(probs, num_samples=1)
else:
_, idx_next = torch.topk(probs, k=1, dim=-1)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)
# -----------------------------------------------------------------------------
# MLP language model
return idx
class MLP(nn.Module):
"""
takes the previous block_size tokens, encodes them with a lookup table,
concatenates the vectors and predicts the next token with an MLP.
Reference:
Bengio et al. 2003 https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf
"""
def __init__(self, config):
super().__init__()
self.block_size = config.block_size
self.vocab_size = config.vocab_size
self.wte = nn.Embedding(config.vocab_size + 1, config.n_embd) # token embeddings table
# +1 in the line above for a special <BLANK> token that gets inserted if encoding a token
# before the beginning of the input sequence
self.mlp = nn.Sequential(
nn.Linear(self.block_size * config.n_embd, config.n_embd2), # TODO: option to vary this
nn.Tanh(),
nn.Linear(config.n_embd2, self.vocab_size)
)
def get_block_size(self):
return self.block_size
def forward(self, idx, targets=None):
# gather the word embeddings of the previous 3 words
embs = []
for k in range(self.block_size):
tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd)
idx = torch.roll(idx, 1, 1)
idx[:, 0] = self.vocab_size # special <BLANK> token
embs.append(tok_emb)
# concat all of the embeddings together and pass through an MLP
x = torch.cat(embs, -1) # (b, t, n_embd * block_size)
logits = self.mlp(x)
# if we are given some desired targets also calculate the loss
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return logits, loss
# -----------------------------------------------------------------------------
# Bigram language model
class Bigram(nn.Module):
"""
Bigram Language Model 'neural net', simply a lookup table of logits for the
next character given a previous character.
"""
def __init__(self, config):
super().__init__()
n = config.vocab_size
self.logits = nn.Parameter(torch.zeros((n, n)))
def get_block_size(self):
return 1 # this model only needs one previous character to predict the next
def forward(self, idx, targets=None):
# 'forward pass', lol
logits = self.logits[idx]
# if we are given some desired targets also calculate the loss
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return logits, loss
# -----------------------------------------------------------------------------
# helper functions for evaluating and sampling from the model
@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
block_size = model.get_block_size()
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:]
# forward the model to get the logits for the index in the sequence
logits, _ = model(idx_cond)
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
# either sample from the distribution or take the most likely element
if do_sample:
idx_next = torch.multinomial(probs, num_samples=1)
else:
_, idx_next = torch.topk(probs, k=1, dim=-1)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)
return idx
def print_samples(num=10):
""" samples from the model and pretty prints the decoded samples """
X_init = torch.zeros(num, 1, dtype=torch.long).to(args.device)
top_k = args.top_k if args.top_k != -1 else None
steps = train_dataset.get_output_length() - 1 # -1 because we already start with <START> token (index 0)
X_samp = model.generate(X_init, steps, top_k=top_k, do_sample=True).to('cpu')
X_samp = generate(model, X_init, steps, top_k=top_k, do_sample=True).to('cpu')
train_samples, test_samples, new_samples = [], [], []
for i in range(X_samp.size(0)):
# get the i'th row of sampled integers, as python list
@ -333,9 +418,11 @@ if __name__ == '__main__':
# sampling
parser.add_argument('--top-k', type=int, default=-1, help="top-k for sampling, -1 means no top-k")
# model
parser.add_argument('--n-layer', type=int, default=4, help="number of layers in the transformer")
parser.add_argument('--n-head', type=int, default=4, help="number of heads in the transformer")
parser.add_argument('--n-embd', type=int, default=64, help="number of feature channels in the transformer")
parser.add_argument('--type', type=str, default='bigram', help="model class type to use, bigram|mlp|transformer")
parser.add_argument('--n-layer', type=int, default=4, help="number of layers")
parser.add_argument('--n-head', type=int, default=4, help="number of heads (in a transformer)")
parser.add_argument('--n-embd', type=int, default=64, help="number of feature channels in the model")
parser.add_argument('--n-embd2', type=int, default=64, help="number of feature channels elsewhere in the model")
# optimization
parser.add_argument('--batch-size', '-b', type=int, default=32, help="batch size during optimization")
parser.add_argument('--learning-rate', '-l', type=float, default=5e-4, help="learning rate")
@ -356,9 +443,14 @@ if __name__ == '__main__':
print(f"dataset determined that: {vocab_size=}, {block_size=}")
# init model
config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd)
model = GPT(config)
config = ModelConfig(vocab_size=vocab_size, block_size=block_size,
n_layer=args.n_layer, n_head=args.n_head,
n_embd=args.n_embd, n_embd2=args.n_embd2)
model = {
'transformer': GPT,
'bigram': Bigram,
'mlp': MLP,
}[args.type](config)
model.to(args.device)
print(f"model #params: {sum(p.numel() for p in model.parameters())}")
if args.resume or args.sample_only: # note: if we sample-only then we also assume we are resuming