add a bag of words model that looks suspiciously similar to a transformer ;)

morelm
Andrej Karpathy 2022-08-21 20:18:20 -07:00
parent b697f434bc
commit c079e1ce76
1 changed files with 99 additions and 2 deletions

View File

@ -155,6 +155,101 @@ class Transformer(nn.Module):
return logits, loss
# -----------------------------------------------------------------------------
# Bag of Words (BoW) language model
class CausalBoW(nn.Module):
"""
Causal bag of words. Averages the preceding elements and looks suspiciously like
a CausalAttention module you'd find in a transformer, for no apparent reason at all ;)
"""
def __init__(self, config):
super().__init__()
# used to mask out vectors and preserve autoregressive property
self.block_size = config.block_size
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, config.block_size, config.block_size))
def forward(self, x):
B, T, C = x.size() # batch size, sequence length, n_embd
# do the weighted average of all preceeding token features
att = torch.zeros((B, T, T), device=x.device)
att = att.masked_fill(self.bias[:,:T,:T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
y = att @ x # (B, T, T) x (B, T, C) -> (B, T, C)
return y
class BoWBlock(nn.Module):
""" collects BoW features and adds an MLP """
def __init__(self, config):
super().__init__()
# Causal BoW module
self.cbow = CausalBoW(config)
# MLP assembler
self.mlp = nn.ModuleDict(dict(
c_fc = nn.Linear(config.n_embd, config.n_embd2),
c_proj = nn.Linear(config.n_embd2, config.n_embd),
))
m = self.mlp
self.mlpf = lambda x: m.c_proj(F.tanh(m.c_fc(x))) # MLP forward
def forward(self, x):
x = x + self.cbow(x)
x = x + self.mlpf(x)
return x
class BoW(nn.Module):
"""
takes the previous block_size tokens, encodes them with a lookup table,
also encodes their positions with lookup table, then averages all of those
embeddings up and uses that to predict the next token.
"""
def __init__(self, config):
super().__init__()
self.block_size = config.block_size
self.vocab_size = config.vocab_size
# token embedding
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
# position embedding
self.wpe = nn.Embedding(config.block_size, config.n_embd)
# context block
self.context_block = BoWBlock(config)
# language model head decoder layer
self.lm_head = nn.Linear(config.n_embd, self.vocab_size)
def get_block_size(self):
return self.block_size
def forward(self, idx, targets=None):
device = idx.device
b, t = idx.size()
assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
# forward the token and position embedding layers
tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd)
pos_emb = self.wpe(pos) # position embeddings of shape (1, t, n_embd)
# add and run through the decoder MLP
x = tok_emb + pos_emb
# run the bag of words context module
x = self.context_block(x)
# decode to next token probability
logits = self.lm_head(x)
# if we are given some desired targets also calculate the loss
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
return logits, loss
# -----------------------------------------------------------------------------
"""
Recurrent Neural Net language model: either a vanilla RNN recurrence or a GRU.
@ -268,7 +363,7 @@ class MLP(nn.Module):
# +1 in the line above for a special <BLANK> token that gets inserted if encoding a token
# before the beginning of the input sequence
self.mlp = nn.Sequential(
nn.Linear(self.block_size * config.n_embd, config.n_embd2), # TODO: option to vary this
nn.Linear(self.block_size * config.n_embd, config.n_embd2),
nn.Tanh(),
nn.Linear(config.n_embd2, self.vocab_size)
)
@ -511,7 +606,7 @@ if __name__ == '__main__':
# sampling
parser.add_argument('--top-k', type=int, default=-1, help="top-k for sampling, -1 means no top-k")
# model
parser.add_argument('--type', type=str, default='bigram', help="model class type to use, bigram|mlp|rnn|gru|transformer")
parser.add_argument('--type', type=str, default='transformer', help="model class type to use, bigram|mlp|rnn|gru|bow|transformer")
parser.add_argument('--n-layer', type=int, default=4, help="number of layers")
parser.add_argument('--n-head', type=int, default=4, help="number of heads (in a transformer)")
parser.add_argument('--n-embd', type=int, default=64, help="number of feature channels in the model")
@ -549,6 +644,8 @@ if __name__ == '__main__':
model = RNN(config, cell_type='rnn')
elif args.type == 'gru':
model = RNN(config, cell_type='gru')
elif args.type == 'bow':
model = BoW(config)
else:
raise ValueError(f'model type {args.type} is not recognized')
model.to(args.device)