fix comment

morelm
Andrej Karpathy 2022-08-20 01:29:27 +00:00
parent d4ede45208
commit 50617fa75d
1 changed files with 3 additions and 1 deletions

View File

@ -389,10 +389,12 @@ if __name__ == '__main__':
# feed into the model
logits, loss = model(X, Y)
# calculate the gradient, clip it, update the weights
# calculate the gradient, update the weights
model.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
# wait for all CUDA work on the GPU to finish then calculate iteration time taken
if args.device.startswith('cuda'):
torch.cuda.synchronize()
t1 = time.time()