simplify optimizer init and delete code

morelm
Andrej Karpathy 2022-08-20 00:32:40 +00:00
parent d26d9750ee
commit 35435ec087
1 changed files with 2 additions and 15 deletions

View File

@ -39,13 +39,6 @@ class GPTConfig:
vocab_size: int = None
block_size: int = None
@dataclass
class TrainConfig:
# optimization parameters
learning_rate: float = 5e-4
betas: List[float] = (0.9, 0.99)
weight_decay: float = 0.01
class NewGELU(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
@ -136,12 +129,6 @@ class GPT(nn.Module):
n_params = sum(p.numel() for p in self.transformer.parameters())
print("number of parameters: %.2fM" % (n_params/1e6,))
def configure_optimizers(self, train_config):
optimizer = torch.optim.AdamW(self.parameters(), lr=train_config.learning_rate,
betas=train_config.betas, weight_decay=train_config.weight_decay,
eps=1e-8)
return optimizer
def forward(self, idx, targets=None):
device = idx.device
b, t = idx.size()
@ -365,8 +352,8 @@ if __name__ == '__main__':
sys.exit()
# init optimizer
train_config = TrainConfig(learning_rate=args.learning_rate, weight_decay=args.weight_decay)
optimizer = model.configure_optimizers(train_config)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay,
betas=(0.9, 0.99), eps=1e-8)
# init dataloader
train_sampler = torch.utils.data.RandomSampler(train_dataset, replacement=True, num_samples=int(1e10))