fix comment
parent
d4ede45208
commit
50617fa75d
|
@ -389,10 +389,12 @@ if __name__ == '__main__':
|
|||
# feed into the model
|
||||
logits, loss = model(X, Y)
|
||||
|
||||
# calculate the gradient, clip it, update the weights
|
||||
# calculate the gradient, update the weights
|
||||
model.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# wait for all CUDA work on the GPU to finish then calculate iteration time taken
|
||||
if args.device.startswith('cuda'):
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
|
|
Loading…
Reference in New Issue