Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit dc3a4f6

Browse files
authored
clean up forward hooks on exception (#4778)
1 parent fcc3a70 commit dc3a4f6

File tree

4 files changed

+16
-8
lines changed

4 files changed

+16
-8
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1515

1616
### Fixed
1717

18+
- Fixed a bug where forward hooks were not cleaned up with saliency interpreters if there
19+
was an exception.
1820
- Fixed the computation of saliency maps in the Interpret code when using mismatched indexing.
1921
Previously, we would compute gradients from the top of the transformer, after aggregation from
2022
wordpieces to tokens, which gives results that are not very informative. Now, we compute gradients

allennlp/interpret/saliency_interpreters/integrated_gradient.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -88,9 +88,11 @@ def _integrate_gradients(self, instance: Instance) -> Dict[str, numpy.ndarray]:
8888
# Hook for modifying embedding value
8989
handles = self._register_hooks(alpha, embeddings_list, token_offsets)
9090

91-
grads = self.predictor.get_gradients([instance])[0]
92-
for handle in handles:
93-
handle.remove()
91+
try:
92+
grads = self.predictor.get_gradients([instance])[0]
93+
finally:
94+
for handle in handles:
95+
handle.remove()
9496

9597
# Running sum of gradients
9698
if ig_grads == {}:

allennlp/interpret/saliency_interpreters/simple_gradient.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ def saliency_interpret_from_json(self, inputs: JsonDict) -> JsonDict:
3030

3131
# Hook used for saving embeddings
3232
handles = self._register_hooks(embeddings_list, token_offsets)
33-
grads = self.predictor.get_gradients([instance])[0]
34-
for handle in handles:
35-
handle.remove()
33+
try:
34+
grads = self.predictor.get_gradients([instance])[0]
35+
finally:
36+
for handle in handles:
37+
handle.remove()
3638

3739
# Gradients come back in the reverse order that they were sent into the network
3840
embeddings_list.reverse()

allennlp/interpret/saliency_interpreters/smooth_gradient.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,10 @@ def _smooth_grads(self, instance: Instance) -> Dict[str, numpy.ndarray]:
7272
total_gradients: Dict[str, Any] = {}
7373
for _ in range(self.num_samples):
7474
handle = self._register_forward_hook(self.stdev)
75-
grads = self.predictor.get_gradients([instance])[0]
76-
handle.remove()
75+
try:
76+
grads = self.predictor.get_gradients([instance])[0]
77+
finally:
78+
handle.remove()
7779

7880
# Sum gradients
7981
if total_gradients == {}:

0 commit comments

Comments
 (0)