add a bag of words model that looks suspiciously similar to a transformer ;)
parent
b697f434bc
commit
c079e1ce76
101
makemore.py
101
makemore.py
|
@ -155,6 +155,101 @@ class Transformer(nn.Module):
|
|||
|
||||
return logits, loss
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Bag of Words (BoW) language model
|
||||
|
||||
class CausalBoW(nn.Module):
|
||||
"""
|
||||
Causal bag of words. Averages the preceding elements and looks suspiciously like
|
||||
a CausalAttention module you'd find in a transformer, for no apparent reason at all ;)
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
# used to mask out vectors and preserve autoregressive property
|
||||
self.block_size = config.block_size
|
||||
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
||||
.view(1, config.block_size, config.block_size))
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.size() # batch size, sequence length, n_embd
|
||||
|
||||
# do the weighted average of all preceeding token features
|
||||
att = torch.zeros((B, T, T), device=x.device)
|
||||
att = att.masked_fill(self.bias[:,:T,:T] == 0, float('-inf'))
|
||||
att = F.softmax(att, dim=-1)
|
||||
y = att @ x # (B, T, T) x (B, T, C) -> (B, T, C)
|
||||
|
||||
return y
|
||||
|
||||
class BoWBlock(nn.Module):
|
||||
""" collects BoW features and adds an MLP """
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
# Causal BoW module
|
||||
self.cbow = CausalBoW(config)
|
||||
# MLP assembler
|
||||
self.mlp = nn.ModuleDict(dict(
|
||||
c_fc = nn.Linear(config.n_embd, config.n_embd2),
|
||||
c_proj = nn.Linear(config.n_embd2, config.n_embd),
|
||||
))
|
||||
m = self.mlp
|
||||
self.mlpf = lambda x: m.c_proj(F.tanh(m.c_fc(x))) # MLP forward
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.cbow(x)
|
||||
x = x + self.mlpf(x)
|
||||
return x
|
||||
|
||||
class BoW(nn.Module):
|
||||
"""
|
||||
takes the previous block_size tokens, encodes them with a lookup table,
|
||||
also encodes their positions with lookup table, then averages all of those
|
||||
embeddings up and uses that to predict the next token.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.block_size = config.block_size
|
||||
self.vocab_size = config.vocab_size
|
||||
# token embedding
|
||||
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
||||
# position embedding
|
||||
self.wpe = nn.Embedding(config.block_size, config.n_embd)
|
||||
# context block
|
||||
self.context_block = BoWBlock(config)
|
||||
# language model head decoder layer
|
||||
self.lm_head = nn.Linear(config.n_embd, self.vocab_size)
|
||||
|
||||
def get_block_size(self):
|
||||
return self.block_size
|
||||
|
||||
def forward(self, idx, targets=None):
|
||||
|
||||
device = idx.device
|
||||
b, t = idx.size()
|
||||
assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}"
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
||||
|
||||
# forward the token and position embedding layers
|
||||
tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd)
|
||||
pos_emb = self.wpe(pos) # position embeddings of shape (1, t, n_embd)
|
||||
# add and run through the decoder MLP
|
||||
x = tok_emb + pos_emb
|
||||
# run the bag of words context module
|
||||
x = self.context_block(x)
|
||||
# decode to next token probability
|
||||
logits = self.lm_head(x)
|
||||
|
||||
# if we are given some desired targets also calculate the loss
|
||||
loss = None
|
||||
if targets is not None:
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
|
||||
return logits, loss
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
"""
|
||||
Recurrent Neural Net language model: either a vanilla RNN recurrence or a GRU.
|
||||
|
@ -268,7 +363,7 @@ class MLP(nn.Module):
|
|||
# +1 in the line above for a special <BLANK> token that gets inserted if encoding a token
|
||||
# before the beginning of the input sequence
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(self.block_size * config.n_embd, config.n_embd2), # TODO: option to vary this
|
||||
nn.Linear(self.block_size * config.n_embd, config.n_embd2),
|
||||
nn.Tanh(),
|
||||
nn.Linear(config.n_embd2, self.vocab_size)
|
||||
)
|
||||
|
@ -511,7 +606,7 @@ if __name__ == '__main__':
|
|||
# sampling
|
||||
parser.add_argument('--top-k', type=int, default=-1, help="top-k for sampling, -1 means no top-k")
|
||||
# model
|
||||
parser.add_argument('--type', type=str, default='bigram', help="model class type to use, bigram|mlp|rnn|gru|transformer")
|
||||
parser.add_argument('--type', type=str, default='transformer', help="model class type to use, bigram|mlp|rnn|gru|bow|transformer")
|
||||
parser.add_argument('--n-layer', type=int, default=4, help="number of layers")
|
||||
parser.add_argument('--n-head', type=int, default=4, help="number of heads (in a transformer)")
|
||||
parser.add_argument('--n-embd', type=int, default=64, help="number of feature channels in the model")
|
||||
|
@ -549,6 +644,8 @@ if __name__ == '__main__':
|
|||
model = RNN(config, cell_type='rnn')
|
||||
elif args.type == 'gru':
|
||||
model = RNN(config, cell_type='gru')
|
||||
elif args.type == 'bow':
|
||||
model = BoW(config)
|
||||
else:
|
||||
raise ValueError(f'model type {args.type} is not recognized')
|
||||
model.to(args.device)
|
||||
|
|
Loading…
Reference in New Issue