generalize makemore into other types of language models, and add bigram LM and an MLP LM
parent
50617fa75d
commit
6694b67d37
174
makemore.py
174
makemore.py
|
@ -28,16 +28,19 @@ from torch.utils.data.dataloader import DataLoader
|
|||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# GPT model definition
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
# size of the model
|
||||
class ModelConfig:
|
||||
block_size: int = None # length of the input sequences of integers
|
||||
vocab_size: int = None # the input integers are in range [0 .. vocab_size -1]
|
||||
# parameters below control the sizes of each model slightly differently
|
||||
n_layer: int = 4
|
||||
n_head: int = 4
|
||||
n_embd: int = 64
|
||||
vocab_size: int = None
|
||||
block_size: int = None
|
||||
n_embd2: int = 64
|
||||
n_head: int = 4
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Transformer Language Model (*exactly* as used in GPT-2)
|
||||
|
||||
class NewGELU(nn.Module):
|
||||
"""
|
||||
|
@ -127,6 +130,9 @@ class GPT(nn.Module):
|
|||
n_params = sum(p.numel() for p in self.transformer.parameters())
|
||||
print("number of parameters: %.2fM" % (n_params/1e6,))
|
||||
|
||||
def get_block_size(self):
|
||||
return self.block_size
|
||||
|
||||
def forward(self, idx, targets=None):
|
||||
device = idx.device
|
||||
b, t = idx.size()
|
||||
|
@ -149,45 +155,124 @@ class GPT(nn.Module):
|
|||
|
||||
return logits, loss
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
|
||||
"""
|
||||
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
|
||||
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
||||
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
||||
"""
|
||||
for _ in range(max_new_tokens):
|
||||
# if the sequence context is growing too long we must crop it at block_size
|
||||
idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
|
||||
# forward the model to get the logits for the index in the sequence
|
||||
logits, _ = self(idx_cond)
|
||||
# pluck the logits at the final step and scale by desired temperature
|
||||
logits = logits[:, -1, :] / temperature
|
||||
# optionally crop the logits to only the top k options
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, top_k)
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
# apply softmax to convert logits to (normalized) probabilities
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
# either sample from the distribution or take the most likely element
|
||||
if do_sample:
|
||||
idx_next = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
_, idx_next = torch.topk(probs, k=1, dim=-1)
|
||||
# append sampled index to the running sequence and continue
|
||||
idx = torch.cat((idx, idx_next), dim=1)
|
||||
# -----------------------------------------------------------------------------
|
||||
# MLP language model
|
||||
|
||||
return idx
|
||||
class MLP(nn.Module):
|
||||
"""
|
||||
takes the previous block_size tokens, encodes them with a lookup table,
|
||||
concatenates the vectors and predicts the next token with an MLP.
|
||||
|
||||
Reference:
|
||||
Bengio et al. 2003 https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.block_size = config.block_size
|
||||
self.vocab_size = config.vocab_size
|
||||
self.wte = nn.Embedding(config.vocab_size + 1, config.n_embd) # token embeddings table
|
||||
# +1 in the line above for a special <BLANK> token that gets inserted if encoding a token
|
||||
# before the beginning of the input sequence
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(self.block_size * config.n_embd, config.n_embd2), # TODO: option to vary this
|
||||
nn.Tanh(),
|
||||
nn.Linear(config.n_embd2, self.vocab_size)
|
||||
)
|
||||
|
||||
def get_block_size(self):
|
||||
return self.block_size
|
||||
|
||||
def forward(self, idx, targets=None):
|
||||
|
||||
# gather the word embeddings of the previous 3 words
|
||||
embs = []
|
||||
for k in range(self.block_size):
|
||||
tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd)
|
||||
idx = torch.roll(idx, 1, 1)
|
||||
idx[:, 0] = self.vocab_size # special <BLANK> token
|
||||
embs.append(tok_emb)
|
||||
|
||||
# concat all of the embeddings together and pass through an MLP
|
||||
x = torch.cat(embs, -1) # (b, t, n_embd * block_size)
|
||||
logits = self.mlp(x)
|
||||
|
||||
# if we are given some desired targets also calculate the loss
|
||||
loss = None
|
||||
if targets is not None:
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
|
||||
return logits, loss
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Bigram language model
|
||||
|
||||
class Bigram(nn.Module):
|
||||
"""
|
||||
Bigram Language Model 'neural net', simply a lookup table of logits for the
|
||||
next character given a previous character.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
n = config.vocab_size
|
||||
self.logits = nn.Parameter(torch.zeros((n, n)))
|
||||
|
||||
def get_block_size(self):
|
||||
return 1 # this model only needs one previous character to predict the next
|
||||
|
||||
def forward(self, idx, targets=None):
|
||||
|
||||
# 'forward pass', lol
|
||||
logits = self.logits[idx]
|
||||
|
||||
# if we are given some desired targets also calculate the loss
|
||||
loss = None
|
||||
if targets is not None:
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
|
||||
return logits, loss
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# helper functions for evaluating and sampling from the model
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(model, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
|
||||
"""
|
||||
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
|
||||
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
||||
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
||||
"""
|
||||
block_size = model.get_block_size()
|
||||
for _ in range(max_new_tokens):
|
||||
# if the sequence context is growing too long we must crop it at block_size
|
||||
idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:]
|
||||
# forward the model to get the logits for the index in the sequence
|
||||
logits, _ = model(idx_cond)
|
||||
# pluck the logits at the final step and scale by desired temperature
|
||||
logits = logits[:, -1, :] / temperature
|
||||
# optionally crop the logits to only the top k options
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, top_k)
|
||||
logits[logits < v[:, [-1]]] = -float('Inf')
|
||||
# apply softmax to convert logits to (normalized) probabilities
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
# either sample from the distribution or take the most likely element
|
||||
if do_sample:
|
||||
idx_next = torch.multinomial(probs, num_samples=1)
|
||||
else:
|
||||
_, idx_next = torch.topk(probs, k=1, dim=-1)
|
||||
# append sampled index to the running sequence and continue
|
||||
idx = torch.cat((idx, idx_next), dim=1)
|
||||
|
||||
return idx
|
||||
|
||||
def print_samples(num=10):
|
||||
""" samples from the model and pretty prints the decoded samples """
|
||||
X_init = torch.zeros(num, 1, dtype=torch.long).to(args.device)
|
||||
top_k = args.top_k if args.top_k != -1 else None
|
||||
steps = train_dataset.get_output_length() - 1 # -1 because we already start with <START> token (index 0)
|
||||
X_samp = model.generate(X_init, steps, top_k=top_k, do_sample=True).to('cpu')
|
||||
X_samp = generate(model, X_init, steps, top_k=top_k, do_sample=True).to('cpu')
|
||||
train_samples, test_samples, new_samples = [], [], []
|
||||
for i in range(X_samp.size(0)):
|
||||
# get the i'th row of sampled integers, as python list
|
||||
|
@ -333,9 +418,11 @@ if __name__ == '__main__':
|
|||
# sampling
|
||||
parser.add_argument('--top-k', type=int, default=-1, help="top-k for sampling, -1 means no top-k")
|
||||
# model
|
||||
parser.add_argument('--n-layer', type=int, default=4, help="number of layers in the transformer")
|
||||
parser.add_argument('--n-head', type=int, default=4, help="number of heads in the transformer")
|
||||
parser.add_argument('--n-embd', type=int, default=64, help="number of feature channels in the transformer")
|
||||
parser.add_argument('--type', type=str, default='bigram', help="model class type to use, bigram|mlp|transformer")
|
||||
parser.add_argument('--n-layer', type=int, default=4, help="number of layers")
|
||||
parser.add_argument('--n-head', type=int, default=4, help="number of heads (in a transformer)")
|
||||
parser.add_argument('--n-embd', type=int, default=64, help="number of feature channels in the model")
|
||||
parser.add_argument('--n-embd2', type=int, default=64, help="number of feature channels elsewhere in the model")
|
||||
# optimization
|
||||
parser.add_argument('--batch-size', '-b', type=int, default=32, help="batch size during optimization")
|
||||
parser.add_argument('--learning-rate', '-l', type=float, default=5e-4, help="learning rate")
|
||||
|
@ -356,9 +443,14 @@ if __name__ == '__main__':
|
|||
print(f"dataset determined that: {vocab_size=}, {block_size=}")
|
||||
|
||||
# init model
|
||||
config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
|
||||
n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd)
|
||||
model = GPT(config)
|
||||
config = ModelConfig(vocab_size=vocab_size, block_size=block_size,
|
||||
n_layer=args.n_layer, n_head=args.n_head,
|
||||
n_embd=args.n_embd, n_embd2=args.n_embd2)
|
||||
model = {
|
||||
'transformer': GPT,
|
||||
'bigram': Bigram,
|
||||
'mlp': MLP,
|
||||
}[args.type](config)
|
||||
model.to(args.device)
|
||||
print(f"model #params: {sum(p.numel() for p in model.parameters())}")
|
||||
if args.resume or args.sample_only: # note: if we sample-only then we also assume we are resuming
|
||||
|
|
Loading…
Reference in New Issue