split out train,test,new separately when reporting on sampling word identity
parent
e0a08f234c
commit
c3aaadcb16
23
makemore.py
23
makemore.py
|
@ -252,8 +252,7 @@ def print_samples(num=10):
|
|||
top_k = args.top_k if args.top_k != -1 else None
|
||||
steps = train_dataset.get_output_length() - 1 # -1 because we already start with <START> token (index 0)
|
||||
X_samp = sample(model, X_init, steps, top_k=top_k).to('cpu')
|
||||
unique_samples = []
|
||||
had_samples = []
|
||||
train_samples, test_samples, new_samples = [], [], []
|
||||
for i in range(X_samp.size(0)):
|
||||
# get the i'th row of sampled integers, as python list
|
||||
row = X_samp[i, 1:].tolist() # note: we need to crop out the first <START> token
|
||||
|
@ -262,17 +261,17 @@ def print_samples(num=10):
|
|||
row = row[:crop_index]
|
||||
word_samp = train_dataset.decode(row)
|
||||
# separately track samples that we have and have not seen before
|
||||
word_have = train_dataset.contains(word_samp) or test_dataset.contains(word_samp)
|
||||
sample_list = had_samples if word_have else unique_samples
|
||||
sample_list.append(word_samp)
|
||||
|
||||
if train_dataset.contains(word_samp):
|
||||
train_samples.append(word_samp)
|
||||
elif test_dataset.contains(word_samp):
|
||||
test_samples.append(word_samp)
|
||||
else:
|
||||
new_samples.append(word_samp)
|
||||
print('-'*80)
|
||||
print(f'{len(had_samples)} Samples that were found in input dataset:')
|
||||
for word in had_samples:
|
||||
print(word)
|
||||
print(f'{len(unique_samples)} Samples that were NOT found in input dataset:')
|
||||
for word in unique_samples:
|
||||
print(word)
|
||||
for lst, desc in [(train_samples, 'in train'), (test_samples, 'in test'), (new_samples, 'new')]:
|
||||
print(f"{len(lst)} samples that are {desc}:")
|
||||
for word in lst:
|
||||
print(word)
|
||||
print('-'*80)
|
||||
|
||||
@torch.inference_mode()
|
||||
|
|
Loading…
Reference in New Issue