Merge pull request #274 from apivovarov/gelu

Use nn.GELU - 1.27x faster training
pull/301/head
Andrej 2023-06-14 16:25:15 -07:00 committed by GitHub
commit f08abb45bd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 9 deletions

View File

@ -15,14 +15,6 @@ import torch
import torch.nn as nn
from torch.nn import functional as F
# @torch.jit.script # good to enable when not using torch.compile, disable when using (our default)
def new_gelu(x):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
"""
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
class LayerNorm(nn.Module):
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
@ -88,12 +80,13 @@ class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
self.gelu = nn.GELU()
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
x = self.c_fc(x)
x = new_gelu(x)
x = self.gelu(x)
x = self.c_proj(x)
x = self.dropout(x)
return x