@@ -73,7 +73,7 @@ def json_to_labeled_instances(self, inputs: JsonDict) -> List[Instance]:
73
73
74
74
def get_gradients (self , instances : List [Instance ]) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
75
75
"""
76
- Gets the gradients of the loss with respect to the model inputs.
76
+ Gets the gradients of the logits with respect to the model inputs.
77
77
78
78
# Parameters
79
79
@@ -91,7 +91,7 @@ def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict
91
91
Takes a `JsonDict` representing the inputs of the model and converts
92
92
them to [`Instances`](../data/instance.md)), sends these through
93
93
the model [`forward`](../models/model.md#forward) function after registering hooks on the embedding
94
- layer of the model. Calls `backward` on the loss and then removes the
94
+ layer of the model. Calls `backward` on the logits and then removes the
95
95
hooks.
96
96
"""
97
97
# set requires_grad to true for all parameters, but save original values to
@@ -113,13 +113,13 @@ def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict
113
113
self ._model .forward (** dataset_tensor_dict ) # type: ignore
114
114
)
115
115
116
- loss = outputs ["loss" ]
116
+ predicted_logit = outputs ["logits" ]. squeeze ( 0 )[ int ( torch . argmax ( outputs [ "probs" ])) ]
117
117
# Zero gradients.
118
118
# NOTE: this is actually more efficient than calling `self._model.zero_grad()`
119
119
# because it avoids a read op when the gradients are first updated below.
120
120
for p in self ._model .parameters ():
121
121
p .grad = None
122
- loss .backward ()
122
+ predicted_logit .backward ()
123
123
124
124
for hook in hooks :
125
125
hook .remove ()
0 commit comments