@@ -80,31 +80,12 @@ def get_length(input_instance: Instance):
80
80
return len (input_text_field .tokens )
81
81
candidates = heapq .nsmallest (self .beam_size , candidates , key = lambda x : get_length (x [0 ]))
82
82
83
- # predictor.get_gradients is where the most expensive computation happens, so we're
84
- # going to do it in a batch, up front, before iterating over the results.
85
- copied_candidates = deepcopy (candidates )
86
- all_grads , all_outputs = self .predictor .get_gradients ([x [0 ] for x in copied_candidates ])
87
-
88
- # The output in `all_grads` and `all_outputs` is batched in a dictionary (e.g.,
89
- # {'grad_output_1': batched_tensor}). We need to split this into a list of non-batched
90
- # dictionaries that we can iterate over.
91
- split_grads = []
92
- for i in range (len (copied_candidates )):
93
- split_grads .append ({key : value [i ] for key , value in all_grads .items ()})
94
- split_outputs = []
95
- for i in range (len (copied_candidates )):
96
- instance_outputs = {}
97
- for key , value in all_outputs .items ():
98
- if key == 'loss' :
99
- continue
100
- instance_outputs [key ] = value [i ]
101
- split_outputs .append (instance_outputs )
102
- beam_candidates = [(x [0 ], x [1 ], x [2 ], split_grads [i ], split_outputs [i ])
103
- for i , x in enumerate (copied_candidates )]
104
-
83
+ beam_candidates = deepcopy (candidates )
105
84
candidates = []
106
- for beam_instance , smallest_idx , tag_mask , grads , outputs in beam_candidates :
85
+ for beam_instance , smallest_idx , tag_mask in beam_candidates :
86
+ # get gradients and predictions
107
87
beam_tag_mask = deepcopy (tag_mask )
88
+ grads , outputs = self .predictor .get_gradients ([beam_instance ])
108
89
109
90
for output in outputs :
110
91
if isinstance (outputs [output ], torch .Tensor ):
@@ -133,7 +114,7 @@ def get_length(input_instance: Instance):
133
114
current_tokens = deepcopy (text_field .tokens )
134
115
reduced_instances_and_smallest = _remove_one_token (beam_instance ,
135
116
input_field_to_attack ,
136
- grads [grad_input_field ],
117
+ grads [grad_input_field ][ 0 ] ,
137
118
ignore_tokens ,
138
119
self .beam_size ,
139
120
beam_tag_mask )
0 commit comments