implementation of InfiniteDataLoader sad
parent
a7c52cd4d0
commit
d4ede45208
38
makemore.py
38
makemore.py
|
@ -297,6 +297,25 @@ def create_datasets(input_file):
|
|||
|
||||
return train_dataset, test_dataset
|
||||
|
||||
class InfiniteDataLoader:
|
||||
"""
|
||||
this is really hacky and I'm not proud of it, but there doesn't seem to be
|
||||
a better way in PyTorch to just create an infinite dataloader?
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, **kwargs):
|
||||
train_sampler = torch.utils.data.RandomSampler(dataset, replacement=True, num_samples=int(1e10))
|
||||
self.train_loader = DataLoader(dataset, sampler=train_sampler, **kwargs)
|
||||
self.data_iter = iter(self.train_loader)
|
||||
|
||||
def next(self):
|
||||
try:
|
||||
batch = next(self.data_iter)
|
||||
except StopIteration: # this will technically only happen after 1e10 samples... (i.e. basically never)
|
||||
self.data_iter = iter(self.train_loader)
|
||||
batch = next(self.data_iter)
|
||||
return batch
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
|
||||
|
@ -342,7 +361,7 @@ if __name__ == '__main__':
|
|||
model = GPT(config)
|
||||
model.to(args.device)
|
||||
print(f"model #params: {sum(p.numel() for p in model.parameters())}")
|
||||
if args.resume or args.sample_only: # note: if we sample-only then we also assume we're resuming
|
||||
if args.resume or args.sample_only: # note: if we sample-only then we also assume we are resuming
|
||||
print("resuming from existing model in the workdir")
|
||||
model.load_state_dict(torch.load(os.path.join(args.work_dir, 'model.pt')))
|
||||
if args.sample_only:
|
||||
|
@ -350,14 +369,10 @@ if __name__ == '__main__':
|
|||
sys.exit()
|
||||
|
||||
# init optimizer
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay,
|
||||
betas=(0.9, 0.99), eps=1e-8)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay, betas=(0.9, 0.99), eps=1e-8)
|
||||
|
||||
# init dataloader
|
||||
train_sampler = torch.utils.data.RandomSampler(train_dataset, replacement=True, num_samples=int(1e10))
|
||||
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True,
|
||||
num_workers=args.num_workers, sampler=train_sampler)
|
||||
data_iter = iter(train_loader)
|
||||
batch_loader = InfiniteDataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True, num_workers=args.num_workers)
|
||||
|
||||
# training loop
|
||||
best_loss = None
|
||||
|
@ -365,12 +380,9 @@ if __name__ == '__main__':
|
|||
while True:
|
||||
|
||||
t0 = time.time()
|
||||
# fetch the next batch and reset dataloader every epoch as necessary
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
except StopIteration:
|
||||
data_iter = iter(train_loader)
|
||||
batch = next(data_iter)
|
||||
|
||||
# get the next batch, ship to device, and unpack it to input and target
|
||||
batch = batch_loader.next()
|
||||
batch = [t.to(args.device) for t in batch]
|
||||
X, Y = batch
|
||||
|
||||
|
|
Loading…
Reference in New Issue