477 lines
20 KiB
Python
477 lines
20 KiB
Python
"""
|
|
you give this script some words (one per line) and it will generate more things like it.
|
|
uses super state of the art Transformer AI tech
|
|
this code is intended to be super hackable. tune it to your needs.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import time
|
|
import math
|
|
import argparse
|
|
from dataclasses import dataclass
|
|
from typing import List
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import functional as F
|
|
from torch.utils.data import Dataset
|
|
from torch.utils.data.dataloader import DataLoader
|
|
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# GPT (PyTorch) model definition
|
|
|
|
@dataclass
|
|
class GPTConfig:
|
|
# size of the model
|
|
n_layer: int = 4
|
|
n_head: int = 4
|
|
n_embd: int = 64
|
|
vocab_size: int = None
|
|
block_size: int = None
|
|
# regularization
|
|
embd_pdrop: float = 0.1
|
|
resid_pdrop:float = 0.1
|
|
attn_pdrop:float = 0.1
|
|
|
|
@dataclass
|
|
class TrainConfig:
|
|
# optimization parameters
|
|
learning_rate: float = 5e-4
|
|
weight_decay: float = 0.1 # only applied on matmul weights
|
|
betas: List[float] = (0.9, 0.99)
|
|
|
|
class CausalSelfAttention(nn.Module):
|
|
"""
|
|
A vanilla multi-head masked self-attention layer with a projection at the end.
|
|
It is possible to use torch.nn.MultiheadAttention here but I am including an
|
|
explicit implementation here to show that there is nothing too scary here.
|
|
"""
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
assert config.n_embd % config.n_head == 0
|
|
# key, query, value projections for all heads
|
|
self.key = nn.Linear(config.n_embd, config.n_embd)
|
|
self.query = nn.Linear(config.n_embd, config.n_embd)
|
|
self.value = nn.Linear(config.n_embd, config.n_embd)
|
|
# regularization
|
|
self.attn_drop = nn.Dropout(config.attn_pdrop)
|
|
self.resid_drop = nn.Dropout(config.resid_pdrop)
|
|
# output projection
|
|
self.proj = nn.Linear(config.n_embd, config.n_embd)
|
|
# causal mask to ensure that attention is only applied to the left in the input sequence
|
|
self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
|
|
.view(1, 1, config.block_size, config.block_size))
|
|
self.n_head = config.n_head
|
|
|
|
def forward(self, x):
|
|
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
|
|
|
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
|
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
|
|
|
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
|
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
|
|
att = F.softmax(att, dim=-1)
|
|
att = self.attn_drop(att)
|
|
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
|
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
|
|
|
# output projection
|
|
y = self.resid_drop(self.proj(y))
|
|
return y
|
|
|
|
class Block(nn.Module):
|
|
""" an unassuming Transformer block """
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
self.ln1 = nn.LayerNorm(config.n_embd)
|
|
self.ln2 = nn.LayerNorm(config.n_embd)
|
|
self.attn = CausalSelfAttention(config)
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(config.n_embd, 4 * config.n_embd),
|
|
nn.GELU(),
|
|
nn.Linear(4 * config.n_embd, config.n_embd),
|
|
nn.Dropout(config.resid_pdrop),
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = x + self.attn(self.ln1(x))
|
|
x = x + self.mlp(self.ln2(x))
|
|
return x
|
|
|
|
class GPT(nn.Module):
|
|
""" the full GPT language model, with a context size of block_size """
|
|
|
|
def __init__(self, config):
|
|
super().__init__()
|
|
|
|
# input embedding stem
|
|
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
|
|
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
|
|
self.drop = nn.Dropout(config.embd_pdrop)
|
|
# transformer
|
|
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
|
# decoder head
|
|
self.ln_f = nn.LayerNorm(config.n_embd)
|
|
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
|
|
|
self.block_size = config.block_size
|
|
self.apply(self._init_weights)
|
|
|
|
print("number of parameters: %d" % sum(p.numel() for p in self.parameters()))
|
|
|
|
def get_block_size(self):
|
|
return self.block_size
|
|
|
|
def _init_weights(self, module):
|
|
if isinstance(module, (nn.Linear, nn.Embedding)):
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
if isinstance(module, nn.Linear) and module.bias is not None:
|
|
torch.nn.init.zeros_(module.bias)
|
|
elif isinstance(module, nn.LayerNorm):
|
|
torch.nn.init.zeros_(module.bias)
|
|
torch.nn.init.ones_(module.weight)
|
|
elif isinstance(module, GPT):
|
|
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
|
|
|
def configure_optimizers(self, train_config):
|
|
"""
|
|
This long function is unfortunately doing something very simple and is being very defensive:
|
|
We are separating out all parameters of the model into two buckets: those that will experience
|
|
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
|
|
We are then returning the PyTorch optimizer object.
|
|
"""
|
|
|
|
# separate out all parameters to those that will and won't experience regularizing weight decay
|
|
decay = set()
|
|
no_decay = set()
|
|
whitelist_weight_modules = (torch.nn.Linear, )
|
|
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
|
for mn, m in self.named_modules():
|
|
for pn, p in m.named_parameters():
|
|
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
|
|
|
if pn.endswith('bias'):
|
|
# all biases will not be decayed
|
|
no_decay.add(fpn)
|
|
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
|
# weights of whitelist modules will be weight decayed
|
|
decay.add(fpn)
|
|
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
|
# weights of blacklist modules will NOT be weight decayed
|
|
no_decay.add(fpn)
|
|
|
|
# special case the position embedding parameter in the root GPT module as not decayed
|
|
no_decay.add('pos_emb')
|
|
|
|
# validate that we considered every parameter
|
|
param_dict = {pn: p for pn, p in self.named_parameters()}
|
|
inter_params = decay & no_decay
|
|
union_params = decay | no_decay
|
|
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
|
|
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
|
|
% (str(param_dict.keys() - union_params), )
|
|
|
|
# create the pytorch optimizer object
|
|
optim_groups = [
|
|
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
|
|
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
|
]
|
|
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
|
|
return optimizer
|
|
|
|
def forward(self, idx, targets=None):
|
|
b, t = idx.size()
|
|
assert t <= self.block_size, "Cannot forward, input tensor sequence is longer than model block_size."
|
|
|
|
# forward the GPT model
|
|
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
|
|
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
|
|
x = self.drop(token_embeddings + position_embeddings)
|
|
x = self.blocks(x)
|
|
x = self.ln_f(x)
|
|
logits = self.head(x)
|
|
|
|
# if we are given some desired targets also calculate the loss
|
|
loss = None
|
|
if targets is not None:
|
|
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
|
|
|
return logits, loss
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# helper functions for evaluating and sampling from the model
|
|
|
|
def top_k_logits(logits, k):
|
|
"""
|
|
takes logits (N,D) and returns logits (N,D), but in each row only
|
|
the top-k are kept and the rest of the entries are set to -inf.
|
|
"""
|
|
v, ix = torch.topk(logits, k, dim=-1)
|
|
out = logits.clone()
|
|
out[out < v[:, [-1]]] = -float('Inf')
|
|
return out
|
|
|
|
@torch.inference_mode()
|
|
def sample(model, x, steps, temperature=1.0, top_k=None):
|
|
"""
|
|
take a conditioning sequence of indices in x (b,t) and predict next tokens
|
|
in the sequence, feeding the predictions back into the model each step.
|
|
"""
|
|
block_size = model.get_block_size()
|
|
model.eval()
|
|
for k in range(steps):
|
|
# crop the context, if necessary
|
|
x_cond = x if x.size(1) <= block_size else x[:, -block_size:]
|
|
# feed the context into the model to get logits at each step
|
|
logits, _ = model(x_cond)
|
|
# pluck the logits at the final step and scale by temperature
|
|
logits = logits[:, -1, :] / temperature
|
|
# optionally crop probabilities to only the top k options
|
|
if top_k is not None:
|
|
logits = top_k_logits(logits, top_k)
|
|
# apply softmax to convert to probabilities
|
|
probs = F.softmax(logits, dim=-1)
|
|
# sample from the distribution or take the most likely
|
|
ix = torch.multinomial(probs, num_samples=1)
|
|
# append to the sequence and continue
|
|
x = torch.cat((x, ix), dim=1)
|
|
model.train()
|
|
return x
|
|
|
|
def print_samples(num=10):
|
|
""" samples from the model and pretty prints the decoded samples """
|
|
X_init = torch.zeros(num, 1, dtype=torch.long).to(args.device)
|
|
top_k = args.top_k if args.top_k != -1 else None
|
|
steps = train_dataset.get_output_length() - 1 # -1 because we already start with <START> token (index 0)
|
|
X_samp = sample(model, X_init, steps, top_k=top_k).to('cpu')
|
|
unique_samples = []
|
|
had_samples = []
|
|
for i in range(X_samp.size(0)):
|
|
# get the i'th row of sampled integers, as python list
|
|
row = X_samp[i, 1:].tolist() # note: we need to crop out the first <START> token
|
|
# token 0 is the <STOP> token, so we crop the output sequence at that point
|
|
crop_index = row.index(0) if 0 in row else len(row)
|
|
row = row[:crop_index]
|
|
word_samp = train_dataset.decode(row)
|
|
# separately track samples that we have and have not seen before
|
|
word_have = train_dataset.contains(word_samp) or test_dataset.contains(word_samp)
|
|
sample_list = had_samples if word_have else unique_samples
|
|
sample_list.append(word_samp)
|
|
|
|
print('-'*80)
|
|
print(f'{len(had_samples)} Samples that were found in input dataset:')
|
|
for word in had_samples:
|
|
print(word)
|
|
print(f'{len(unique_samples)} Samples that were NOT found in input dataset:')
|
|
for word in unique_samples:
|
|
print(word)
|
|
print('-'*80)
|
|
|
|
@torch.inference_mode()
|
|
def evaluate(model, dataset, batch_size=50, max_batches=None):
|
|
model.eval()
|
|
loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=0)
|
|
losses = []
|
|
for i, batch in enumerate(loader):
|
|
batch = [t.to(args.device) for t in batch]
|
|
X, Y = batch
|
|
logits, loss = model(X, Y)
|
|
losses.append(loss.item())
|
|
if max_batches is not None and i >= max_batches:
|
|
break
|
|
mean_loss = torch.tensor(losses).mean().item()
|
|
model.train() # reset model back to training mode
|
|
return mean_loss
|
|
|
|
# -----------------------------------------------------------------------------
|
|
# helper functions for creating the training and test Datasets that emit words
|
|
|
|
class CharDataset(Dataset):
|
|
|
|
def __init__(self, words, chars, max_word_length):
|
|
self.words = words
|
|
self.chars = chars
|
|
self.max_word_length = max_word_length
|
|
self.stoi = {ch:i+1 for i,ch in enumerate(chars)}
|
|
self.itos = {i:s for s,i in self.stoi.items()} # inverse mapping
|
|
|
|
def __len__(self):
|
|
return len(self.words)
|
|
|
|
def contains(self, word):
|
|
return word in self.words
|
|
|
|
def get_vocab_size(self):
|
|
return len(self.chars) + 1 # all the possible characters and special 0 token
|
|
|
|
def get_output_length(self):
|
|
return self.max_word_length + 1 # <START> token followed by words
|
|
|
|
def encode(self, word):
|
|
ix = torch.tensor([self.stoi[w] for w in word], dtype=torch.long)
|
|
return ix
|
|
|
|
def decode(self, ix):
|
|
word = ''.join(self.itos[i] for i in ix)
|
|
return word
|
|
|
|
def __getitem__(self, idx):
|
|
word = self.words[idx]
|
|
ix = self.encode(word)
|
|
x = torch.zeros(self.max_word_length + 1, dtype=torch.long)
|
|
y = torch.zeros(self.max_word_length + 1, dtype=torch.long)
|
|
x[1:1+len(ix)] = ix
|
|
y[:len(ix)] = ix
|
|
y[len(ix)+1:] = -1 # index -1 will mask the loss at the inactive locations
|
|
return x, y
|
|
|
|
def create_datasets(input_file):
|
|
|
|
# preprocessing of the input text file
|
|
with open(input_file, 'r') as f:
|
|
data = f.read()
|
|
words = data.splitlines()
|
|
words = [w.strip() for w in words] # get rid of any leading or trailing white space
|
|
words = [w for w in words if w] # get rid of any empty strings
|
|
chars = sorted(list(set(''.join(words)))) # all the possible characters
|
|
max_word_length = max(len(w) for w in words)
|
|
print(f"number of examples in the dataset: {len(words)}")
|
|
print(f"max word length: {max_word_length}")
|
|
print(f"number of unique characters in the vocabulary: {len(chars)}")
|
|
print("vocabulary:")
|
|
print(''.join(chars))
|
|
|
|
# partition the input data into a training and the test set
|
|
test_set_size = min(1000, int(len(words) * 0.1)) # 10% of the training set, or up to 1000 examples
|
|
rp = torch.randperm(len(words)).tolist()
|
|
train_words = [words[i] for i in rp[:-test_set_size]]
|
|
test_words = [words[i] for i in rp[-test_set_size:]]
|
|
print(f"split up the dataset into {len(train_words)} training examples and {len(test_words)} test examples")
|
|
|
|
# wrap in dataset objects
|
|
train_dataset = CharDataset(train_words, chars, max_word_length)
|
|
test_dataset = CharDataset(test_words, chars, max_word_length)
|
|
|
|
return train_dataset, test_dataset
|
|
|
|
# -----------------------------------------------------------------------------
|
|
if __name__ == '__main__':
|
|
|
|
# parse command line args
|
|
parser = argparse.ArgumentParser(description="Make More")
|
|
# system/input/output
|
|
parser.add_argument('--input-file', '-i', type=str, default='input.txt', help="input file with things one per line")
|
|
parser.add_argument('--work-dir', '-o', type=str, default='out', help="output working directory")
|
|
parser.add_argument('--resume', action='store_true', help="when this flag is used, we will resume optimization from existing model in the workdir")
|
|
parser.add_argument('--num-workers', '-n', type=int, default=1, help="number of data workers for both train/test")
|
|
parser.add_argument('--device', type=str, default='cpu', help="device to use for compute, e.g. cpu|cuda|m1")
|
|
parser.add_argument('--seed', type=int, default=1337, help="seed")
|
|
# sampling
|
|
parser.add_argument('--sample-only', action='store_true', help="just sample from the model and quit, don't train")
|
|
parser.add_argument('--top-k', type=int, default=-1, help="top-k for sampling, -1 means no top-k")
|
|
# model
|
|
parser.add_argument('--n-layer', type=int, default=4, help="number of layers in the transformer")
|
|
parser.add_argument('--n-head', type=int, default=4, help="number of heads in the transformer")
|
|
parser.add_argument('--n-embd', type=int, default=64, help="number of feature channels in the transformer")
|
|
# optimization
|
|
parser.add_argument('--batch-size', '-b', type=int, default=32, help="batch size during optimization")
|
|
parser.add_argument('--learning-rate', '-l', type=float, default=5e-4, help="learning rate")
|
|
parser.add_argument('--dropout', '-d', type=float, default=0.1, help="dropout rate")
|
|
parser.add_argument('--weight-decay', '-w', type=float, default=0.1, help="weight decay")
|
|
args = parser.parse_args()
|
|
print(vars(args))
|
|
|
|
# system inits
|
|
torch.manual_seed(args.seed)
|
|
torch.cuda.manual_seed_all(args.seed)
|
|
torch.use_deterministic_algorithms(True)
|
|
os.makedirs(args.work_dir, exist_ok=True)
|
|
writer = SummaryWriter(log_dir=args.work_dir)
|
|
|
|
# init datasets
|
|
train_dataset, test_dataset = create_datasets(args.input_file)
|
|
|
|
# init model
|
|
config = GPTConfig(vocab_size=train_dataset.get_vocab_size(), block_size=train_dataset.get_output_length(),
|
|
n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd,
|
|
embd_pdrop=args.dropout, attn_pdrop=args.dropout, resid_pdrop=args.dropout)
|
|
model = GPT(config)
|
|
model.to(args.device)
|
|
print(f"model #params: {sum(p.numel() for p in model.parameters())}")
|
|
if args.resume or args.sample_only: # note: if we sample-only then we also assume we're resuming
|
|
print("resuming from existing model in the workdir")
|
|
model.load_state_dict(torch.load(os.path.join(args.work_dir, 'model.pt')))
|
|
if args.sample_only:
|
|
print_samples(num=50)
|
|
sys.exit()
|
|
|
|
# init optimizer
|
|
train_config = TrainConfig(learning_rate=args.learning_rate, weight_decay=args.weight_decay)
|
|
optimizer = model.configure_optimizers(train_config)
|
|
|
|
# init dataloader
|
|
train_sampler = torch.utils.data.RandomSampler(train_dataset, replacement=True, num_samples=int(1e10))
|
|
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True,
|
|
num_workers=args.num_workers, sampler=train_sampler)
|
|
data_iter = iter(train_loader)
|
|
|
|
# training loop
|
|
best_loss = None
|
|
step = 0
|
|
while True:
|
|
|
|
t0 = time.time()
|
|
# fetch the next batch and reset dataloader every epoch as necessary
|
|
try:
|
|
batch = next(data_iter)
|
|
except StopIteration:
|
|
data_iter = iter(train_loader)
|
|
batch = next(data_iter)
|
|
batch = [t.to(args.device) for t in batch]
|
|
X, Y = batch
|
|
|
|
# feed into the model
|
|
logits, loss = model(X, Y)
|
|
|
|
# calculate the gradient, clip it, update the weights
|
|
model.zero_grad(set_to_none=True)
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
|
optimizer.step()
|
|
if args.device != 'cpu':
|
|
torch.cuda.synchronize()
|
|
t1 = time.time()
|
|
|
|
# logging
|
|
if step % 10 == 0:
|
|
print(f"step {step} | loss {loss.item():.4f} | step time {(t1-t0)*1000:.2f}ms")
|
|
|
|
# evaluate the model
|
|
if step > 0 and step % 500 == 0:
|
|
train_loss = evaluate(model, train_dataset, batch_size=100, max_batches=10)
|
|
test_loss = evaluate(model, test_dataset, batch_size=100, max_batches=10)
|
|
writer.add_scalar("Loss/train", train_loss, step)
|
|
writer.add_scalar("Loss/test", test_loss, step)
|
|
writer.flush()
|
|
print(f"step {step} train loss: {train_loss} test loss: {test_loss}")
|
|
# save the model to disk if it has improved
|
|
if best_loss is None or test_loss < best_loss:
|
|
out_path = os.path.join(args.work_dir, "model.pt")
|
|
print(f"test loss {test_loss} is the best so far, saving model to {out_path}")
|
|
torch.save(model.state_dict(), out_path)
|
|
best_loss = test_loss
|
|
|
|
# sample from the model
|
|
if step > 0 and step % 200 == 0:
|
|
print_samples(num=10)
|
|
|
|
step += 1
|