Skip to content

[fix] Use return_dict=True in Transformer; improve how all_layer_embeddings are determined #3320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions sentence_transformers/models/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,8 +439,9 @@ def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torc
if key in ["input_ids", "attention_mask", "token_type_ids", "inputs_embeds"]
}

output_states = self.auto_model(**trans_features, **kwargs, return_dict=False)
output_tokens = output_states[0]
outputs = self.auto_model(**trans_features, **kwargs, return_dict=True)
token_embeddings = outputs[0]
features["token_embeddings"] = token_embeddings

# If the AutoModel is wrapped with a PeftModelForFeatureExtraction, then it may have added virtual tokens
# We need to extend the attention mask to include these virtual tokens, or the pooling will fail
Expand All @@ -451,22 +452,15 @@ def forward(self, features: dict[str, torch.Tensor], **kwargs) -> dict[str, torc
isinstance(self.auto_model, PeftModelForFeatureExtraction)
and self.auto_model.active_peft_config.is_prompt_learning
):
batch_size = output_tokens.size(0)
batch_size = token_embeddings.size(0)
attention_mask = features["attention_mask"]
prefix_attention_mask = torch.ones(
batch_size, self.auto_model.active_peft_config.num_virtual_tokens, device=attention_mask.device
)
features["attention_mask"] = torch.cat((prefix_attention_mask, attention_mask), dim=1)

features["token_embeddings"] = output_tokens

if self.auto_model.config.output_hidden_states and len(output_states) > 2:
all_layer_idx = 2 # I.e. after last_hidden_states and pooler_output
if len(output_states) < 3: # Some models only output last_hidden_states and all_hidden_states
all_layer_idx = 1

hidden_states = output_states[all_layer_idx]
features["all_layer_embeddings"] = hidden_states
if self.auto_model.config.output_hidden_states and "hidden_states" in outputs:
features["all_layer_embeddings"] = outputs["hidden_states"]

return features

Expand Down
Loading