Skip to content

Commit f08c264

Browse files
authored
feat: add response_format param to OllamaChatGenerator (#1326)
* Add response_ format param to Ollama integration * Add related tests
1 parent 232c537 commit f08c264

File tree

3 files changed

+92
-4
lines changed

3 files changed

+92
-4
lines changed

integrations/ollama/pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers = [
2626
"Programming Language :: Python :: Implementation :: CPython",
2727
"Programming Language :: Python :: Implementation :: PyPy",
2828
]
29-
dependencies = ["haystack-ai", "ollama>=0.4.0"]
29+
dependencies = ["haystack-ai", "ollama>=0.4.0", "pydantic"]
3030

3131
[project.urls]
3232
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/ollama#readme"
@@ -165,5 +165,5 @@ markers = [
165165
addopts = ["--import-mode=importlib"]
166166

167167
[[tool.mypy.overrides]]
168-
module = ["haystack.*", "haystack_integrations.*", "pytest.*", "ollama.*"]
168+
module = ["haystack.*", "haystack_integrations.*", "pytest.*", "ollama.*", "pydantic.*"]
169169
ignore_missing_imports = true

integrations/ollama/src/haystack_integrations/components/generators/ollama/chat/chat_generator.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
from typing import Any, Callable, Dict, List, Optional, Union
1+
from typing import Any, Callable, Dict, List, Literal, Optional, Union
22

33
from haystack import component, default_from_dict, default_to_dict
44
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
55
from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
66
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
7+
from pydantic.json_schema import JsonSchemaValue
78

89
from ollama import ChatResponse, Client
910

@@ -97,6 +98,7 @@ def __init__(
9798
keep_alive: Optional[Union[float, str]] = None,
9899
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
99100
tools: Optional[List[Tool]] = None,
101+
response_format: Optional[Union[None, Literal["json"], JsonSchemaValue]] = None,
100102
):
101103
"""
102104
:param model:
@@ -124,6 +126,11 @@ def __init__(
124126
A list of tools for which the model can prepare calls.
125127
Not all models support tools. For a list of models compatible with tools, see the
126128
[models page](https://ollama.com/search?c=tools).
129+
:param response_format:
130+
The format for structured model outputs. The value can be:
131+
- None: No specific structure or format is applied to the response. The response is returned as-is.
132+
- "json": The response is formatted as a JSON object.
133+
- JSON Schema: The response is formatted as a JSON object that adheres to the specified JSON Schema.
127134
"""
128135

129136
_check_duplicate_tool_names(tools)
@@ -135,7 +142,7 @@ def __init__(
135142
self.keep_alive = keep_alive
136143
self.streaming_callback = streaming_callback
137144
self.tools = tools
138-
145+
self.response_format = response_format
139146
self._client = Client(host=self.url, timeout=self.timeout)
140147

141148
def to_dict(self) -> Dict[str, Any]:
@@ -156,6 +163,7 @@ def to_dict(self) -> Dict[str, Any]:
156163
timeout=self.timeout,
157164
streaming_callback=callback_name,
158165
tools=serialized_tools,
166+
response_format=self.response_format,
159167
)
160168

161169
@classmethod
@@ -237,6 +245,14 @@ def run(
237245
msg = "Ollama does not support tools and streaming at the same time. Please choose one."
238246
raise ValueError(msg)
239247

248+
if self.response_format and tools:
249+
msg = "Ollama does not support tools and response_format at the same time. Please choose one."
250+
raise ValueError(msg)
251+
252+
if self.response_format and stream:
253+
msg = "Ollama does not support streaming and response_format at the same time. Please choose one."
254+
raise ValueError(msg)
255+
240256
ollama_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools] if tools else None
241257

242258
ollama_messages = [_convert_chatmessage_to_ollama_format(msg) for msg in messages]
@@ -247,6 +263,7 @@ def run(
247263
stream=stream,
248264
keep_alive=self.keep_alive,
249265
options=generation_kwargs,
266+
format=self.response_format,
250267
)
251268

252269
if stream:

integrations/ollama/tests/test_chat_generator.py

+71
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def test_init_default(self):
165165
assert component.streaming_callback is None
166166
assert component.tools is None
167167
assert component.keep_alive is None
168+
assert component.response_format is None
168169

169170
def test_init(self, tools):
170171
component = OllamaChatGenerator(
@@ -175,6 +176,7 @@ def test_init(self, tools):
175176
keep_alive="10m",
176177
streaming_callback=print_streaming_chunk,
177178
tools=tools,
179+
response_format={"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "number"}}},
178180
)
179181

180182
assert component.model == "llama2"
@@ -184,6 +186,10 @@ def test_init(self, tools):
184186
assert component.keep_alive == "10m"
185187
assert component.streaming_callback is print_streaming_chunk
186188
assert component.tools == tools
189+
assert component.response_format == {
190+
"type": "object",
191+
"properties": {"name": {"type": "string"}, "age": {"type": "number"}},
192+
}
187193

188194
def test_init_fail_with_duplicate_tool_names(self, tools):
189195

@@ -206,6 +212,7 @@ def test_to_dict(self):
206212
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
207213
tools=[tool],
208214
keep_alive="5m",
215+
response_format={"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "number"}}},
209216
)
210217
data = component.to_dict()
211218
assert data == {
@@ -235,6 +242,10 @@ def test_to_dict(self):
235242
},
236243
},
237244
],
245+
"response_format": {
246+
"type": "object",
247+
"properties": {"name": {"type": "string"}, "age": {"type": "number"}},
248+
},
238249
},
239250
}
240251

