1
1
# pylint: disable=protected-access
2
2
from copy import deepcopy
3
- from typing import List
3
+ from typing import Dict , List
4
4
5
5
import numpy
6
6
import torch
@@ -66,6 +66,7 @@ def __init__(self,
66
66
if not self .vocab ._index_to_token [self .namespace ][i ].isalnum ():
67
67
self .invalid_replacement_indices .append (i )
68
68
self .embedding_matrix : torch .Tensor = None
69
+ self .embedding_layer : torch .nn .Module = None
69
70
70
71
def initialize (self ):
71
72
"""
@@ -74,7 +75,7 @@ def initialize(self):
74
75
being done when __init__() is called.
75
76
"""
76
77
if self .embedding_matrix is None :
77
- self .embedding_matrix = self ._construct_embedding_matrix ()
78
+ self .embedding_matrix = self ._construct_embedding_matrix (). cpu ()
78
79
79
80
def _construct_embedding_matrix (self ) -> Embedding :
80
81
"""
@@ -87,6 +88,7 @@ def _construct_embedding_matrix(self) -> Embedding:
87
88
matrix".
88
89
"""
89
90
embedding_layer = util .find_embedding_layer (self .predictor ._model )
91
+ self .embedding_layer = embedding_layer
90
92
if isinstance (embedding_layer , (Embedding , torch .nn .modules .sparse .Embedding )):
91
93
# If we're using something that already has an only embedding matrix, we can just use
92
94
# that and bypass this method.
@@ -99,36 +101,40 @@ def _construct_embedding_matrix(self) -> Embedding:
99
101
max_index = self .vocab .get_token_index (all_tokens [- 1 ], self .namespace )
100
102
self .invalid_replacement_indices = [i for i in self .invalid_replacement_indices if i < max_index ]
101
103
102
- all_inputs = {}
104
+ inputs = self ._make_embedder_input (all_tokens )
105
+
106
+ # pass all tokens through the fake matrix and create an embedding out of it.
107
+ embedding_matrix = embedding_layer (inputs ).squeeze ()
108
+
109
+ return embedding_matrix
110
+
111
+ def _make_embedder_input (self , all_tokens : List [str ]) -> Dict [str , torch .Tensor ]:
112
+ inputs = {}
103
113
# A bit of a hack; this will only work with some dataset readers, but it'll do for now.
104
114
indexers = self .predictor ._dataset_reader ._token_indexers # type: ignore
105
115
for indexer_name , token_indexer in indexers .items ():
106
116
if isinstance (token_indexer , SingleIdTokenIndexer ):
107
117
all_indices = [self .vocab ._token_to_index [self .namespace ][token ] for token in all_tokens ]
108
- all_inputs [indexer_name ] = torch .LongTensor (all_indices ).unsqueeze (0 )
118
+ inputs [indexer_name ] = torch .LongTensor (all_indices ).unsqueeze (0 )
109
119
elif isinstance (token_indexer , TokenCharactersIndexer ):
110
120
tokens = [Token (x ) for x in all_tokens ]
111
121
max_token_length = max (len (x ) for x in all_tokens )
112
122
indexed_tokens = token_indexer .tokens_to_indices (tokens , self .vocab , "token_characters" )
113
123
padded_tokens = token_indexer .as_padded_tensor (indexed_tokens ,
114
124
{"token_characters" : len (tokens )},
115
125
{"num_token_characters" : max_token_length })
116
- all_inputs [indexer_name ] = torch .LongTensor (padded_tokens ['token_characters' ]).unsqueeze (0 )
126
+ inputs [indexer_name ] = torch .LongTensor (padded_tokens ['token_characters' ]).unsqueeze (0 )
117
127
elif isinstance (token_indexer , ELMoTokenCharactersIndexer ):
118
128
elmo_tokens = []
119
129
for token in all_tokens :
120
130
elmo_indexed_token = token_indexer .tokens_to_indices ([Token (text = token )],
121
131
self .vocab ,
122
132
"sentence" )["sentence" ]
123
133
elmo_tokens .append (elmo_indexed_token [0 ])
124
- all_inputs [indexer_name ] = torch .LongTensor (elmo_tokens ).unsqueeze (0 )
134
+ inputs [indexer_name ] = torch .LongTensor (elmo_tokens ).unsqueeze (0 )
125
135
else :
126
136
raise RuntimeError ('Unsupported token indexer:' , token_indexer )
127
-
128
- # pass all tokens through the fake matrix and create an embedding out of it.
129
- embedding_matrix = embedding_layer (all_inputs ).squeeze ()
130
-
131
- return embedding_matrix
137
+ return inputs
132
138
133
139
def attack_from_json (self ,
134
140
inputs : JsonDict ,
@@ -254,7 +260,6 @@ def attack_from_json(self,
254
260
255
261
# Get new token using taylor approximation.
256
262
new_id = self ._first_order_taylor (grad [index_of_token_to_flip ],
257
- self .embedding_matrix ,
258
263
original_id_of_token_to_flip ,
259
264
sign )
260
265
@@ -292,10 +297,7 @@ def attack_from_json(self,
292
297
"original" : original_tokens ,
293
298
"outputs" : outputs })
294
299
295
- def _first_order_taylor (self , grad : numpy .ndarray ,
296
- embedding_matrix : torch .Tensor ,
297
- token_idx : int ,
298
- sign : int ) -> int :
300
+ def _first_order_taylor (self , grad : numpy .ndarray , token_idx : int , sign : int ) -> int :
299
301
"""
300
302
The below code is based on
301
303
https://github.com/pmichel31415/translate/blob/paul/pytorch_translate/
@@ -306,14 +308,20 @@ def _first_order_taylor(self, grad: numpy.ndarray,
306
308
first-order taylor approximation of the loss.
307
309
"""
308
310
grad = torch .from_numpy (grad )
309
- embedding_matrix = embedding_matrix .cpu ()
310
- word_embeds = torch .nn .functional .embedding (torch .LongTensor ([token_idx ]),
311
- embedding_matrix )
312
- word_embeds = word_embeds .detach ().unsqueeze (0 )
311
+ if token_idx >= self .embedding_matrix .size (0 ):
312
+ # This happens when we've truncated our fake embedding matrix. We need to do a dot
313
+ # product with the word vector of the current token; if that token is out of
314
+ # vocabulary for our truncated matrix, we need to run it through the embedding layer.
315
+ inputs = self ._make_embedder_input ([self .vocab .get_token_from_index (token_idx )])
316
+ word_embedding = self .embedding_layer (inputs )[0 ]
317
+ else :
318
+ word_embedding = torch .nn .functional .embedding (torch .LongTensor ([token_idx ]),
319
+ self .embedding_matrix )
320
+ word_embedding = word_embedding .detach ().unsqueeze (0 )
313
321
grad = grad .unsqueeze (0 ).unsqueeze (0 )
314
322
# solves equation (3) here https://arxiv.org/abs/1903.06620
315
- new_embed_dot_grad = torch .einsum ("bij,kj->bik" , (grad , embedding_matrix ))
316
- prev_embed_dot_grad = torch .einsum ("bij,bij->bi" , (grad , word_embeds )).unsqueeze (- 1 )
323
+ new_embed_dot_grad = torch .einsum ("bij,kj->bik" , (grad , self . embedding_matrix ))
324
+ prev_embed_dot_grad = torch .einsum ("bij,bij->bi" , (grad , word_embedding )).unsqueeze (- 1 )
317
325
neg_dir_dot_grad = sign * (prev_embed_dot_grad - new_embed_dot_grad )
318
326
neg_dir_dot_grad = neg_dir_dot_grad .detach ().cpu ().numpy ()
319
327
# Do not replace with non-alphanumeric tokens
0 commit comments