Skip to content

Commit 526a911

Browse files
authored
Fix the type hints in CGISTEmbedLoss (#3272)
1 parent f20e75d commit 526a911

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

sentence_transformers/losses/CachedGISTEmbedLoss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def embed_minibatch(
175175
with_grad: bool,
176176
copy_random_state: bool,
177177
random_state: RandContext | None = None,
178-
) -> tuple[Tensor, RandContext | None]:
178+
) -> tuple[Tensor, Tensor, RandContext | None]:
179179
"""Do forward pass on a minibatch of the input features and return corresponding embeddings."""
180180
grad_context = nullcontext if with_grad else torch.no_grad
181181
random_state_context = nullcontext() if random_state is None else random_state
@@ -203,7 +203,7 @@ def embed_minibatch_iter(
203203
with_grad: bool,
204204
copy_random_state: bool,
205205
random_states: list[RandContext] | None = None,
206-
) -> Iterator[tuple[Tensor, RandContext | None]]:
206+
) -> Iterator[tuple[Tensor, Tensor, RandContext | None]]:
207207
"""Do forward pass on all the minibatches of the input features and yield corresponding embeddings."""
208208
input_ids: Tensor = sentence_feature["input_ids"]
209209
bsz, _ = input_ids.shape

0 commit comments

Comments
 (0)