Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.

Commit ef337d2

Browse files
committed
Compute attributions w.r.t the predicted logit, not the predicted loss
1 parent 1fff7ca commit ef337d2

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

CHANGELOG.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1919
### Fixed
2020

2121
- Fixed typo with `LabelField` string representation: removed trailing apostrophe.
22-
22+
- Gradient attribution in AllenNLP Interpret now computed as a function of the predicted class' logit, not its loss.
2323

2424
## [v1.3.0](https://github.com/allenai/allennlp/releases/tag/v1.3.0) - 2020-12-15
2525

allennlp/interpret/saliency_interpreters/simple_gradient.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class SimpleGradient(SaliencyInterpreter):
1717

1818
def saliency_interpret_from_json(self, inputs: JsonDict) -> JsonDict:
1919
"""
20-
Interprets the model's prediction for inputs. Gets the gradients of the loss with respect
20+
Interprets the model's prediction for inputs. Gets the gradients of the logits with respect
2121
to the input and returns those gradients normalized and sanitized.
2222
"""
2323
labeled_instances = self.predictor.json_to_labeled_instances(inputs)

allennlp/predictors/predictor.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def json_to_labeled_instances(self, inputs: JsonDict) -> List[Instance]:
7373

7474
def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
7575
"""
76-
Gets the gradients of the loss with respect to the model inputs.
76+
Gets the gradients of the logits with respect to the model inputs.
7777
7878
# Parameters
7979
@@ -91,7 +91,7 @@ def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict
9191
Takes a `JsonDict` representing the inputs of the model and converts
9292
them to [`Instances`](../data/instance.md)), sends these through
9393
the model [`forward`](../models/model.md#forward) function after registering hooks on the embedding
94-
layer of the model. Calls `backward` on the loss and then removes the
94+
layer of the model. Calls `backward` on the logits and then removes the
9595
hooks.
9696
"""
9797
# set requires_grad to true for all parameters, but save original values to
@@ -113,13 +113,13 @@ def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict
113113
self._model.forward(**dataset_tensor_dict) # type: ignore
114114
)
115115

116-
loss = outputs["loss"]
116+
predicted_logit = outputs["logits"].squeeze(0)[int(torch.argmax(outputs['probs']))]
117117
# Zero gradients.
118118
# NOTE: this is actually more efficient than calling `self._model.zero_grad()`
119119
# because it avoids a read op when the gradients are first updated below.
120120
for p in self._model.parameters():
121121
p.grad = None
122-
loss.backward()
122+
predicted_logit.backward()
123123

124124
for hook in hooks:
125125
hook.remove()

0 commit comments

Comments
 (0)