Skip to content

feat: add response_format param to OllamaChatGenerator #1326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jan 29, 2025
2 changes: 1 addition & 1 deletion integrations/ollama/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai", "ollama>=0.4.0"]
dependencies = ["haystack-ai", "ollama>=0.4.0", "pydantic"]

[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/ollama#readme"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Union

from haystack import component, default_from_dict, default_to_dict
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
from haystack.tools import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
from pydantic.json_schema import JsonSchemaValue # type: ignore

from ollama import ChatResponse, Client

Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__(
keep_alive: Optional[Union[float, str]] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
tools: Optional[List[Tool]] = None,
response_format: Optional[Union[None, Literal["json"], JsonSchemaValue]] = None,
):
"""
:param model:
Expand Down Expand Up @@ -124,6 +126,11 @@ def __init__(
A list of tools for which the model can prepare calls.
Not all models support tools. For a list of models compatible with tools, see the
[models page](https://ollama.com/search?c=tools).
:param response_format:
The format for structured model outputs. The value can be:
- None: The default response format is used.
- "json": The response is formatted as a JSON object.
- JsonSchemaValue: The response is formatted as a JSON object that adheres to the specified JSON Schema.
"""

_check_duplicate_tool_names(tools)
Expand All @@ -135,7 +142,7 @@ def __init__(
self.keep_alive = keep_alive
self.streaming_callback = streaming_callback
self.tools = tools

self.response_format = response_format
self._client = Client(host=self.url, timeout=self.timeout)

def to_dict(self) -> Dict[str, Any]:
Expand All @@ -156,6 +163,7 @@ def to_dict(self) -> Dict[str, Any]:
timeout=self.timeout,
streaming_callback=callback_name,
tools=serialized_tools,
response_format=self.response_format,
)

@classmethod
Expand Down Expand Up @@ -237,6 +245,10 @@ def run(
msg = "Ollama does not support tools and streaming at the same time. Please choose one."
raise ValueError(msg)

if self.response_format and tools:
msg = "Ollama does not support tools and response_format at the same time. Please choose one."
raise ValueError(msg)

ollama_tools = [{"type": "function", "function": {**t.tool_spec}} for t in tools] if tools else None

ollama_messages = [_convert_chatmessage_to_ollama_format(msg) for msg in messages]
Expand All @@ -247,6 +259,7 @@ def run(
stream=stream,
keep_alive=self.keep_alive,
options=generation_kwargs,
format=self.response_format,
)

if stream:
Expand Down
52 changes: 52 additions & 0 deletions integrations/ollama/tests/test_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def test_init_default(self):
assert component.streaming_callback is None
assert component.tools is None
assert component.keep_alive is None
assert component.response_format is None

def test_init(self, tools):
component = OllamaChatGenerator(
Expand All @@ -175,6 +176,7 @@ def test_init(self, tools):
keep_alive="10m",
streaming_callback=print_streaming_chunk,
tools=tools,
response_format={"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "number"}}},
)

assert component.model == "llama2"
Expand All @@ -184,6 +186,10 @@ def test_init(self, tools):
assert component.keep_alive == "10m"
assert component.streaming_callback is print_streaming_chunk
assert component.tools == tools
assert component.response_format == {
"type": "object",
"properties": {"name": {"type": "string"}, "age": {"type": "number"}},
}

def test_init_fail_with_duplicate_tool_names(self, tools):

Expand All @@ -206,6 +212,7 @@ def test_to_dict(self):
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
tools=[tool],
keep_alive="5m",
response_format={"type": "object", "properties": {"name": {"type": "string"}, "age": {"type": "number"}}},
)
data = component.to_dict()
assert data == {
Expand Down Expand Up @@ -235,6 +242,10 @@ def test_to_dict(self):
},
},
],
"response_format": {
"type": "object",
"properties": {"name": {"type": "string"}, "age": {"type": "number"}},
},
},
}

Expand Down Expand Up @@ -273,6 +284,10 @@ def test_from_dict(self):
},
},
],
"response_format": {
"type": "object",
"properties": {"name": {"type": "string"}, "age": {"type": "number"}},
},
},
}
component = OllamaChatGenerator.from_dict(data)
Expand All @@ -286,6 +301,10 @@ def test_from_dict(self):
}
assert component.timeout == 120
assert component.tools == [tool]
assert component.response_format == {
"type": "object",
"properties": {"name": {"type": "string"}, "age": {"type": "number"}},
}

@patch("haystack_integrations.components.generators.ollama.chat.chat_generator.Client")
def test_run(self, mock_client):
Expand Down Expand Up @@ -319,6 +338,7 @@ def test_run(self, mock_client):
tools=None,
options={},
keep_alive=None,
format=None,
)

assert "replies" in result
Expand Down Expand Up @@ -456,3 +476,35 @@ def test_run_with_tools(self, tools):
assert isinstance(tool_call, ToolCall)
assert tool_call.tool_name == "weather"
assert tool_call.arguments == {"city": "Paris"}

@pytest.mark.integration
def test_run_with_response_format(self):
response_format = {
"type": "object",
"properties": {"capital": {"type": "string"}, "population": {"type": "number"}},
}
chat_generator = OllamaChatGenerator(model="llama3.2:3b", response_format=response_format)

message = ChatMessage.from_user("What's the capital of France and its population?")
response = chat_generator.run([message])
assert isinstance(response, dict)
assert isinstance(response["replies"], list)

# Parse the response text as JSON and verify its structure
response_data = json.loads(response["replies"][0].text)
assert isinstance(response_data, dict)
assert "capital" in response_data
assert isinstance(response_data["capital"], str)
assert "population" in response_data
assert isinstance(response_data["population"], (int, float))
assert response_data["capital"] == "Paris"

def test_run_with_tools_and_format(self, tools):
response_format = {
"type": "object",
"properties": {"capital": {"type": "string"}, "population": {"type": "number"}},
}
chat_generator = OllamaChatGenerator(model="llama3.2:3b", tools=tools, response_format=response_format)
message = ChatMessage.from_user("What's the weather in Paris?")
with pytest.raises(ValueError):
chat_generator.run([message])