diff --git a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py index 8efa8cda7..dbcab619d 100644 --- a/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py +++ b/integrations/google_ai/src/haystack_integrations/components/generators/google_ai/chat/gemini.py @@ -313,9 +313,24 @@ def _get_response(self, response_body: GenerateContentResponse) -> List[ChatMess """ replies: List[ChatMessage] = [] metadata = response_body.to_dict() + + # currently Google only supports one candidate and usage metadata reflects this + # this should be refactored when multiple candidates are supported + usage_metadata_openai_format = {} + + usage_metadata = metadata.get("usage_metadata") + if usage_metadata: + usage_metadata_openai_format = { + "prompt_tokens": usage_metadata["prompt_token_count"], + "completion_tokens": usage_metadata["candidates_token_count"], + "total_tokens": usage_metadata["total_token_count"], + } + for idx, candidate in enumerate(response_body.candidates): candidate_metadata = metadata["candidates"][idx] candidate_metadata.pop("content", None) # we remove content from the metadata + if usage_metadata_openai_format: + candidate_metadata["usage"] = usage_metadata_openai_format for part in candidate.content.parts: if part.text != "": diff --git a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py index c4372db0d..cb42f0ff8 100644 --- a/integrations/google_ai/tests/generators/chat/test_chat_gemini.py +++ b/integrations/google_ai/tests/generators/chat/test_chat_gemini.py @@ -295,5 +295,11 @@ def test_past_conversation(): ] response = gemini_chat.run(messages=messages) assert "replies" in response - assert len(response["replies"]) > 0 - assert all(reply.role == ChatRole.ASSISTANT for reply in response["replies"]) + replies = response["replies"] + assert len(replies) > 0 + assert all(reply.role == ChatRole.ASSISTANT for reply in replies) + + assert all("usage" in reply.meta for reply in replies) + assert all("prompt_tokens" in reply.meta["usage"] for reply in replies) + assert all("completion_tokens" in reply.meta["usage"] for reply in replies) + assert all("total_tokens" in reply.meta["usage"] for reply in replies)