Skip to content

refactor!: AmazonBedrockGenerator - remove truncation #1314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion integrations/amazon_bedrock/pydoc/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ loaders:
"haystack_integrations.components.generators.amazon_bedrock.generator",
"haystack_integrations.components.generators.amazon_bedrock.adapters",
"haystack_integrations.common.amazon_bedrock.errors",
"haystack_integrations.components.generators.amazon_bedrock.handlers",
"haystack_integrations.components.generators.amazon_bedrock.chat.chat_generator",
"haystack_integrations.components.embedders.amazon_bedrock.text_embedder",
"haystack_integrations.components.embedders.amazon_bedrock.document_embedder",
Expand Down
3 changes: 1 addition & 2 deletions integrations/amazon_bedrock/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dependencies = ["haystack-ai", "boto3>=1.28.57", "transformers!=4.48.*"]
dependencies = ["haystack-ai", "boto3>=1.28.57"]

[project.urls]
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/amazon_bedrock#readme"
Expand Down Expand Up @@ -155,7 +155,6 @@ exclude_lines = ["no cov", "if __name__ == .__main__.:", "if TYPE_CHECKING:"]
[[tool.mypy.overrides]]
module = [
"botocore.*",
"transformers.*",
"boto3.*",
"haystack.*",
"haystack_integrations.*",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import re
import warnings
from typing import Any, Callable, ClassVar, Dict, List, Literal, Optional, Type, get_args

from botocore.config import Config
Expand All @@ -25,9 +26,6 @@
MetaLlamaAdapter,
MistralAdapter,
)
from .handlers import (
DefaultPromptHandler,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -105,8 +103,8 @@ def __init__(
aws_session_token: Optional[Secret] = Secret.from_env_var("AWS_SESSION_TOKEN", strict=False), # noqa: B008
aws_region_name: Optional[Secret] = Secret.from_env_var("AWS_DEFAULT_REGION", strict=False), # noqa: B008
aws_profile_name: Optional[Secret] = Secret.from_env_var("AWS_PROFILE", strict=False), # noqa: B008
max_length: Optional[int] = 100,
truncate: Optional[bool] = True,
max_length: Optional[int] = None,
truncate: Optional[bool] = None,
streaming_callback: Optional[Callable[[StreamingChunk], None]] = None,
boto3_config: Optional[Dict[str, Any]] = None,
model_family: Optional[MODEL_FAMILIES] = None,
Expand All @@ -121,8 +119,8 @@ def __init__(
:param aws_session_token: The AWS session token.
:param aws_region_name: The AWS region name. Make sure the region you set supports Amazon Bedrock.
:param aws_profile_name: The AWS profile name.
:param max_length: The maximum length of the generated text.
:param truncate: Whether to truncate the prompt or not.
:param max_length: Deprecated. This parameter no longer has any effect.
:param truncate: Deprecated. This parameter no longer has any effect.
:param streaming_callback: A callback function that is called when a new token is received from the stream.
The callback function accepts StreamingChunk as an argument.
:param boto3_config: The configuration for the boto3 client.
Expand All @@ -140,6 +138,13 @@ def __init__(
self.model = model
self.max_length = max_length
self.truncate = truncate

if max_length is not None or truncate is not None:
warnings.warn(
"The 'max_length' and 'truncate' parameters have been removed and no longer have any effect. "
"No truncation will be performed.",
)

self.aws_access_key_id = aws_access_key_id
self.aws_secret_access_key = aws_secret_access_key
self.aws_session_token = aws_session_token
Expand Down Expand Up @@ -173,44 +178,10 @@ def resolve_secret(secret: Optional[Secret]) -> Optional[str]:
raise AmazonBedrockConfigurationError(msg) from exception

model_input_kwargs = kwargs
# We pop the model_max_length as it is not sent to the model but used to truncate the prompt if needed
model_max_length = kwargs.get("model_max_length", 4096)

# we initialize the prompt handler only if truncate is True: we avoid unnecessarily downloading the tokenizer
if self.truncate:
# Truncate prompt if prompt tokens > model_max_length-max_length
# (max_length is the length of the generated text)
# we use GPT2 tokenizer which will likely provide good token count approximation
self.prompt_handler = DefaultPromptHandler(
tokenizer="gpt2",
model_max_length=model_max_length,
max_length=self.max_length or 100,
)

model_adapter_cls = self.get_model_adapter(model=model, model_family=model_family)
self.model_adapter = model_adapter_cls(model_kwargs=model_input_kwargs, max_length=self.max_length)

def _ensure_token_limit(self, prompt: str) -> str:
"""
Ensures that the prompt and answer token lengths together are within the model_max_length specified during
the initialization of the component.

:param prompt: The prompt to be sent to the model.
:returns: The resized prompt.
"""
resize_info = self.prompt_handler(prompt)
if resize_info["prompt_length"] != resize_info["new_prompt_length"]:
logger.warning(
"The prompt was truncated from %s tokens to %s tokens so that the prompt length and "
"the answer length (%s tokens) fit within the model's max token limit (%s tokens). "
"Shorten the prompt or it will be cut off.",
resize_info["prompt_length"],
max(0, resize_info["model_max_length"] - resize_info["max_length"]), # type: ignore
resize_info["max_length"],
resize_info["model_max_length"],
)
return str(resize_info["resized_prompt"])

@component.output_types(replies=List[str])
def run(
self,
Expand All @@ -235,9 +206,6 @@ def run(
streaming_callback = streaming_callback or self.streaming_callback
generation_kwargs["stream"] = streaming_callback is not None

if self.truncate:
prompt = self._ensure_token_limit(prompt)

body = self.model_adapter.prepare_body(prompt=prompt, **generation_kwargs)
try:
if streaming_callback:
Expand Down

This file was deleted.

8 changes: 0 additions & 8 deletions integrations/amazon_bedrock/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,3 @@ def set_env_variables(monkeypatch):
def mock_boto3_session():
with patch("boto3.Session") as mock_client:
yield mock_client


@pytest.fixture
def mock_prompt_handler():
with patch(
"haystack_integrations.components.generators.amazon_bedrock.handlers.DefaultPromptHandler"
) as mock_prompt_handler:
yield mock_prompt_handler
126 changes: 1 addition & 125 deletions integrations/amazon_bedrock/tests/test_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any, Dict, Optional, Type
from unittest.mock import MagicMock, call, patch
from unittest.mock import MagicMock, call

import pytest
from haystack.dataclasses import StreamingChunk
Expand Down Expand Up @@ -107,9 +107,6 @@ def test_default_constructor(mock_boto3_session, set_env_variables):
assert layer.max_length == 99
assert layer.model == "anthropic.claude-v2"

assert layer.prompt_handler is not None
assert layer.prompt_handler.model_max_length == 4096

# assert mocked boto3 client called exactly once
mock_boto3_session.assert_called_once()

Expand All @@ -123,23 +120,6 @@ def test_default_constructor(mock_boto3_session, set_env_variables):
)


def test_constructor_prompt_handler_initialized(mock_boto3_session, mock_prompt_handler):
"""
Test that the constructor sets the prompt_handler correctly, with the correct model_max_length for llama-2
"""
layer = AmazonBedrockGenerator(model="anthropic.claude-v2", prompt_handler=mock_prompt_handler)
assert layer.prompt_handler is not None
assert layer.prompt_handler.model_max_length == 4096


def test_prompt_handler_absent_when_truncate_false(mock_boto3_session):
"""
Test that the prompt_handler is not initialized when truncate is set to False.
"""
generator = AmazonBedrockGenerator(model="anthropic.claude-v2", truncate=False)
assert not hasattr(generator, "prompt_handler")


def test_constructor_with_model_kwargs(mock_boto3_session):
"""
Test that model_kwargs are correctly set in the constructor
Expand All @@ -159,110 +139,6 @@ def test_constructor_with_empty_model():
AmazonBedrockGenerator(model="")


def test_short_prompt_is_not_truncated(mock_boto3_session):
"""
Test that a short prompt is not truncated
"""
# Define a short mock prompt and its tokenized version
mock_prompt_text = "I am a tokenized prompt"
mock_prompt_tokens = mock_prompt_text.split()

# Mock the tokenizer so it returns our predefined tokens
mock_tokenizer = MagicMock()
mock_tokenizer.tokenize.return_value = mock_prompt_tokens

# We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens
# Since our mock prompt is 5 tokens long, it doesn't exceed the
# total limit (5 prompt tokens + 3 generated tokens < 10 tokens)
max_length_generated_text = 3
total_model_max_length = 10

with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer):
layer = AmazonBedrockGenerator(
"anthropic.claude-v2",
max_length=max_length_generated_text,
model_max_length=total_model_max_length,
)
prompt_after_resize = layer._ensure_token_limit(mock_prompt_text)

# The prompt doesn't exceed the limit, _ensure_token_limit doesn't truncate it
assert prompt_after_resize == mock_prompt_text


def test_long_prompt_is_truncated(mock_boto3_session):
"""
Test that a long prompt is truncated
"""
# Define a long mock prompt and its tokenized version
long_prompt_text = "I am a tokenized prompt of length eight"
long_prompt_tokens = long_prompt_text.split()

# _ensure_token_limit will truncate the prompt to make it fit into the model's max token limit
truncated_prompt_text = "I am a tokenized prompt of length"

# Mock the tokenizer to return our predefined tokens
# convert tokens to our predefined truncated text
mock_tokenizer = MagicMock()
mock_tokenizer.tokenize.return_value = long_prompt_tokens
mock_tokenizer.convert_tokens_to_string.return_value = truncated_prompt_text

# We set a small max_length for generated text (3 tokens) and a total model_max_length of 10 tokens
# Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens)
max_length_generated_text = 3
total_model_max_length = 10

with patch("transformers.AutoTokenizer.from_pretrained", return_value=mock_tokenizer):
layer = AmazonBedrockGenerator(
"anthropic.claude-v2",
max_length=max_length_generated_text,
model_max_length=total_model_max_length,
)
prompt_after_resize = layer._ensure_token_limit(long_prompt_text)

# The prompt exceeds the limit, _ensure_token_limit truncates it
assert prompt_after_resize == truncated_prompt_text


def test_long_prompt_is_not_truncated_when_truncate_false(mock_boto3_session):
"""
Test that a long prompt is not truncated and _ensure_token_limit is not called when truncate is set to False
"""
long_prompt_text = "I am a tokenized prompt of length eight"

# Our mock prompt is 8 tokens long, so it exceeds the total limit (8 prompt tokens + 3 generated tokens > 10 tokens)
max_length_generated_text = 3
total_model_max_length = 10

with patch("transformers.AutoTokenizer.from_pretrained", return_value=MagicMock()):
generator = AmazonBedrockGenerator(
model="anthropic.claude-v2",
max_length=max_length_generated_text,
model_max_length=total_model_max_length,
truncate=False,
)

# Mock the _ensure_token_limit method to track if it is called
with patch.object(
generator, "_ensure_token_limit", wraps=generator._ensure_token_limit
) as mock_ensure_token_limit:
# Mock the model adapter to avoid actual invocation
generator.model_adapter.prepare_body = MagicMock(return_value={})
generator.client = MagicMock()
generator.client.invoke_model = MagicMock(
return_value={"body": MagicMock(read=MagicMock(return_value=b'{"generated_text": "response"}'))}
)
generator.model_adapter.get_responses = MagicMock(return_value=["response"])

# Invoke the generator
generator.run(prompt=long_prompt_text)

# Ensure _ensure_token_limit was not called
mock_ensure_token_limit.assert_not_called()

# Check the prompt passed to prepare_body
generator.model_adapter.prepare_body.assert_called_with(prompt=long_prompt_text, stream=False)


@pytest.mark.parametrize(
"model, expected_model_adapter",
[
Expand Down
Loading