split out train,test,new separately when reporting on sampling word identity

morelm
Andrej 2022-06-09 20:55:27 +00:00 committed by GitHub
parent e0a08f234c
commit c3aaadcb16
1 changed files with 11 additions and 12 deletions

View File

@ -252,8 +252,7 @@ def print_samples(num=10):
top_k = args.top_k if args.top_k != -1 else None
steps = train_dataset.get_output_length() - 1 # -1 because we already start with <START> token (index 0)
X_samp = sample(model, X_init, steps, top_k=top_k).to('cpu')
unique_samples = []
had_samples = []
train_samples, test_samples, new_samples = [], [], []
for i in range(X_samp.size(0)):
# get the i'th row of sampled integers, as python list
row = X_samp[i, 1:].tolist() # note: we need to crop out the first <START> token
@ -262,17 +261,17 @@ def print_samples(num=10):
row = row[:crop_index]
word_samp = train_dataset.decode(row)
# separately track samples that we have and have not seen before
word_have = train_dataset.contains(word_samp) or test_dataset.contains(word_samp)
sample_list = had_samples if word_have else unique_samples
sample_list.append(word_samp)
if train_dataset.contains(word_samp):
train_samples.append(word_samp)
elif test_dataset.contains(word_samp):
test_samples.append(word_samp)
else:
new_samples.append(word_samp)
print('-'*80)
print(f'{len(had_samples)} Samples that were found in input dataset:')
for word in had_samples:
print(word)
print(f'{len(unique_samples)} Samples that were NOT found in input dataset:')
for word in unique_samples:
print(word)
for lst, desc in [(train_samples, 'in train'), (test_samples, 'in test'), (new_samples, 'new')]:
print(f"{len(lst)} samples that are {desc}:")
for word in lst:
print(word)
print('-'*80)
@torch.inference_mode()