Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit 2a95022

Browse files
authored
Revert batching for input reduction (#3276)
1 parent 052e8d3 commit 2a95022

File tree

1 file changed

+5
-24
lines changed

1 file changed

+5
-24
lines changed

allennlp/interpret/attackers/input_reduction.py

+5-24
Original file line numberDiff line numberDiff line change
@@ -80,31 +80,12 @@ def get_length(input_instance: Instance):
8080
return len(input_text_field.tokens)
8181
candidates = heapq.nsmallest(self.beam_size, candidates, key=lambda x: get_length(x[0]))
8282

83-
# predictor.get_gradients is where the most expensive computation happens, so we're
84-
# going to do it in a batch, up front, before iterating over the results.
85-
copied_candidates = deepcopy(candidates)
86-
all_grads, all_outputs = self.predictor.get_gradients([x[0] for x in copied_candidates])
87-
88-
# The output in `all_grads` and `all_outputs` is batched in a dictionary (e.g.,
89-
# {'grad_output_1': batched_tensor}). We need to split this into a list of non-batched
90-
# dictionaries that we can iterate over.
91-
split_grads = []
92-
for i in range(len(copied_candidates)):
93-
split_grads.append({key: value[i] for key, value in all_grads.items()})
94-
split_outputs = []
95-
for i in range(len(copied_candidates)):
96-
instance_outputs = {}
97-
for key, value in all_outputs.items():
98-
if key == 'loss':
99-
continue
100-
instance_outputs[key] = value[i]
101-
split_outputs.append(instance_outputs)
102-
beam_candidates = [(x[0], x[1], x[2], split_grads[i], split_outputs[i])
103-
for i, x in enumerate(copied_candidates)]
104-
83+
beam_candidates = deepcopy(candidates)
10584
candidates = []
106-
for beam_instance, smallest_idx, tag_mask, grads, outputs in beam_candidates:
85+
for beam_instance, smallest_idx, tag_mask in beam_candidates:
86+
# get gradients and predictions
10787
beam_tag_mask = deepcopy(tag_mask)
88+
grads, outputs = self.predictor.get_gradients([beam_instance])
10889

10990
for output in outputs:
11091
if isinstance(outputs[output], torch.Tensor):
@@ -133,7 +114,7 @@ def get_length(input_instance: Instance):
133114
current_tokens = deepcopy(text_field.tokens)
134115
reduced_instances_and_smallest = _remove_one_token(beam_instance,
135116
input_field_to_attack,
136-
grads[grad_input_field],
117+
grads[grad_input_field][0],
137118
ignore_tokens,
138119
self.beam_size,
139120
beam_tag_mask)

0 commit comments

Comments
 (0)