Merge pull request #3 from normanyu/fix-rnn-prev-state

Fix bug in RNN where hprev always referred to start.
pull/5/head
Andrej 2022-09-15 08:26:41 -07:00 committed by GitHub
commit f61811b994
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 0 deletions

View File

@ -330,6 +330,7 @@ class RNN(nn.Module):
for i in range(t):
xt = emb[:, i, :] # (b, n_embd)
ht = self.cell(xt, hprev) # (b, n_embd2)
hprev = ht
hiddens.append(ht)
# decode the outputs