Fix bug in RNN where hprev always referred to start. Change so that hprev refers to output of previous cell
parent
2f5e8d746e
commit
bf38625014
|
@ -330,6 +330,7 @@ class RNN(nn.Module):
|
|||
for i in range(t):
|
||||
xt = emb[:, i, :] # (b, n_embd)
|
||||
ht = self.cell(xt, hprev) # (b, n_embd2)
|
||||
hprev = ht
|
||||
hiddens.append(ht)
|
||||
|
||||
# decode the outputs
|
||||
|
|
Loading…
Reference in New Issue