respect multigpu envs, e.g. cuda:2 designation should work

morelm
Andrej Karpathy 2022-08-19 22:31:36 +00:00
parent 013af92770
commit 055e7ee48a
1 changed files with 1 additions and 1 deletions

View File

@ -411,7 +411,7 @@ if __name__ == '__main__':
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
if args.device == 'cuda':
if args.device.startswith('cuda'):
torch.cuda.synchronize()
t1 = time.time()