Merge pull request #88 from kajyuuen/fix-bert-feature

Fix bert features
pull/89/head
nyanp 2021-02-03 07:51:28 +09:00 committed by GitHub
commit 32bffe86bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 2 deletions

View File

@ -80,9 +80,9 @@ class BertSentenceVectorizer(BaseFeaturizer):
self.model.eval()
with torch.no_grad():
all_encoder_layers, _ = self.model(tokens_tensor)
outputs = self.model(tokens_tensor)
embedding = all_encoder_layers.cpu().numpy()[0]
embedding = outputs.last_hidden_state.cpu().numpy()[0]
if self.pooling_strategy == 'reduce_mean':
return np.mean(embedding, axis=0)
elif self.pooling_strategy == 'reduce_max':

View File