Skip to content

Commit c81a6c5

Browse files
anakin87Amnah199
authored andcommitted
fix: Cohere - fix chat message creation (#1289)
1 parent 4eca1d6 commit c81a6c5

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

integrations/cohere/src/haystack_integrations/components/generators/cohere/chat/chat_generator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def run(self, messages: List[ChatMessage], generation_kwargs: Optional[Dict[str,
172172
response_text += event.text
173173
elif event.event_type == "stream-end":
174174
finish_response = event.response
175-
chat_message = ChatMessage.from_assistant(content=response_text)
175+
chat_message = ChatMessage.from_assistant(response_text)
176176

177177
if finish_response and finish_response.meta:
178178
if finish_response.meta.billed_units:
@@ -219,7 +219,7 @@ def _build_message(self, cohere_response):
219219
# TODO revisit to see if we need to handle multiple tool calls
220220
message = ChatMessage.from_assistant(cohere_response.tool_calls[0].json())
221221
elif cohere_response.text:
222-
message = ChatMessage.from_assistant(content=cohere_response.text)
222+
message = ChatMessage.from_assistant(cohere_response.text)
223223
message.meta.update(
224224
{
225225
"model": self.model,

integrations/cohere/tests/test_cohere_chat_generator.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def test_message_to_dict(self, chat_messages):
144144
)
145145
@pytest.mark.integration
146146
def test_live_run(self):
147-
chat_messages = [ChatMessage.from_user(content="What's the capital of France")]
147+
chat_messages = [ChatMessage.from_user("What's the capital of France")]
148148
component = CohereChatGenerator(generation_kwargs={"temperature": 0.8})
149149
results = component.run(chat_messages)
150150
assert len(results["replies"]) == 1
@@ -181,7 +181,7 @@ def __call__(self, chunk: StreamingChunk) -> None:
181181

182182
callback = Callback()
183183
component = CohereChatGenerator(streaming_callback=callback)
184-
results = component.run([ChatMessage.from_user(content="What's the capital of France? answer in a word")])
184+
results = component.run([ChatMessage.from_user("What's the capital of France? answer in a word")])
185185

186186
assert len(results["replies"]) == 1
187187
message: ChatMessage = results["replies"][0]
@@ -202,7 +202,7 @@ def __call__(self, chunk: StreamingChunk) -> None:
202202
)
203203
@pytest.mark.integration
204204
def test_live_run_with_connector(self):
205-
chat_messages = [ChatMessage.from_user(content="What's the capital of France")]
205+
chat_messages = [ChatMessage.from_user("What's the capital of France")]
206206
component = CohereChatGenerator(generation_kwargs={"temperature": 0.8})
207207
results = component.run(chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]})
208208
assert len(results["replies"]) == 1
@@ -227,7 +227,7 @@ def __call__(self, chunk: StreamingChunk) -> None:
227227
self.responses += chunk.content if chunk.content else ""
228228

229229
callback = Callback()
230-
chat_messages = [ChatMessage.from_user(content="What's the capital of France? answer in a word")]
230+
chat_messages = [ChatMessage.from_user("What's the capital of France? answer in a word")]
231231
component = CohereChatGenerator(streaming_callback=callback)
232232
results = component.run(chat_messages, generation_kwargs={"connectors": [{"id": "web-search"}]})
233233

0 commit comments

Comments
 (0)