Skip to content

genai: Fix the issue task_type passed into embed_query is ignored. #908

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion libs/genai/langchain_google_genai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,9 @@ def embed_query(
Returns:
Embedding for the text.
"""
task_type = self.task_type or "RETRIEVAL_QUERY"
task_type_to_use = task_type if task_type else self.task_type
if task_type_to_use is None:
task_type_to_use = "RETRIEVAL_QUERY" # Default to RETRIEVAL_QUERY
try:
request: EmbedContentRequest = self._prepare_request(
text=text,
Expand Down
20 changes: 20 additions & 0 deletions libs/genai/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,23 @@ def test_embed_documents_quality() -> None:
np.array(dissimilar_embeddings[0]) - np.array(dissimilar_embeddings[1])
)
assert similar_distance < dissimilar_distance


def test_embed_query_task_type() -> None:
"""Test for task_type"""

embeddings = GoogleGenerativeAIEmbeddings(model=_MODEL, task_type="clustering")
emb = embeddings.embed_query("How does alphafold work?", output_dimensionality=768)

embeddings2 = GoogleGenerativeAIEmbeddings(model=_MODEL)
emb2 = embeddings2.embed_query(
"How does alphafold work?", task_type="clustering", output_dimensionality=768
)

embeddings3 = GoogleGenerativeAIEmbeddings(model=_MODEL)
emb3 = embeddings3.embed_query(
"How does alphafold work?", output_dimensionality=768
)

assert emb == emb2
assert emb != emb3
Loading