Skip to content

Commit 3e17629

Browse files
authored
feat: Add streaming_callback to run methods of OllamaGenerator and OllamaChatGenerator (#1636)
* Add streaming_callback to runtime of ollama * Add tests
1 parent d514c7b commit 3e17629

File tree

4 files changed

+105
-16
lines changed

4 files changed

+105
-16
lines changed

integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -258,16 +258,18 @@ def _build_chunk(self, chunk_response: Any) -> StreamingChunk:
258258
chunk_message = StreamingChunk(content, meta)
259259
return chunk_message
260260

261-
def _handle_streaming_response(self, response) -> Dict[str, List[Any]]:
261+
def _handle_streaming_response(
262+
self, response: Any, streaming_callback: Optional[Callable[[StreamingChunk], None]]
263+
) -> Dict[str, List[Any]]:
262264
"""
263265
Handles streaming response and converts it to Haystack format
264266
"""
265267
chunks: List[StreamingChunk] = []
266268
for chunk in response:
267269
chunk_delta = self._build_chunk(chunk)
268270
chunks.append(chunk_delta)
269-
if self.streaming_callback is not None:
270-
self.streaming_callback(chunk_delta)
271+
if streaming_callback is not None:
272+
streaming_callback(chunk_delta)
271273

272274
replies = [ChatMessage.from_assistant("".join([c.content for c in chunks]))]
273275
meta = {key: value for key, value in chunks[0].meta.items() if key != "message"}
@@ -280,6 +282,8 @@ def run(
280282
messages: List[ChatMessage],
281283
generation_kwargs: Optional[Dict[str, Any]] = None,
282284
tools: Optional[List[Tool]] = None,
285+
*,
286+
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
283287
):
284288
"""
285289
Runs an Ollama Model on a given chat history.
@@ -293,12 +297,15 @@ def run(
293297
:param tools:
294298
A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
295299
during component initialization.
300+
:param streaming_callback:
301+
A callback function that is called when a new token is received from the stream.
296302
:returns: A dictionary with the following keys:
297303
- `replies`: The responses from the model
298304
"""
299305
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
306+
resolved_streaming_callback = streaming_callback or self.streaming_callback
300307

301-
stream = self.streaming_callback is not None
308+
stream = resolved_streaming_callback is not None
302309
tools = tools or self.tools
303310
_check_duplicate_tool_names(tools)
304311

@@ -328,6 +335,6 @@ def run(
328335
)
329336

330337
if stream:
331-
return self._handle_streaming_response(response)
338+
return self._handle_streaming_response(response, resolved_streaming_callback)
332339

333340
return {"replies": [_convert_ollama_response_to_chatmessage(response)]}

integrations/ollama/src/haystack_integrations/components/generators/ollama/generator.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -137,16 +137,18 @@ def _convert_to_streaming_response(self, chunks: List[StreamingChunk]) -> Dict[s
137137

138138
return {"replies": replies, "meta": [meta]}
139139

140-
def _handle_streaming_response(self, response) -> List[StreamingChunk]:
140+
def _handle_streaming_response(
141+
self, response: Any, streaming_callback: Optional[Callable[[StreamingChunk], None]]
142+
) -> List[StreamingChunk]:
141143
"""
142144
Handles Streaming response cases
143145
"""
144146
chunks: List[StreamingChunk] = []
145147
for chunk in response:
146148
chunk_delta: StreamingChunk = self._build_chunk(chunk)
147149
chunks.append(chunk_delta)
148-
if self.streaming_callback is not None:
149-
self.streaming_callback(chunk_delta)
150+
if streaming_callback is not None:
151+
streaming_callback(chunk_delta)
150152
return chunks
151153

152154
def _build_chunk(self, chunk_response: Any) -> StreamingChunk:
@@ -165,6 +167,8 @@ def run(
165167
self,
166168
prompt: str,
167169
generation_kwargs: Optional[Dict[str, Any]] = None,
170+
*,
171+
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
168172
):
169173
"""
170174
Runs an Ollama Model on the given prompt.
@@ -175,20 +179,27 @@ def run(
175179
Optional arguments to pass to the Ollama generation endpoint, such as temperature,
176180
top_p, and others. See the available arguments in
177181
[Ollama docs](https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values).
182+
:param streaming_callback:
183+
A callback function that is called when a new token is received from the stream.
178184
:returns: A dictionary with the following keys:
179185
- `replies`: The responses from the model
180186
- `meta`: The metadata collected during the run
181187
"""
182188
generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
183189

184-
stream = self.streaming_callback is not None
190+
resolved_streaming_callback = streaming_callback or self.streaming_callback
191+
stream = resolved_streaming_callback is not None
185192

186193
response = self._client.generate(
187-
model=self.model, prompt=prompt, stream=stream, keep_alive=self.keep_alive, options=generation_kwargs
194+
model=self.model,
195+
prompt=prompt,
196+
stream=stream,
197+
keep_alive=self.keep_alive,
198+
options=generation_kwargs,
188199
)
189200

190201
if stream:
191-
chunks: List[StreamingChunk] = self._handle_streaming_response(response)
202+
chunks: List[StreamingChunk] = self._handle_streaming_response(response, resolved_streaming_callback)
192203
return self._convert_to_streaming_response(chunks)
193204

194205
return self._convert_to_response(response)

integrations/ollama/tests/test_chat_generator.py

+52-3
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,51 @@ def streaming_callback(_: StreamingChunk) -> None:
418418
assert result["replies"][0].text == "first chunk second chunk"
419419
assert result["replies"][0].role == "assistant"
420420

421+
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
422+
def test_run_streaming_at_runtime(self, mock_client):
423+
streaming_callback_called = False
424+
425+
def streaming_callback(_: StreamingChunk) -> None:
426+
nonlocal streaming_callback_called
427+
streaming_callback_called = True
428+
429+
generator = OllamaChatGenerator(streaming_callback=None)
430+
431+
mock_response = iter(
432+
[
433+
ChatResponse(
434+
model="llama3.2",
435+
created_at="2023-12-12T14:13:43.416799Z",
436+
message={"role": "assistant", "content": "first chunk "},
437+
done=False,
438+
),
439+
ChatResponse(
440+
model="llama3.2",
441+
created_at="2023-12-12T14:13:43.416799Z",
442+
message={"role": "assistant", "content": "second chunk"},
443+
done=True,
444+
total_duration=4883583458,
445+
load_duration=1334875,
446+
prompt_eval_count=26,
447+
prompt_eval_duration=342546000,
448+
eval_count=282,
449+
eval_duration=4535599000,
450+
),
451+
]
452+
)
453+
454+
mock_client_instance = mock_client.return_value
455+
mock_client_instance.chat.return_value = mock_response
456+
457+
result = generator.run(messages=[ChatMessage.from_user("irrelevant")], streaming_callback=streaming_callback)
458+
459+
assert streaming_callback_called
460+
461+
assert "replies" in result
462+
assert len(result["replies"]) == 1
463+
assert result["replies"][0].text == "first chunk second chunk"
464+
assert result["replies"][0].role == "assistant"
465+
421466
def test_run_fail_with_tools_and_streaming(self, tools):
422467
component = OllamaChatGenerator(tools=tools, streaming_callback=print_streaming_chunk)
423468

@@ -459,7 +504,9 @@ def test_run_with_chat_history(self):
459504
assert isinstance(response, dict)
460505
assert isinstance(response["replies"], list)
461506

462-
assert any(city in response["replies"][-1].text for city in ["Manchester", "Birmingham", "Glasgow"])
507+
assert any(
508+
city.lower() in response["replies"][-1].text.lower() for city in ["Manchester", "Birmingham", "Glasgow"]
509+
)
463510

464511
@pytest.mark.integration
465512
def test_run_model_unavailable(self):
@@ -486,7 +533,9 @@ def test_run_with_streaming(self):
486533

487534
assert isinstance(response, dict)
488535
assert isinstance(response["replies"], list)
489-
assert any(city in response["replies"][-1].text for city in ["Manchester", "Birmingham", "Glasgow"])
536+
assert any(
537+
city.lower() in response["replies"][-1].text.lower() for city in ["Manchester", "Birmingham", "Glasgow"]
538+
)
490539

491540
@pytest.mark.integration
492541
def test_run_with_tools(self, tools):
@@ -525,7 +574,7 @@ def test_run_with_response_format(self):
525574
assert isinstance(response_data["capital"], str)
526575
assert "population" in response_data
527576
assert isinstance(response_data["population"], (int, float))
528-
assert response_data["capital"] == "Paris"
577+
assert response_data["capital"].lower() == "paris"
529578

530579
def test_run_with_streaming_and_format(self):
531580
response_format = {

integrations/ollama/tests/test_generator.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def test_from_dict(self):
135135
assert component.keep_alive == "5m"
136136

137137
@pytest.mark.integration
138-
def test_ollama_generator_run_streaming(self):
138+
def test_ollama_generator_streaming(self):
139139
class Callback:
140140
def __init__(self):
141141
self.responses = ""
@@ -151,7 +151,29 @@ def __call__(self, chunk):
151151
results = component.run(prompt="What's the capital of Netherlands?")
152152

153153
assert len(results["replies"]) == 1
154-
assert "Amsterdam" in results["replies"][0]
154+
assert "amsterdam" in results["replies"][0].lower()
155+
assert len(results["meta"]) == 1
156+
assert callback.responses == results["replies"][0]
157+
assert callback.count_calls > 1
158+
159+
@pytest.mark.integration
160+
def test_ollama_generator_streaming_in_run(self):
161+
class Callback:
162+
def __init__(self):
163+
self.responses = ""
164+
self.count_calls = 0
165+
166+
def __call__(self, chunk):
167+
self.responses += chunk.content
168+
self.count_calls += 1
169+
return chunk
170+
171+
callback = Callback()
172+
component = OllamaGenerator(model="llama3.2:3b", streaming_callback=None)
173+
results = component.run(prompt="What's the capital of Netherlands?", streaming_callback=callback)
174+
175+
assert len(results["replies"]) == 1
176+
assert "amsterdam" in results["replies"][0].lower()
155177
assert len(results["meta"]) == 1
156178
assert callback.responses == results["replies"][0]
157179
assert callback.count_calls > 1

0 commit comments

Comments
 (0)