Fix bug in RNN where hprev always referred to start. Change so that hprev refers to output of previous cell

pull/3/head
Norman Yu 2022-09-15 18:10:19 +08:00
parent 2f5e8d746e
commit bf38625014
1 changed files with 1 additions and 0 deletions

View File

@ -330,6 +330,7 @@ class RNN(nn.Module):
for i in range(t):
xt = emb[:, i, :] # (b, n_embd)
ht = self.cell(xt, hprev) # (b, n_embd2)
hprev = ht
hiddens.append(ht)
# decode the outputs