Fix bug in RNN where hprev always referred to start. Change so that hprev refers to output of previous cell
parent
2f5e8d746e
commit
bf38625014
|
@ -330,6 +330,7 @@ class RNN(nn.Module):
|
||||||
for i in range(t):
|
for i in range(t):
|
||||||
xt = emb[:, i, :] # (b, n_embd)
|
xt = emb[:, i, :] # (b, n_embd)
|
||||||
ht = self.cell(xt, hprev) # (b, n_embd2)
|
ht = self.cell(xt, hprev) # (b, n_embd2)
|
||||||
|
hprev = ht
|
||||||
hiddens.append(ht)
|
hiddens.append(ht)
|
||||||
|
|
||||||
# decode the outputs
|
# decode the outputs
|
||||||
|
|
Loading…
Reference in New Issue