respect multigpu envs, e.g. cuda:2 designation should work
parent
013af92770
commit
055e7ee48a
|
@ -411,7 +411,7 @@ if __name__ == '__main__':
|
|||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
if args.device == 'cuda':
|
||||
if args.device.startswith('cuda'):
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
|
||||
|
|
Loading…
Reference in New Issue