first commit
parent
180c4f7260
commit
8f79bd0126
2
LICENSE
2
LICENSE
|
@ -1,6 +1,6 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2022 Andrej
|
||||
Copyright (c) 2022 Andrej Karpathy
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
|
|
@ -0,0 +1,42 @@
|
|||
|
||||
# makemore
|
||||
|
||||
makemore is the most accessible way of tinkering with a GPT.
|
||||
|
||||
The one-file script `makemore.py` takes one text file as input, where each line is assumed to be one training thing, and generates more things like it. For example, we can feed it a database of names, and then use it to generate new cool baby name ideas that sound name-like, but are not already existing names. Or if we feed it a database of company names then we can generate new ideas for a name of a company. Or we can just feed it valid scrabble words and generate english-like babble.
|
||||
|
||||
Under the hood, the script trains a (character-level) Transformer, identical to the one that powers [GPT and friends]().
|
||||
|
||||
This is not meant to be a heavyweight library with switches and knobs. It's one hackable file of ~500 lines of code. [PyTorch](https://pytorch.org) is the only requirement. Go nuts.
|
||||
|
||||
### Usage
|
||||
|
||||
The included `names.txt` dataset, as an example, has the most common 32K names takes from [ssa.gov](https://www.ssa.gov/oact/babynames/) for the year 2018. It looks like:
|
||||
|
||||
```
|
||||
emma
|
||||
olivia
|
||||
ava
|
||||
isabella
|
||||
sophia
|
||||
charlotte
|
||||
...
|
||||
```
|
||||
|
||||
Let's point the script at it:
|
||||
|
||||
```bash
|
||||
$ python makemore.py -i names.txt -o names
|
||||
```
|
||||
|
||||
Training progress and logs and model will all be saved to the working directory `names`. The default model is a super tiny 200K param transformer; Many more training configurations are available - see the argparse and read the code. Training does not require any special hardware, it runs on my Macbook Air and will run on anything else, but if you have a GPU then training will fly. As training progresses the script will print some samples throughout. However, if you'd like to sample manually, you can use the `--sample-only` flag, e.g. in a separate terminal do:
|
||||
|
||||
```bash
|
||||
$ python makemore.py -i names.txt -o names --sample-only
|
||||
```
|
||||
|
||||
This will load the best model so far and print more samples on demand. Have fun.
|
||||
|
||||
### License
|
||||
|
||||
MIT
|
|
@ -0,0 +1,476 @@
|
|||
"""
|
||||
you give this script some words (one per line) and it will generate more things like it.
|
||||
uses super state of the art Transformer AI tech
|
||||
this code is intended to be super hackable. tune it to your needs.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import math
|
||||
import argparse
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data.dataloader import DataLoader
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# GPT (PyTorch) model definition
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
# size of the model
|
||||
n_layer: int = 4
|
||||
n_head: int = 4
|
||||
n_embd: int = 64
|
||||
vocab_size: int = None
|
||||
block_size: int = None
|
||||
# regularization
|
||||
embd_pdrop: float = 0.1
|
||||
resid_pdrop:float = 0.1
|
||||
attn_pdrop:float = 0.1
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
# optimization parameters
|
||||
learning_rate: float = 5e-4
|
||||
weight_decay: float = 0.1 # only applied on matmul weights
|
||||
betas: List[float] = (0.9, 0.99)
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
"""
|
||||
A vanilla multi-head masked self-attention layer with a projection at the end.
|
||||
It is possible to use torch.nn.MultiheadAttention here but I am including an
|
||||
explicit implementation here to show that there is nothing too scary here.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
assert config.n_embd % config.n_head == 0
|
||||
# key, query, value projections for all heads
|
||||
self.key = nn.Linear(config.n_embd, config.n_embd)
|
||||
self.query = nn.Linear(config.n_embd, config.n_embd)
|
||||
self.value = nn.Linear(config.n_embd, config.n_embd)
|
||||
# regularization
|
||||
self.attn_drop = nn.Dropout(config.attn_pdrop)
|
||||
self.resid_drop = nn.Dropout(config.resid_pdrop)
|
||||
# output projection
|
||||
self.proj = nn.Linear(config.n_embd, config.n_embd)
|
||||
# causal mask to ensure that attention is only applied to the left in the input sequence
|
||||
self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size))
|
||||
.view(1, 1, config.block_size, config.block_size))
|
||||
self.n_head = config.n_head
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
||||
|
||||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
|
||||
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
|
||||
att = F.softmax(att, dim=-1)
|
||||
att = self.attn_drop(att)
|
||||
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
||||
|
||||
# output projection
|
||||
y = self.resid_drop(self.proj(y))
|
||||
return y
|
||||
|
||||
class Block(nn.Module):
|
||||
""" an unassuming Transformer block """
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.ln1 = nn.LayerNorm(config.n_embd)
|
||||
self.ln2 = nn.LayerNorm(config.n_embd)
|
||||
self.attn = CausalSelfAttention(config)
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(config.n_embd, 4 * config.n_embd),
|
||||
nn.GELU(),
|
||||
nn.Linear(4 * config.n_embd, config.n_embd),
|
||||
nn.Dropout(config.resid_pdrop),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.attn(self.ln1(x))
|
||||
x = x + self.mlp(self.ln2(x))
|
||||
return x
|
||||
|
||||
class GPT(nn.Module):
|
||||
""" the full GPT language model, with a context size of block_size """
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
# input embedding stem
|
||||
self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
|
||||
self.drop = nn.Dropout(config.embd_pdrop)
|
||||
# transformer
|
||||
self.blocks = nn.Sequential(*[Block(config) for _ in range(config.n_layer)])
|
||||
# decoder head
|
||||
self.ln_f = nn.LayerNorm(config.n_embd)
|
||||
self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
||||
|
||||
self.block_size = config.block_size
|
||||
self.apply(self._init_weights)
|
||||
|
||||
print("number of parameters: %d" % sum(p.numel() for p in self.parameters()))
|
||||
|
||||
def get_block_size(self):
|
||||
return self.block_size
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
torch.nn.init.zeros_(module.bias)
|
||||
torch.nn.init.ones_(module.weight)
|
||||
elif isinstance(module, GPT):
|
||||
torch.nn.init.normal_(module.pos_emb, mean=0.0, std=0.02)
|
||||
|
||||
def configure_optimizers(self, train_config):
|
||||
"""
|
||||
This long function is unfortunately doing something very simple and is being very defensive:
|
||||
We are separating out all parameters of the model into two buckets: those that will experience
|
||||
weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
|
||||
We are then returning the PyTorch optimizer object.
|
||||
"""
|
||||
|
||||
# separate out all parameters to those that will and won't experience regularizing weight decay
|
||||
decay = set()
|
||||
no_decay = set()
|
||||
whitelist_weight_modules = (torch.nn.Linear, )
|
||||
blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
|
||||
for mn, m in self.named_modules():
|
||||
for pn, p in m.named_parameters():
|
||||
fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
|
||||
|
||||
if pn.endswith('bias'):
|
||||
# all biases will not be decayed
|
||||
no_decay.add(fpn)
|
||||
elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
|
||||
# weights of whitelist modules will be weight decayed
|
||||
decay.add(fpn)
|
||||
elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
|
||||
# weights of blacklist modules will NOT be weight decayed
|
||||
no_decay.add(fpn)
|
||||
|
||||
# special case the position embedding parameter in the root GPT module as not decayed
|
||||
no_decay.add('pos_emb')
|
||||
|
||||
# validate that we considered every parameter
|
||||
param_dict = {pn: p for pn, p in self.named_parameters()}
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
|
||||
assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
|
||||
% (str(param_dict.keys() - union_params), )
|
||||
|
||||
# create the pytorch optimizer object
|
||||
optim_groups = [
|
||||
{"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay},
|
||||
{"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas)
|
||||
return optimizer
|
||||
|
||||
def forward(self, idx, targets=None):
|
||||
b, t = idx.size()
|
||||
assert t <= self.block_size, "Cannot forward, input tensor sequence is longer than model block_size."
|
||||
|
||||
# forward the GPT model
|
||||
token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector
|
||||
position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector
|
||||
x = self.drop(token_embeddings + position_embeddings)
|
||||
x = self.blocks(x)
|
||||
x = self.ln_f(x)
|
||||
logits = self.head(x)
|
||||
|
||||
# if we are given some desired targets also calculate the loss
|
||||
loss = None
|
||||
if targets is not None:
|
||||
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
||||
|
||||
return logits, loss
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# helper functions for evaluating and sampling from the model
|
||||
|
||||
def top_k_logits(logits, k):
|
||||
"""
|
||||
takes logits (N,D) and returns logits (N,D), but in each row only
|
||||
the top-k are kept and the rest of the entries are set to -inf.
|
||||
"""
|
||||
v, ix = torch.topk(logits, k, dim=-1)
|
||||
out = logits.clone()
|
||||
out[out < v[:, [-1]]] = -float('Inf')
|
||||
return out
|
||||
|
||||
@torch.inference_mode()
|
||||
def sample(model, x, steps, temperature=1.0, top_k=None):
|
||||
"""
|
||||
take a conditioning sequence of indices in x (b,t) and predict next tokens
|
||||
in the sequence, feeding the predictions back into the model each step.
|
||||
"""
|
||||
block_size = model.get_block_size()
|
||||
model.eval()
|
||||
for k in range(steps):
|
||||
# crop the context, if necessary
|
||||
x_cond = x if x.size(1) <= block_size else x[:, -block_size:]
|
||||
# feed the context into the model to get logits at each step
|
||||
logits, _ = model(x_cond)
|
||||
# pluck the logits at the final step and scale by temperature
|
||||
logits = logits[:, -1, :] / temperature
|
||||
# optionally crop probabilities to only the top k options
|
||||
if top_k is not None:
|
||||
logits = top_k_logits(logits, top_k)
|
||||
# apply softmax to convert to probabilities
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
# sample from the distribution or take the most likely
|
||||
ix = torch.multinomial(probs, num_samples=1)
|
||||
# append to the sequence and continue
|
||||
x = torch.cat((x, ix), dim=1)
|
||||
model.train()
|
||||
return x
|
||||
|
||||
def print_samples(num=10):
|
||||
""" samples from the model and pretty prints the decoded samples """
|
||||
X_init = torch.zeros(num, 1, dtype=torch.long).to(args.device)
|
||||
top_k = args.top_k if args.top_k != -1 else None
|
||||
steps = train_dataset.get_output_length() - 1 # -1 because we already start with <START> token (index 0)
|
||||
X_samp = sample(model, X_init, steps, top_k=top_k).to('cpu')
|
||||
unique_samples = []
|
||||
had_samples = []
|
||||
for i in range(X_samp.size(0)):
|
||||
# get the i'th row of sampled integers, as python list
|
||||
row = X_samp[i, 1:].tolist() # note: we need to crop out the first <START> token
|
||||
# token 0 is the <STOP> token, so we crop the output sequence at that point
|
||||
crop_index = row.index(0) if 0 in row else len(row)
|
||||
row = row[:crop_index]
|
||||
word_samp = train_dataset.decode(row)
|
||||
# separately track samples that we have and have not seen before
|
||||
word_have = train_dataset.contains(word_samp) or test_dataset.contains(word_samp)
|
||||
sample_list = had_samples if word_have else unique_samples
|
||||
sample_list.append(word_samp)
|
||||
|
||||
print('-'*80)
|
||||
print(f'{len(had_samples)} Samples that were found in input dataset:')
|
||||
for word in had_samples:
|
||||
print(word)
|
||||
print(f'{len(unique_samples)} Samples that were NOT found in input dataset:')
|
||||
for word in unique_samples:
|
||||
print(word)
|
||||
print('-'*80)
|
||||
|
||||
@torch.inference_mode()
|
||||
def evaluate(model, dataset, batch_size=50, max_batches=None):
|
||||
model.eval()
|
||||
loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=0)
|
||||
losses = []
|
||||
for i, batch in enumerate(loader):
|
||||
batch = [t.to(args.device) for t in batch]
|
||||
X, Y = batch
|
||||
logits, loss = model(X, Y)
|
||||
losses.append(loss.item())
|
||||
if max_batches is not None and i >= max_batches:
|
||||
break
|
||||
mean_loss = torch.tensor(losses).mean().item()
|
||||
model.train() # reset model back to training mode
|
||||
return mean_loss
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# helper functions for creating the training and test Datasets that emit words
|
||||
|
||||
class CharDataset(Dataset):
|
||||
|
||||
def __init__(self, words, chars, max_word_length):
|
||||
self.words = words
|
||||
self.chars = chars
|
||||
self.max_word_length = max_word_length
|
||||
self.stoi = {ch:i+1 for i,ch in enumerate(chars)}
|
||||
self.itos = {i:s for s,i in self.stoi.items()} # inverse mapping
|
||||
|
||||
def __len__(self):
|
||||
return len(self.words)
|
||||
|
||||
def contains(self, word):
|
||||
return word in self.words
|
||||
|
||||
def get_vocab_size(self):
|
||||
return len(self.chars) + 1 # all the possible characters and special 0 token
|
||||
|
||||
def get_output_length(self):
|
||||
return self.max_word_length + 1 # <START> token followed by words
|
||||
|
||||
def encode(self, word):
|
||||
ix = torch.tensor([self.stoi[w] for w in word], dtype=torch.long)
|
||||
return ix
|
||||
|
||||
def decode(self, ix):
|
||||
word = ''.join(self.itos[i] for i in ix)
|
||||
return word
|
||||
|
||||
def __getitem__(self, idx):
|
||||
word = self.words[idx]
|
||||
ix = self.encode(word)
|
||||
x = torch.zeros(self.max_word_length + 1, dtype=torch.long)
|
||||
y = torch.zeros(self.max_word_length + 1, dtype=torch.long)
|
||||
x[1:1+len(ix)] = ix
|
||||
y[:len(ix)] = ix
|
||||
y[len(ix)+1:] = -1 # index -1 will mask the loss at the inactive locations
|
||||
return x, y
|
||||
|
||||
def create_datasets(input_file):
|
||||
|
||||
# preprocessing of the input text file
|
||||
with open(input_file, 'r') as f:
|
||||
data = f.read()
|
||||
words = data.splitlines()
|
||||
words = [w.strip() for w in words] # get rid of any leading or trailing white space
|
||||
words = [w for w in words if w] # get rid of any empty strings
|
||||
chars = sorted(list(set(''.join(words)))) # all the possible characters
|
||||
max_word_length = max(len(w) for w in words)
|
||||
print(f"number of examples in the dataset: {len(words)}")
|
||||
print(f"max word length: {max_word_length}")
|
||||
print(f"number of unique characters in the vocabulary: {len(chars)}")
|
||||
print("vocabulary:")
|
||||
print(''.join(chars))
|
||||
|
||||
# partition the input data into a training and the test set
|
||||
test_set_size = min(1000, int(len(words) * 0.1)) # 10% of the training set, or up to 1000 examples
|
||||
rp = torch.randperm(len(words)).tolist()
|
||||
train_words = [words[i] for i in rp[:-test_set_size]]
|
||||
test_words = [words[i] for i in rp[-test_set_size:]]
|
||||
print(f"split up the dataset into {len(train_words)} training examples and {len(test_words)} test examples")
|
||||
|
||||
# wrap in dataset objects
|
||||
train_dataset = CharDataset(train_words, chars, max_word_length)
|
||||
test_dataset = CharDataset(test_words, chars, max_word_length)
|
||||
|
||||
return train_dataset, test_dataset
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
if __name__ == '__main__':
|
||||
|
||||
# parse command line args
|
||||
parser = argparse.ArgumentParser(description="Make More")
|
||||
# system/input/output
|
||||
parser.add_argument('--input-file', '-i', type=str, default='input.txt', help="input file with things one per line")
|
||||
parser.add_argument('--work-dir', '-o', type=str, default='out', help="output working directory")
|
||||
parser.add_argument('--resume', action='store_true', help="when this flag is used, we will resume optimization from existing model in the workdir")
|
||||
parser.add_argument('--num-workers', '-n', type=int, default=1, help="number of data workers for both train/test")
|
||||
parser.add_argument('--device', type=str, default='cpu', help="device to use for compute, e.g. cpu|cuda|m1")
|
||||
parser.add_argument('--seed', type=int, default=1337, help="seed")
|
||||
# sampling
|
||||
parser.add_argument('--sample-only', action='store_true', help="just sample from the model and quit, don't train")
|
||||
parser.add_argument('--top-k', type=int, default=-1, help="top-k for sampling, -1 means no top-k")
|
||||
# model
|
||||
parser.add_argument('--n-layer', type=int, default=4, help="number of layers in the transformer")
|
||||
parser.add_argument('--n-head', type=int, default=4, help="number of heads in the transformer")
|
||||
parser.add_argument('--n-embd', type=int, default=64, help="number of feature channels in the transformer")
|
||||
# optimization
|
||||
parser.add_argument('--batch-size', '-b', type=int, default=32, help="batch size during optimization")
|
||||
parser.add_argument('--learning-rate', '-l', type=float, default=5e-4, help="learning rate")
|
||||
parser.add_argument('--dropout', '-d', type=float, default=0.1, help="dropout rate")
|
||||
parser.add_argument('--weight-decay', '-w', type=float, default=0.1, help="weight decay")
|
||||
args = parser.parse_args()
|
||||
print(vars(args))
|
||||
|
||||
# system inits
|
||||
torch.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
torch.use_deterministic_algorithms(True)
|
||||
os.makedirs(args.work_dir, exist_ok=True)
|
||||
writer = SummaryWriter(log_dir=args.work_dir)
|
||||
|
||||
# init datasets
|
||||
train_dataset, test_dataset = create_datasets(args.input_file)
|
||||
|
||||
# init model
|
||||
config = GPTConfig(vocab_size=train_dataset.get_vocab_size(), block_size=train_dataset.get_output_length(),
|
||||
n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd,
|
||||
embd_pdrop=args.dropout, attn_pdrop=args.dropout, resid_pdrop=args.dropout)
|
||||
model = GPT(config)
|
||||
model.to(args.device)
|
||||
print(f"model #params: {sum(p.numel() for p in model.parameters())}")
|
||||
if args.resume or args.sample_only: # note: if we sample-only then we also assume we're resuming
|
||||
print("resuming from existing model in the workdir")
|
||||
model.load_state_dict(torch.load(os.path.join(args.work_dir, 'model.pt')))
|
||||
if args.sample_only:
|
||||
print_samples(num=50)
|
||||
sys.exit()
|
||||
|
||||
# init optimizer
|
||||
train_config = TrainConfig(learning_rate=args.learning_rate, weight_decay=args.weight_decay)
|
||||
optimizer = model.configure_optimizers(train_config)
|
||||
|
||||
# init dataloader
|
||||
train_sampler = torch.utils.data.RandomSampler(train_dataset, replacement=True, num_samples=int(1e10))
|
||||
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True,
|
||||
num_workers=args.num_workers, sampler=train_sampler)
|
||||
data_iter = iter(train_loader)
|
||||
|
||||
# training loop
|
||||
best_loss = None
|
||||
step = 0
|
||||
while True:
|
||||
|
||||
t0 = time.time()
|
||||
# fetch the next batch and reset dataloader every epoch as necessary
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
except StopIteration:
|
||||
data_iter = iter(train_loader)
|
||||
batch = next(data_iter)
|
||||
batch = [t.to(args.device) for t in batch]
|
||||
X, Y = batch
|
||||
|
||||
# feed into the model
|
||||
logits, loss = model(X, Y)
|
||||
|
||||
# calculate the gradient, clip it, update the weights
|
||||
model.zero_grad(set_to_none=True)
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
||||
optimizer.step()
|
||||
if args.device != 'cpu':
|
||||
torch.cuda.synchronize()
|
||||
t1 = time.time()
|
||||
|
||||
# logging
|
||||
if step % 10 == 0:
|
||||
print(f"step {step} | loss {loss.item():.4f} | step time {(t1-t0)*1000:.2f}ms")
|
||||
|
||||
# evaluate the model
|
||||
if step > 0 and step % 500 == 0:
|
||||
train_loss = evaluate(model, train_dataset, batch_size=100, max_batches=10)
|
||||
test_loss = evaluate(model, test_dataset, batch_size=100, max_batches=10)
|
||||
writer.add_scalar("Loss/train", train_loss, step)
|
||||
writer.add_scalar("Loss/test", test_loss, step)
|
||||
writer.flush()
|
||||
print(f"step {step} train loss: {train_loss} test loss: {test_loss}")
|
||||
# save the model to disk if it has improved
|
||||
if best_loss is None or test_loss < best_loss:
|
||||
out_path = os.path.join(args.work_dir, "model.pt")
|
||||
print(f"test loss {test_loss} is the best so far, saving model to {out_path}")
|
||||
torch.save(model.state_dict(), out_path)
|
||||
best_loss = test_loss
|
||||
|
||||
# sample from the model
|
||||
if step > 0 and step % 200 == 0:
|
||||
print_samples(num=10)
|
||||
|
||||
step += 1
|
Loading…
Reference in New Issue