implementation of InfiniteDataLoader sad

morelm
Andrej Karpathy 2022-08-20 01:24:44 +00:00
parent a7c52cd4d0
commit d4ede45208
1 changed files with 25 additions and 13 deletions

View File

@ -297,6 +297,25 @@ def create_datasets(input_file):
return train_dataset, test_dataset
class InfiniteDataLoader:
"""
this is really hacky and I'm not proud of it, but there doesn't seem to be
a better way in PyTorch to just create an infinite dataloader?
"""
def __init__(self, dataset, **kwargs):
train_sampler = torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=int(1e10))
self.train_loader = DataLoader(dataset, sampler=train_sampler, **kwargs)
self.data_iter = iter(self.train_loader)
def next(self):
try:
batch = next(self.data_iter)
except StopIteration: # this will technically only happen after 1e10 samples... (i.e. basically never)
self.data_iter = iter(self.train_loader)
batch = next(self.data_iter)
return batch
# -----------------------------------------------------------------------------
if __name__ == '__main__':
@ -342,7 +361,7 @@ if __name__ == '__main__':
model = GPT(config)
model.to(args.device)
print(f"model #params: {sum(p.numel() for p in model.parameters())}")
if args.resume or args.sample_only: # note: if we sample-only then we also assume we're resuming
if args.resume or args.sample_only: # note: if we sample-only then we also assume we are resuming
print("resuming from existing model in the workdir")
model.load_state_dict(torch.load(os.path.join(args.work_dir, 'model.pt')))
if args.sample_only:
@ -350,14 +369,10 @@ if __name__ == '__main__':
sys.exit()
# init optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay,
betas=(0.9, 0.99), eps=1e-8)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, betas=(0.9, 0.99), eps=1e-8)
# init dataloader
train_sampler = torch.utils.data.RandomSampler(train_dataset, replacement=True, num_samples=int(1e10))
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True,
num_workers=args.num_workers, sampler=train_sampler)
data_iter = iter(train_loader)
batch_loader = InfiniteDataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True, num_workers=args.num_workers)
# training loop
best_loss = None
@ -365,12 +380,9 @@ if __name__ == '__main__':
while True:
t0 = time.time()
# fetch the next batch and reset dataloader every epoch as necessary
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(train_loader)
batch = next(data_iter)
# get the next batch, ship to device, and unpack it to input and target
batch = batch_loader.next()
batch = [t.to(args.device) for t in batch]
X, Y = batch