Skip to content

Commit e21ce0c

Browse files
authored
Add AnthropicVertexChatGenerator component (#1192)
* Created a model adapter * Create adapter class and add VertexAPI * Add chat generator for Anthropic Vertex * Add tests * Small fix * Improve doc_strings * Make project_id and region mandatory params * Small fix
1 parent 3b33958 commit e21ce0c

File tree

3 files changed

+334
-1
lines changed

3 files changed

+334
-1
lines changed

integrations/anthropic/src/haystack_integrations/components/generators/anthropic/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44
from .chat.chat_generator import AnthropicChatGenerator
5+
from .chat.vertex_chat_generator import AnthropicVertexChatGenerator
56
from .generator import AnthropicGenerator
67

7-
__all__ = ["AnthropicGenerator", "AnthropicChatGenerator"]
8+
__all__ = ["AnthropicGenerator", "AnthropicChatGenerator", "AnthropicVertexChatGenerator"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import os
2+
from typing import Any, Callable, Dict, Optional
3+
4+
from haystack import component, default_from_dict, default_to_dict, logging
5+
from haystack.dataclasses import StreamingChunk
6+
from haystack.utils import deserialize_callable, serialize_callable
7+
8+
from anthropic import AnthropicVertex
9+
10+
from .chat_generator import AnthropicChatGenerator
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
@component
16+
class AnthropicVertexChatGenerator(AnthropicChatGenerator):
17+
"""
18+
19+
Enables text generation using state-of-the-art Claude 3 LLMs via the Anthropic Vertex AI API.
20+
It supports models such as `Claude 3.5 Sonnet`, `Claude 3 Opus`, `Claude 3 Sonnet`, and `Claude 3 Haiku`,
21+
accessible through the Vertex AI API endpoint.
22+
23+
To use AnthropicVertexChatGenerator, you must have a GCP project with Vertex AI enabled.
24+
Additionally, ensure that the desired Anthropic model is activated in the Vertex AI Model Garden.
25+
Before making requests, you may need to authenticate with GCP using `gcloud auth login`.
26+
For more details, refer to the [guide] (https://docs.anthropic.com/en/api/claude-on-vertex-ai).
27+
28+
Any valid text generation parameters for the Anthropic messaging API can be passed to
29+
the AnthropicVertex API. Users can provide these parameters directly to the component via
30+
the `generation_kwargs` parameter in `__init__` or the `run` method.
31+
32+
For more details on the parameters supported by the Anthropic API, refer to the
33+
Anthropic Message API [documentation](https://docs.anthropic.com/en/api/messages).
34+
35+
```python
36+
from haystack_integrations.components.generators.anthropic import AnthropicVertexChatGenerator
37+
from haystack.dataclasses import ChatMessage
38+
39+
messages = [ChatMessage.from_user("What's Natural Language Processing?")]
40+
client = AnthropicVertexChatGenerator(
41+
model="claude-3-sonnet@20240229",
42+
project_id="your-project-id", region="your-region"
43+
)
44+
response = client.run(messages)
45+
print(response)
46+
47+
>> {'replies': [ChatMessage(content='Natural Language Processing (NLP) is a field of artificial intelligence that
48+
>> focuses on enabling computers to understand, interpret, and generate human language. It involves developing
49+
>> techniques and algorithms to analyze and process text or speech data, allowing machines to comprehend and
50+
>> communicate in natural languages like English, Spanish, or Chinese.', role=<ChatRole.ASSISTANT: 'assistant'>,
51+
>> name=None, meta={'model': 'claude-3-sonnet@20240229', 'index': 0, 'finish_reason': 'end_turn',
52+
>> 'usage': {'input_tokens': 15, 'output_tokens': 64}})]}
53+
```
54+
55+
For more details on supported models and their capabilities, refer to the Anthropic
56+
[documentation](https://docs.anthropic.com/claude/docs/intro-to-claude).
57+
58+
"""
59+
60+
def __init__(
61+
self,
62+
region: str,
63+
project_id: str,
64+
model: str = "claude-3-5-sonnet@20240620",
65+
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
66+
generation_kwargs: Optional[Dict[str, Any]] = None,
67+
ignore_tools_thinking_messages: bool = True,
68+
):
69+
"""
70+
Creates an instance of AnthropicVertexChatGenerator.
71+
72+
:param region: The region where the Anthropic model is deployed. Defaults to "us-central1".
73+
:param project_id: The GCP project ID where the Anthropic model is deployed.
74+
:param model: The name of the model to use.
75+
:param streaming_callback: A callback function that is called when a new token is received from the stream.
76+
The callback function accepts StreamingChunk as an argument.
77+
:param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to
78+
the AnthropicVertex endpoint. See Anthropic [documentation](https://docs.anthropic.com/claude/reference/messages_post)
79+
for more details.
80+
81+
Supported generation_kwargs parameters are:
82+
- `system`: The system message to be passed to the model.
83+
- `max_tokens`: The maximum number of tokens to generate.
84+
- `metadata`: A dictionary of metadata to be passed to the model.
85+
- `stop_sequences`: A list of strings that the model should stop generating at.
86+
- `temperature`: The temperature to use for sampling.
87+
- `top_p`: The top_p value to use for nucleus sampling.
88+
- `top_k`: The top_k value to use for top-k sampling.
89+
- `extra_headers`: A dictionary of extra headers to be passed to the model (i.e. for beta features).
90+
:param ignore_tools_thinking_messages: Anthropic's approach to tools (function calling) resolution involves a
91+
"chain of thought" messages before returning the actual function names and parameters in a message. If
92+
`ignore_tools_thinking_messages` is `True`, the generator will drop so-called thinking messages when tool
93+
use is detected. See the Anthropic [tools](https://docs.anthropic.com/en/docs/tool-use#chain-of-thought-tool-use)
94+
for more details.
95+
"""
96+
self.region = region or os.environ.get("REGION")
97+
self.project_id = project_id or os.environ.get("PROJECT_ID")
98+
self.model = model
99+
self.generation_kwargs = generation_kwargs or {}
100+
self.streaming_callback = streaming_callback
101+
self.client = AnthropicVertex(region=self.region, project_id=self.project_id)
102+
self.ignore_tools_thinking_messages = ignore_tools_thinking_messages
103+
104+
def to_dict(self) -> Dict[str, Any]:
105+
"""
106+
Serialize this component to a dictionary.
107+
108+
:returns:
109+
The serialized component as a dictionary.
110+
"""
111+
callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
112+
return default_to_dict(
113+
self,
114+
region=self.region,
115+
project_id=self.project_id,
116+
model=self.model,
117+
streaming_callback=callback_name,
118+
generation_kwargs=self.generation_kwargs,
119+
ignore_tools_thinking_messages=self.ignore_tools_thinking_messages,
120+
)
121+
122+
@classmethod
123+
def from_dict(cls, data: Dict[str, Any]) -> "AnthropicVertexChatGenerator":
124+
"""
125+
Deserialize this component from a dictionary.
126+
127+
:param data: The dictionary representation of this component.
128+
:returns:
129+
The deserialized component instance.
130+
"""
131+
init_params = data.get("init_parameters", {})
132+
serialized_callback_handler = init_params.get("streaming_callback")
133+
if serialized_callback_handler:
134+
data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
135+
return default_from_dict(cls, data)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import os
2+
3+
import anthropic
4+
import pytest
5+
from haystack.components.generators.utils import print_streaming_chunk
6+
from haystack.dataclasses import ChatMessage, ChatRole
7+
8+
from haystack_integrations.components.generators.anthropic import AnthropicVertexChatGenerator
9+
10+
11+
@pytest.fixture
12+
def chat_messages():
13+
return [
14+
ChatMessage.from_system("\\nYou are a helpful assistant, be super brief in your responses."),
15+
ChatMessage.from_user("What's the capital of France?"),
16+
]
17+
18+
19+
class TestAnthropicVertexChatGenerator:
20+
def test_init_default(self):
21+
component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id")
22+
assert component.region == "us-central1"
23+
assert component.project_id == "test-project-id"
24+
assert component.model == "claude-3-5-sonnet@20240620"
25+
assert component.streaming_callback is None
26+
assert not component.generation_kwargs
27+
assert component.ignore_tools_thinking_messages
28+
29+
def test_init_with_parameters(self):
30+
component = AnthropicVertexChatGenerator(
31+
region="us-central1",
32+
project_id="test-project-id",
33+
model="claude-3-5-sonnet@20240620",
34+
streaming_callback=print_streaming_chunk,
35+
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
36+
ignore_tools_thinking_messages=False,
37+
)
38+
assert component.region == "us-central1"
39+
assert component.project_id == "test-project-id"
40+
assert component.model == "claude-3-5-sonnet@20240620"
41+
assert component.streaming_callback is print_streaming_chunk
42+
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
43+
assert component.ignore_tools_thinking_messages is False
44+
45+
def test_to_dict_default(self):
46+
component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id")
47+
data = component.to_dict()
48+
assert data == {
49+
"type": (
50+
"haystack_integrations.components.generators."
51+
"anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator"
52+
),
53+
"init_parameters": {
54+
"region": "us-central1",
55+
"project_id": "test-project-id",
56+
"model": "claude-3-5-sonnet@20240620",
57+
"streaming_callback": None,
58+
"generation_kwargs": {},
59+
"ignore_tools_thinking_messages": True,
60+
},
61+
}
62+
63+
def test_to_dict_with_parameters(self):
64+
component = AnthropicVertexChatGenerator(
65+
region="us-central1",
66+
project_id="test-project-id",
67+
streaming_callback=print_streaming_chunk,
68+
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
69+
)
70+
data = component.to_dict()
71+
assert data == {
72+
"type": (
73+
"haystack_integrations.components.generators."
74+
"anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator"
75+
),
76+
"init_parameters": {
77+
"region": "us-central1",
78+
"project_id": "test-project-id",
79+
"model": "claude-3-5-sonnet@20240620",
80+
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
81+
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
82+
"ignore_tools_thinking_messages": True,
83+
},
84+
}
85+
86+
def test_to_dict_with_lambda_streaming_callback(self):
87+
component = AnthropicVertexChatGenerator(
88+
region="us-central1",
89+
project_id="test-project-id",
90+
model="claude-3-5-sonnet@20240620",
91+
streaming_callback=lambda x: x,
92+
generation_kwargs={"max_tokens": 10, "some_test_param": "test-params"},
93+
)
94+
data = component.to_dict()
95+
assert data == {
96+
"type": (
97+
"haystack_integrations.components.generators."
98+
"anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator"
99+
),
100+
"init_parameters": {
101+
"region": "us-central1",
102+
"project_id": "test-project-id",
103+
"model": "claude-3-5-sonnet@20240620",
104+
"streaming_callback": "tests.test_vertex_chat_generator.<lambda>",
105+
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
106+
"ignore_tools_thinking_messages": True,
107+
},
108+
}
109+
110+
def test_from_dict(self):
111+
data = {
112+
"type": (
113+
"haystack_integrations.components.generators."
114+
"anthropic.chat.vertex_chat_generator.AnthropicVertexChatGenerator"
115+
),
116+
"init_parameters": {
117+
"region": "us-central1",
118+
"project_id": "test-project-id",
119+
"model": "claude-3-5-sonnet@20240620",
120+
"streaming_callback": "haystack.components.generators.utils.print_streaming_chunk",
121+
"generation_kwargs": {"max_tokens": 10, "some_test_param": "test-params"},
122+
"ignore_tools_thinking_messages": True,
123+
},
124+
}
125+
component = AnthropicVertexChatGenerator.from_dict(data)
126+
assert component.model == "claude-3-5-sonnet@20240620"
127+
assert component.region == "us-central1"
128+
assert component.project_id == "test-project-id"
129+
assert component.streaming_callback is print_streaming_chunk
130+
assert component.generation_kwargs == {"max_tokens": 10, "some_test_param": "test-params"}
131+
132+
def test_run(self, chat_messages, mock_chat_completion):
133+
component = AnthropicVertexChatGenerator(region="us-central1", project_id="test-project-id")
134+
response = component.run(chat_messages)
135+
136+
# check that the component returns the correct ChatMessage response
137+
assert isinstance(response, dict)
138+
assert "replies" in response
139+
assert isinstance(response["replies"], list)
140+
assert len(response["replies"]) == 1
141+
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
142+
143+
def test_run_with_params(self, chat_messages, mock_chat_completion):
144+
component = AnthropicVertexChatGenerator(
145+
region="us-central1", project_id="test-project-id", generation_kwargs={"max_tokens": 10, "temperature": 0.5}
146+
)
147+
response = component.run(chat_messages)
148+
149+
# check that the component calls the Anthropic API with the correct parameters
150+
_, kwargs = mock_chat_completion.call_args
151+
assert kwargs["max_tokens"] == 10
152+
assert kwargs["temperature"] == 0.5
153+
154+
# check that the component returns the correct response
155+
assert isinstance(response, dict)
156+
assert "replies" in response
157+
assert isinstance(response["replies"], list)
158+
assert len(response["replies"]) == 1
159+
assert [isinstance(reply, ChatMessage) for reply in response["replies"]]
160+
161+
@pytest.mark.skipif(
162+
not (os.environ.get("REGION", None) or os.environ.get("PROJECT_ID", None)),
163+
reason="Authenticate with GCP and set env variables REGION and PROJECT_ID to run this test.",
164+
)
165+
@pytest.mark.integration
166+
def test_live_run_wrong_model(self, chat_messages):
167+
component = AnthropicVertexChatGenerator(
168+
model="something-obviously-wrong", region=os.environ.get("REGION"), project_id=os.environ.get("PROJECT_ID")
169+
)
170+
with pytest.raises(anthropic.NotFoundError):
171+
component.run(chat_messages)
172+
173+
@pytest.mark.skipif(
174+
not (os.environ.get("REGION", None) or os.environ.get("PROJECT_ID", None)),
175+
reason="Authenticate with GCP and set env variables REGION and PROJECT_ID to run this test.",
176+
)
177+
@pytest.mark.integration
178+
def test_default_inference_params(self, chat_messages):
179+
client = AnthropicVertexChatGenerator(
180+
region=os.environ.get("REGION"), project_id=os.environ.get("PROJECT_ID"), model="claude-3-sonnet@20240229"
181+
)
182+
response = client.run(chat_messages)
183+
184+
assert "replies" in response, "Response does not contain 'replies' key"
185+
replies = response["replies"]
186+
assert isinstance(replies, list), "Replies is not a list"
187+
assert len(replies) > 0, "No replies received"
188+
189+
first_reply = replies[0]
190+
assert isinstance(first_reply, ChatMessage), "First reply is not a ChatMessage instance"
191+
assert first_reply.content, "First reply has no content"
192+
assert ChatMessage.is_from(first_reply, ChatRole.ASSISTANT), "First reply is not from the assistant"
193+
assert "paris" in first_reply.content.lower(), "First reply does not contain 'paris'"
194+
assert first_reply.meta, "First reply has no metadata"
195+
196+
# Anthropic messages API is similar for AnthropicVertex and Anthropic endpoint,
197+
# remaining tests are skipped for AnthropicVertexChatGenerator as they are already tested in AnthropicChatGenerator.

0 commit comments

Comments
 (0)