remove weight init, not needed at this scale

morelm
Andrej Karpathy 2022-08-20 00:30:09 +00:00
parent 0a19a59564
commit d26d9750ee
1 changed files with 0 additions and 15 deletions

View File

@ -132,25 +132,10 @@ class GPT(nn.Module):
))
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# init all weights
self.apply(self._init_weights)
# report number of parameters (note we don't count the decoder parameters in lm_head)
n_params = sum(p.numel() for p in self.transformer.parameters())
print("number of parameters: %.2fM" % (n_params/1e6,))
def _init_weights(self, module):
# TODO is this function needed?
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
def configure_optimizers(self, train_config):
optimizer = torch.optim.AdamW(self.parameters(), lr=train_config.learning_rate,
betas=train_config.betas, weight_decay=train_config.weight_decay,