add an RNN and a GRU language model
parent
6694b67d37
commit
b697f434bc
116
makemore.py
116
makemore.py
|
@ -111,8 +111,8 @@ class Block(nn.Module):
|
|||
x = x + self.mlpf(self.ln_2(x))
|
||||
return x
|
||||
|
||||
class GPT(nn.Module):
|
||||
""" GPT Language Model """
|
||||
class Transformer(nn.Module):
|
||||
""" Transformer Language Model, exactly as seen in GPT-2 """
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
@ -155,6 +155,99 @@ class GPT(nn.Module):
|
|||
|
||||
return logits, loss
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
"""
|
||||
Recurrent Neural Net language model: either a vanilla RNN recurrence or a GRU.
|
||||
Did not implement an LSTM because its API is a bit more annoying as it has
|
||||
both a hidden state and a cell state, but it's very similar to GRU and in
|
||||
practice works just as well.
|
||||
"""
|
||||
|
||||
class RNNCell(nn.Module):
|
||||
"""
|
||||
the job of a 'Cell' is to:
|
||||
take input at current time step x_{t} and the hidden state at the
|
||||
previous time step h_{t-1} and return the resulting hidden state
|
||||
h_{t} at the current timestep
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.xh_to_h = nn.Linear(config.n_embd + config.n_embd2, config.n_embd2)
|
||||
|
||||
def forward(self, xt, hprev):
|
||||
xh = torch.cat([xt, hprev], dim=1)
|
||||
ht = F.tanh(self.xh_to_h(xh))
|
||||
return ht
|
||||
|
||||
class GRUCell(nn.Module):
|
||||
"""
|
||||
same job as RNN cell, but a bit more complicated recurrence formula
|
||||
that makes the GRU more expressive and easier to optimize.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
# input, forget, output, gate
|
||||
self.xh_to_z = nn.Linear(config.n_embd + config.n_embd2, config.n_embd2)
|
||||
self.xh_to_r = nn.Linear(config.n_embd + config.n_embd2, config.n_embd2)
|
||||
self.xh_to_hbar = nn.Linear(config.n_embd + config.n_embd2, config.n_embd2)
|
||||
|
||||
def forward(self, xt, hprev):
|
||||
# first use the reset gate to wipe some channels of the hidden state to zero
|
||||
xh = torch.cat([xt, hprev], dim=1)
|
||||
r = F.sigmoid(self.xh_to_r(xh))
|
||||
hprev_reset = r * hprev
|
||||
# calculate the candidate new hidden state hbar
|
||||
xhr = torch.cat([xt, hprev_reset], dim=1)
|
||||
hbar = F.tanh(self.xh_to_hbar(xhr))
|
||||
# calculate the switch gate that determines if each channel should be updated at all
|
||||
z = F.sigmoid(self.xh_to_z(xh))
|
||||
# blend the previous hidden state and the new candidate hidden state
|
||||
ht = (1 - z) * hprev + z * hbar
|
||||
return ht
|
||||
|
||||
class RNN(nn.Module):
|
||||
|
||||
def __init__(self, config, cell_type):
|
||||
super().__init__()
|
||||
self.block_size = config.block_size
|
||||
self.vocab_size = config.vocab_size
|
||||
self.start = nn.Parameter(torch.zeros(1, config.n_embd2)) # the starting hidden state
|
||||
self.wte = nn.Embedding(config.vocab_size, config.n_embd) # token embeddings table
|
||||
if cell_type == 'rnn':
|
||||
self.cell = RNNCell(config)
|
||||
elif cell_type == 'gru':
|
||||
self.cell = GRUCell(config)
|
||||
self.lm_head = nn.Linear(config.n_embd2, self.vocab_size)
|
||||
|
||||
def get_block_size(self):
|
||||
return self.block_size
|
||||
|
||||
def forward(self, idx, targets=None):
|
||||
device = idx.device
|
||||
b, t = idx.size()
|
||||
|
||||
# embed all the integers up front and all at once for efficiency
|
||||
emb = self.wte(idx) # (b, t, n_embd)
|
||||
|
||||
# sequentially iterate over the inputs and update the RNN state each tick
|
||||
hprev = self.start.expand((b, -1)) # expand out the batch dimension
|
||||
hiddens = []
|
||||
for i in range(t):
|
||||
xt = emb[:, i, :] # (b, n_embd)
|
||||
ht = self.cell(xt, hprev) # (b, n_embd2)
|
||||
hiddens.append(ht)
|
||||
|
||||
# decode the outputs
|
||||
hidden = torch.stack(hiddens, 1) # (b, t, n_embd2)
|
||||
logits = self.lm_head(hidden)
|
||||
|
||||
# if we are given some desired targets also calculate the loss
|
||||
loss = None
|
||||
if targets is not None:
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
|
||||
return logits, loss
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# MLP language model
|
||||
|
||||
|
@ -418,7 +511,7 @@ if __name__ == '__main__':
|
|||
# sampling
|
||||
parser.add_argument('--top-k', type=int, default=-1, help="top-k for sampling, -1 means no top-k")
|
||||
# model
|
||||
parser.add_argument('--type', type=str, default='bigram', help="model class type to use, bigram|mlp|transformer")
|
||||
parser.add_argument('--type', type=str, default='bigram', help="model class type to use, bigram|mlp|rnn|gru|transformer")
|
||||
parser.add_argument('--n-layer', type=int, default=4, help="number of layers")
|
||||
parser.add_argument('--n-head', type=int, default=4, help="number of heads (in a transformer)")
|
||||
parser.add_argument('--n-embd', type=int, default=64, help="number of feature channels in the model")
|
||||
|
@ -446,11 +539,18 @@ if __name__ == '__main__':
|
|||
config = ModelConfig(vocab_size=vocab_size, block_size=block_size,
|
||||
n_layer=args.n_layer, n_head=args.n_head,
|
||||
n_embd=args.n_embd, n_embd2=args.n_embd2)
|
||||
model = {
|
||||
'transformer': GPT,
|
||||
'bigram': Bigram,
|
||||
'mlp': MLP,
|
||||
}[args.type](config)
|
||||
if args.type == 'transformer':
|
||||
model = Transformer(config)
|
||||
elif args.type == 'bigram':
|
||||
model = Bigram(config)
|
||||
elif args.type == 'mlp':
|
||||
model = MLP(config)
|
||||
elif args.type == 'rnn':
|
||||
model = RNN(config, cell_type='rnn')
|
||||
elif args.type == 'gru':
|
||||
model = RNN(config, cell_type='gru')
|
||||
else:
|
||||
raise ValueError(f'model type {args.type} is not recognized')
|
||||
model.to(args.device)
|
||||
print(f"model #params: {sum(p.numel() for p in model.parameters())}")
|
||||
if args.resume or args.sample_only: # note: if we sample-only then we also assume we are resuming
|
||||
|
|
Loading…
Reference in New Issue