big refactor to make easier and api agree with mingpt more

morelm
Andrej Karpathy 2022-08-19 22:29:58 +00:00
parent 054568ec24
commit 013af92770
1 changed files with 117 additions and 151 deletions

View File

@ -2,6 +2,14 @@
you give this script some words (one per line) and it will generate more things like it.
uses super state of the art Transformer AI tech
this code is intended to be super hackable. tune it to your needs.
Changes from minGPT:
- I removed the from_pretrained function where we init with GPT2 weights
- I removed dropout layers because the models we train here are small,
it's not necessary to understand at this stage and at this scale.
- I removed weight decay and all of the complexity around what parameters are
and are not weight decayed. I don't believe this should make a massive
difference at the scale that we operate on here.
"""
import os
@ -20,7 +28,7 @@ from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter
# -----------------------------------------------------------------------------
# GPT (PyTorch) model definition
# GPT model definition
@dataclass
class GPTConfig:
@ -30,17 +38,21 @@ class GPTConfig:
n_embd: int = 64
vocab_size: int = None
block_size: int = None
# regularization
embd_pdrop: float = 0.1
resid_pdrop:float = 0.1
attn_pdrop:float = 0.1
@dataclass
class TrainConfig:
# optimization parameters
learning_rate: float = 5e-4
weight_decay: float = 0.1 # only applied on matmul weights
betas: List[float] = (0.9, 0.99)
weight_decay: float = 0.01
class NewGELU(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
"""
def forward(self, x):
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
class CausalSelfAttention(nn.Module):
"""
@ -52,38 +64,34 @@ class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embd % config.n_head == 0
# key, query, value projections for all heads
self.key = nn.Linear(config.n_embd, config.n_embd)
self.query = nn.Linear(config.n_embd, config.n_embd)
self.value = nn.Linear(config.n_embd, config.n_embd)
# regularization
self.attn_drop = nn.Dropout(config.attn_pdrop)
self.resid_drop = nn.Dropout(config.resid_pdrop)
# key, query, value projections for all heads, but in a batch
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
# output projection
self.proj = nn.Linear(config.n_embd, config.n_embd)
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
# causal mask to ensure that attention is only applied to the left in the input sequence
self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
self.n_head = config.n_head
self.n_embd = config.n_embd
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q, k ,v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
# output projection
y = self.resid_drop(self.proj(y))
y = self.c_proj(y)
return y
class Block(nn.Module):
@ -91,113 +99,78 @@ class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.n_embd)
self.ln2 = nn.LayerNorm(config.n_embd)
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.mlp = nn.Sequential(
nn.Linear(config.n_embd, 4 * config.n_embd),
nn.GELU(),
nn.Linear(4 * config.n_embd, config.n_embd),
nn.Dropout(config.resid_pdrop),
)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = nn.ModuleDict(dict(
c_fc = nn.Linear(config.n_embd, 4 * config.n_embd),
c_proj = nn.Linear(4 * config.n_embd, config.n_embd),
act = NewGELU(),
))
m = self.mlp
self.mlpf = lambda x: m.c_proj(m.act(m.c_fc(x))) # MLP forward
def forward(self, x):
x = x + self.attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
x = x + self.attn(self.ln_1(x))
x = x + self.mlpf(self.ln_2(x))
return x
class GPT(nn.Module):
""" the full GPT language model, with a context size of block_size """
""" GPT Language Model """
def __init__(self, config):
super().__init__()
# input embedding stem
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
self.drop = nn.Dropout(config.embd_pdrop)
# transformer
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
# decoder head
self.ln_f = nn.LayerNorm(config.n_embd)
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
assert config.vocab_size is not None
assert config.block_size is not None
self.block_size = config.block_size
self.transformer = nn.ModuleDict(dict(
wte = nn.Embedding(config.vocab_size, config.n_embd),
wpe = nn.Embedding(config.block_size, config.n_embd),
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
ln_f = nn.LayerNorm(config.n_embd),
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# init all weights
self.apply(self._init_weights)
print("number of parameters: %d" % sum(p.numel() for p in self.parameters()))
def get_block_size(self):
return self.block_size
# report number of parameters (note we don't count the decoder parameters in lm_head)
n_params = sum(p.numel() for p in self.transformer.parameters())
print("number of parameters: %.2fM" % (n_params/1e6,))
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
# TODO is this function needed?
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if isinstance(module, nn.Linear) and module.bias is not None:
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
elif isinstance(module, GPT):
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
def configure_optimizers(self, train_config):
"""
This long function is unfortunately doing something very simple and is being very defensive:
We are separating out all parameters of the model into two buckets: those that will experience
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
We are then returning the PyTorch optimizer object.
"""
# separate out all parameters to those that will and won't experience regularizing weight decay
decay = set()
no_decay = set()
whitelist_weight_modules = (torch.nn.Linear, )
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
for mn, m in self.named_modules():
for pn, p in m.named_parameters():
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
if pn.endswith('bias'):
# all biases will not be decayed
no_decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
# weights of whitelist modules will be weight decayed
decay.add(fpn)
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
# weights of blacklist modules will NOT be weight decayed
no_decay.add(fpn)
# special case the position embedding parameter in the root GPT module as not decayed
no_decay.add('pos_emb')
# validate that we considered every parameter
param_dict = {pn: p for pn, p in self.named_parameters()}
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
% (str(param_dict.keys() - union_params), )
# create the pytorch optimizer object
optim_groups = [
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
]
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
optimizer = torch.optim.AdamW(self.parameters(), lr=train_config.learning_rate,
betas=train_config.betas, weight_decay=train_config.weight_decay,
eps=1e-8)
return optimizer
def forward(self, idx, targets=None):
device = idx.device
b, t = idx.size()
assert t <= self.block_size, "Cannot forward, input tensor sequence is longer than model block_size."
assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
# forward the GPT model
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
x = self.drop(token_embeddings + position_embeddings)
x = self.blocks(x)
x = self.ln_f(x)
logits = self.head(x)
# forward the GPT model itself
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
x = tok_emb + pos_emb
for block in self.transformer.h:
x = block(x)
x = self.transformer.ln_f(x)
logits = self.lm_head(x)
# if we are given some desired targets also calculate the loss
loss = None
@ -206,52 +179,45 @@ class GPT(nn.Module):
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
# forward the model to get the logits for the index in the sequence
logits, _ = self(idx_cond)
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
# either sample from the distribution or take the most likely element
if do_sample:
idx_next = torch.multinomial(probs, num_samples=1)
else:
_, idx_next = torch.topk(probs, k=1, dim=-1)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)
return idx
# -----------------------------------------------------------------------------
# helper functions for evaluating and sampling from the model
def top_k_logits(logits, k):
"""
takes logits (N,D) and returns logits (N,D), but in each row only
the top-k are kept and the rest of the entries are set to -inf.
"""
v, ix = torch.topk(logits, k, dim=-1)
out = logits.clone()
out[out < v[:, [-1]]] = -float('Inf')
return out
@torch.inference_mode()
def sample(model, x, steps, temperature=1.0, top_k=None):
"""
take a conditioning sequence of indices in x (b,t) and predict next tokens
in the sequence, feeding the predictions back into the model each step.
"""
block_size = model.get_block_size()
model.eval()
for k in range(steps):
# crop the context, if necessary
x_cond = x if x.size(1) <= block_size else x[:, -block_size:]
# feed the context into the model to get logits at each step
logits, _ = model(x_cond)
# pluck the logits at the final step and scale by temperature
logits = logits[:, -1, :] / temperature
# optionally crop probabilities to only the top k options
if top_k is not None:
logits = top_k_logits(logits, top_k)
# apply softmax to convert to probabilities
probs = F.softmax(logits, dim=-1)
# sample from the distribution or take the most likely
ix = torch.multinomial(probs, num_samples=1)
# append to the sequence and continue
x = torch.cat((x, ix), dim=1)
model.train()
return x
def print_samples(num=10):
""" samples from the model and pretty prints the decoded samples """
X_init = torch.zeros(num, 1, dtype=torch.long).to(args.device)
top_k = args.top_k if args.top_k != -1 else None
steps = train_dataset.get_output_length() - 1 # -1 because we already start with <START> token (index 0)
X_samp = sample(model, X_init, steps, top_k=top_k).to('cpu')
X_samp = model.generate(X_init, steps, top_k=top_k, do_sample=True).to('cpu')
train_samples, test_samples, new_samples = [], [], []
for i in range(X_samp.size(0)):
# get the i'th row of sampled integers, as python list
@ -367,14 +333,14 @@ if __name__ == '__main__':
# parse command line args
parser = argparse.ArgumentParser(description="Make More")
# system/input/output
parser.add_argument('--input-file', '-i', type=str, default='input.txt', help="input file with things one per line")
parser.add_argument('--input-file', '-i', type=str, default='names.txt', help="input file with things one per line")
parser.add_argument('--work-dir', '-o', type=str, default='out', help="output working directory")
parser.add_argument('--resume', action='store_true', help="when this flag is used, we will resume optimization from existing model in the workdir")
parser.add_argument('--num-workers', '-n', type=int, default=1, help="number of data workers for both train/test")
parser.add_argument('--device', type=str, default='cpu', help="device to use for compute, e.g. cpu|cuda|mps")
parser.add_argument('--seed', type=int, default=1337, help="seed")
# sampling
parser.add_argument('--sample-only', action='store_true', help="just sample from the model and quit, don't train")
parser.add_argument('--num-workers', '-n', type=int, default=0, help="number of data workers for both train/test")
parser.add_argument('--device', type=str, default='cpu', help="device to use for compute, examples: cpu|cuda|cuda:2|mps")
parser.add_argument('--seed', type=int, default=3407, help="seed")
# sampling
parser.add_argument('--top-k', type=int, default=-1, help="top-k for sampling, -1 means no top-k")
# model
parser.add_argument('--n-layer', type=int, default=4, help="number of layers in the transformer")
@ -383,25 +349,25 @@ if __name__ == '__main__':
# optimization
parser.add_argument('--batch-size', '-b', type=int, default=32, help="batch size during optimization")
parser.add_argument('--learning-rate', '-l', type=float, default=5e-4, help="learning rate")
parser.add_argument('--dropout', '-d', type=float, default=0.1, help="dropout rate")
parser.add_argument('--weight-decay', '-w', type=float, default=0.1, help="weight decay")
parser.add_argument('--weight-decay', '-w', type=float, default=0.01, help="weight decay")
args = parser.parse_args()
print(vars(args))
# system inits
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.use_deterministic_algorithms(True)
os.makedirs(args.work_dir, exist_ok=True)
writer = SummaryWriter(log_dir=args.work_dir)
# init datasets
train_dataset, test_dataset = create_datasets(args.input_file)
vocab_size = train_dataset.get_vocab_size()
block_size = train_dataset.get_output_length()
print(f"dataset determined that: {vocab_size=}, {block_size=}")
# init model
config = GPTConfig(vocab_size=train_dataset.get_vocab_size(), block_size=train_dataset.get_output_length(),
n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd,
embd_pdrop=args.dropout, attn_pdrop=args.dropout, resid_pdrop=args.dropout)
config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd)
model = GPT(config)
model.to(args.device)
print(f"model #params: {sum(p.numel() for p in model.parameters())}")