Using transforms' new API.

pull/88/head
kajyuuen 2021-02-01 12:26:52 +09:00
parent c451cdd632
commit 8d0e12e023
2 changed files with 2 additions and 2 deletions

View File

@ -80,9 +80,9 @@ class BertSentenceVectorizer(BaseFeaturizer):
self.model.eval()
with torch.no_grad():
all_encoder_layers, _ = self.model(tokens_tensor)
outputs = self.model(tokens_tensor)
embedding = all_encoder_layers.cpu().numpy()[0]
embedding = outputs.last_hidden_state.cpu().numpy()[0]
if self.pooling_strategy == 'reduce_mean':
return np.mean(embedding, axis=0)
elif self.pooling_strategy == 'reduce_max':

View File