@@ -273,6 +284,10 @@ def test_from_dict(self):
273284
},
274285
},
275286
],
287+
"response_format": {
288+
"type": "object",
289+
"properties": {"name": {"type": "string"}, "age": {"type": "number"}},
290+
},
276291
},
277292
}
278293
component = OllamaChatGenerator.from_dict(data)
@@ -286,6 +301,10 @@ def test_from_dict(self):
286301
}
287302
assert component.timeout == 120
288303
assert component.tools == [tool]
304+
assert component.response_format == {
305+
"type": "object",
306+
"properties": {"name": {"type": "string"}, "age": {"type": "number"}},
307+
}
289308

290309
@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
291310
def test_run(self, mock_client):
@@ -319,6 +338,7 @@ def test_run(self, mock_client):
319338
tools=None,
320339
options={},
321340
keep_alive=None,
341+
format=None,
322342
)
323343

324344
assert "replies" in result
@@ -456,3 +476,54 @@ def test_run_with_tools(self, tools):
456476
assert isinstance(tool_call, ToolCall)
457477
assert tool_call.tool_name == "weather"
458478
assert tool_call.arguments == {"city": "Paris"}
479+
480+
@pytest.mark.integration
481+
def test_run_with_response_format(self):
482+
response_format = {
483+
"type": "object",
484+
"properties": {"capital": {"type": "string"}, "population": {"type": "number"}},
485+
}
486+
chat_generator = OllamaChatGenerator(model="llama3.2:3b", response_format=response_format)
487+
488+
message = ChatMessage.from_user("What's the capital of France and its population?")
489+
response = chat_generator.run([message])
490+
491+
assert isinstance(response, dict)
492+
assert isinstance(response["replies"], list)
493+
494+
# Parse the response text as JSON and verify its structure
495+
response_data = json.loads(response["replies"][0].text)
496+
assert isinstance(response_data, dict)
497+
assert "capital" in response_data
498+
assert isinstance(response_data["capital"], str)
499+
assert "population" in response_data
500+
assert isinstance(response_data["population"], (int, float))
501+
assert response_data["capital"] == "Paris"
502+
503+
def test_run_with_streaming_and_format(self):
504+
response_format = {
505+
"type": "object",
506+
"properties": {"answer": {"type": "string"}},
507+
}
508+
streaming_callback = Mock()
509+
chat_generator = OllamaChatGenerator(
510+
model="llama3.2:3b", streaming_callback=streaming_callback, response_format=response_format
511+
)
512+
513+
chat_messages = [
514+
ChatMessage.from_user("What is the largest city in the United Kingdom by population?"),
515+
ChatMessage.from_assistant("London is the largest city in the United Kingdom by population"),
516+
ChatMessage.from_user("And what is the second largest?"),
517+
]
518+
with pytest.raises(ValueError):
519+
chat_generator.run([chat_messages])
520+
521+
def test_run_with_tools_and_format(self, tools):
522+
response_format = {
523+
"type": "object",
524+
"properties": {"capital": {"type": "string"}, "population": {"type": "number"}},
525+
}
526+
chat_generator = OllamaChatGenerator(model="llama3.2:3b", tools=tools, response_format=response_format)
527+
message = ChatMessage.from_user("What's the weather in Paris?")
528+
with pytest.raises(ValueError):
529+
chat_generator.run([message])

0 commit comments

Comments
 (0)