Skip to content

Missing duplicate tokens from TokenClassificationExplainer #154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
swardiantara opened this issue Feb 16, 2025 · 0 comments
Open

Missing duplicate tokens from TokenClassificationExplainer #154

swardiantara opened this issue Feb 16, 2025 · 0 comments

Comments

@swardiantara
Copy link

I tried to use the TokenClassificationExplainer for my fine-tuned BERT model. It returns a dictionary where the key is the tokenized inputs.
When I process the returned dict manually, there was a missing token. Turned out, it has appeared once in the dict. A dictionary cannot have a duplicate key. Therefore, it did not show up in the final returned value. For those who use this class, I recommend to modify the return value so that all the tokenized inputs are preserved.
Here is the original implementation:

@property
    def word_attributions(self) -> Dict:
        "Returns the word attributions for model and the text provided. Raises error if attributions not calculated."

        if self.attributions is not None:
            word_attr = dict()
            tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
            labels = self.predicted_class_names

            for index, attr in self.attributions.items():
                try:
                    predicted_class = self.id2label[torch.argmax(self.pred_probs[index]).item()]
                except KeyError:
                    predicted_class = torch.argmax(self.pred_probs[index]).item()

                word_attr[tokens[index]] = {
                    "label": predicted_class,
                    "attribution_scores": attr.word_attributions,
                }

            return word_attr
        else:
            raise ValueError("Attributions have not yet been calculated. Please call the explainer on text first.")

Below are my modifications to the word_attributions property.

@property
    def word_attributions(self) -> List:
        "Returns the word attributions for model and the text provided. Raises error if attributions not calculated."

        if self.attributions is not None:
            word_attr = []
            tokens = [token.replace("Ġ", "") for token in self.decode(self.input_ids)]
            labels = self.predicted_class_names
            for index in self._selected_indexes:
                try:
                    predicted_class = self.id2label[torch.argmax(self.pred_probs[index]).item()]
                except KeyError:
                    predicted_class = torch.argmax(self.pred_probs[index]).item()

                word_attr.append({
                    "index": index,
                    "token": tokens[index],
                    "label": predicted_class,
                    "attribution_scores": self.attributions[index].word_attributions,
                })

            return word_attr
        else:
            raise ValueError("Attributions have not yet been calculated. Please call the explainer on text first.")

Notes:

  1. I prefer using index in self._selected_indexes: as the iterator for consistency with other methods within the Class.
  2. I have checked that the labels from self.predicted_class_names is consistent to the infered labels in the try...except statement. I think it is better to use the pre-infered labels for consistency. However, please do re-check and verify when you try it.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant