Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit d09042e

Browse files
authored
Fix crash when hotflip gets OOV input (#3277)
* Fix crash when hotflip gets OOV input * add comment
1 parent 2a95022 commit d09042e

File tree

1 file changed

+30
-22
lines changed

1 file changed

+30
-22
lines changed

allennlp/interpret/attackers/hotflip.py

+30-22
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# pylint: disable=protected-access
22
from copy import deepcopy
3-
from typing import List
3+
from typing import Dict, List
44

55
import numpy
66
import torch
@@ -66,6 +66,7 @@ def __init__(self,
6666
if not self.vocab._index_to_token[self.namespace][i].isalnum():
6767
self.invalid_replacement_indices.append(i)
6868
self.embedding_matrix: torch.Tensor = None
69+
self.embedding_layer: torch.nn.Module = None
6970

7071
def initialize(self):
7172
"""
@@ -74,7 +75,7 @@ def initialize(self):
7475
being done when __init__() is called.
7576
"""
7677
if self.embedding_matrix is None:
77-
self.embedding_matrix = self._construct_embedding_matrix()
78+
self.embedding_matrix = self._construct_embedding_matrix().cpu()
7879

7980
def _construct_embedding_matrix(self) -> Embedding:
8081
"""
@@ -87,6 +88,7 @@ def _construct_embedding_matrix(self) -> Embedding:
8788
matrix".
8889
"""
8990
embedding_layer = util.find_embedding_layer(self.predictor._model)
91+
self.embedding_layer = embedding_layer
9092
if isinstance(embedding_layer, (Embedding, torch.nn.modules.sparse.Embedding)):
9193
# If we're using something that already has an only embedding matrix, we can just use
9294
# that and bypass this method.
@@ -99,36 +101,40 @@ def _construct_embedding_matrix(self) -> Embedding:
99101
max_index = self.vocab.get_token_index(all_tokens[-1], self.namespace)
100102
self.invalid_replacement_indices = [i for i in self.invalid_replacement_indices if i < max_index]
101103

102-
all_inputs = {}
104+
inputs = self._make_embedder_input(all_tokens)
105+
106+
# pass all tokens through the fake matrix and create an embedding out of it.
107+
embedding_matrix = embedding_layer(inputs).squeeze()
108+
109+
return embedding_matrix
110+
111+
def _make_embedder_input(self, all_tokens: List[str]) -> Dict[str, torch.Tensor]:
112+
inputs = {}
103113
# A bit of a hack; this will only work with some dataset readers, but it'll do for now.
104114
indexers = self.predictor._dataset_reader._token_indexers # type: ignore
105115
for indexer_name, token_indexer in indexers.items():
106116
if isinstance(token_indexer, SingleIdTokenIndexer):
107117
all_indices = [self.vocab._token_to_index[self.namespace][token] for token in all_tokens]
108-
all_inputs[indexer_name] = torch.LongTensor(all_indices).unsqueeze(0)
118+
inputs[indexer_name] = torch.LongTensor(all_indices).unsqueeze(0)
109119
elif isinstance(token_indexer, TokenCharactersIndexer):
110120
tokens = [Token(x) for x in all_tokens]
111121
max_token_length = max(len(x) for x in all_tokens)
112122
indexed_tokens = token_indexer.tokens_to_indices(tokens, self.vocab, "token_characters")
113123
padded_tokens = token_indexer.as_padded_tensor(indexed_tokens,
114124
{"token_characters": len(tokens)},
115125
{"num_token_characters": max_token_length})
116-
all_inputs[indexer_name] = torch.LongTensor(padded_tokens['token_characters']).unsqueeze(0)
126+
inputs[indexer_name] = torch.LongTensor(padded_tokens['token_characters']).unsqueeze(0)
117127
elif isinstance(token_indexer, ELMoTokenCharactersIndexer):
118128
elmo_tokens = []
119129
for token in all_tokens:
120130
elmo_indexed_token = token_indexer.tokens_to_indices([Token(text=token)],
121131
self.vocab,
122132
"sentence")["sentence"]
123133
elmo_tokens.append(elmo_indexed_token[0])
124-
all_inputs[indexer_name] = torch.LongTensor(elmo_tokens).unsqueeze(0)
134+
inputs[indexer_name] = torch.LongTensor(elmo_tokens).unsqueeze(0)
125135
else:
126136
raise RuntimeError('Unsupported token indexer:', token_indexer)
127-
128-
# pass all tokens through the fake matrix and create an embedding out of it.
129-
embedding_matrix = embedding_layer(all_inputs).squeeze()
130-
131-
return embedding_matrix
137+
return inputs
132138

133139
def attack_from_json(self,
134140
inputs: JsonDict,
@@ -254,7 +260,6 @@ def attack_from_json(self,
254260

255261
# Get new token using taylor approximation.
256262
new_id = self._first_order_taylor(grad[index_of_token_to_flip],
257-
self.embedding_matrix,
258263
original_id_of_token_to_flip,
259264
sign)
260265

@@ -292,10 +297,7 @@ def attack_from_json(self,
292297
"original": original_tokens,
293298
"outputs": outputs})
294299

295-
def _first_order_taylor(self, grad: numpy.ndarray,
296-
embedding_matrix: torch.Tensor,
297-
token_idx: int,
298-
sign: int) -> int:
300+
def _first_order_taylor(self, grad: numpy.ndarray, token_idx: int, sign: int) -> int:
299301
"""
300302
The below code is based on
301303
https://github.com/pmichel31415/translate/blob/paul/pytorch_translate/
@@ -306,14 +308,20 @@ def _first_order_taylor(self, grad: numpy.ndarray,
306308
first-order taylor approximation of the loss.
307309
"""
308310
grad = torch.from_numpy(grad)
309-
embedding_matrix = embedding_matrix.cpu()
310-
word_embeds = torch.nn.functional.embedding(torch.LongTensor([token_idx]),
311-
embedding_matrix)
312-
word_embeds = word_embeds.detach().unsqueeze(0)
311+
if token_idx >= self.embedding_matrix.size(0):
312+
# This happens when we've truncated our fake embedding matrix. We need to do a dot
313+
# product with the word vector of the current token; if that token is out of
314+
# vocabulary for our truncated matrix, we need to run it through the embedding layer.
315+
inputs = self._make_embedder_input([self.vocab.get_token_from_index(token_idx)])
316+
word_embedding = self.embedding_layer(inputs)[0]
317+
else:
318+
word_embedding = torch.nn.functional.embedding(torch.LongTensor([token_idx]),
319+
self.embedding_matrix)
320+
word_embedding = word_embedding.detach().unsqueeze(0)
313321
grad = grad.unsqueeze(0).unsqueeze(0)
314322
# solves equation (3) here https://arxiv.org/abs/1903.06620
315-
new_embed_dot_grad = torch.einsum("bij,kj->bik", (grad, embedding_matrix))
316-
prev_embed_dot_grad = torch.einsum("bij,bij->bi", (grad, word_embeds)).unsqueeze(-1)
323+
new_embed_dot_grad = torch.einsum("bij,kj->bik", (grad, self.embedding_matrix))
324+
prev_embed_dot_grad = torch.einsum("bij,bij->bi", (grad, word_embedding)).unsqueeze(-1)
317325
neg_dir_dot_grad = sign * (prev_embed_dot_grad - new_embed_dot_grad)
318326
neg_dir_dot_grad = neg_dir_dot_grad.detach().cpu().numpy()
319327
# Do not replace with non-alphanumeric tokens

0 commit comments

Comments
 (0)