Skip to content

Add support for SageMaker Inference Components in sagemaker chat #10603

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8633d1b
Add support for SageMaker Inference Components in sagemaker chat
bobbywlindsey May 6, 2025
507f25e
Merge branch 'BerriAI:main' into main
bobbywlindsey May 12, 2025
c8db5f0
Merge branch 'BerriAI:main' into main
bobbywlindsey May 12, 2025
cd9ab8e
Integrate with migration to comnon base_llm_http_handler
bobbywlindsey May 12, 2025
b24c9e4
Remove extraneous new line
bobbywlindsey May 12, 2025
4b2afe2
Merge branch 'BerriAI:main' into main
bobbywlindsey May 12, 2025
2275ab0
Merge branch 'BerriAI:main' into main
bobbywlindsey May 13, 2025
e4e54d1
Add test for inference components
bobbywlindsey May 13, 2025
0668d0c
Move request logic to transform_request
bobbywlindsey May 13, 2025
6995b8d
Keep request transformation logic in sagemaker chat
bobbywlindsey May 13, 2025
b3b122e
Fix formatting errors
bobbywlindsey May 13, 2025
64d31a1
I can't see
bobbywlindsey May 13, 2025
3a3123f
Merge branch 'BerriAI:main' into main
bobbywlindsey May 14, 2025
53cab3b
Merge branch 'BerriAI:main' into main
bobbywlindsey May 15, 2025
a6f5c42
Merge branch 'BerriAI:main' into main
bobbywlindsey May 16, 2025
6b76154
Merge branch 'BerriAI:main' into main
bobbywlindsey May 19, 2025
848ba7e
Merge branch 'BerriAI:main' into main
bobbywlindsey May 20, 2025
e8ab570
Merge branch 'BerriAI:main' into main
bobbywlindsey May 21, 2025
8c63bda
Merge branch 'BerriAI:main' into main
bobbywlindsey May 21, 2025
76cd1d5
Merge branch 'BerriAI:main' into main
bobbywlindsey May 21, 2025
b27cce8
Merge branch 'BerriAI:main' into main
bobbywlindsey May 22, 2025
5b70e0d
Merge branch 'BerriAI:main' into main
bobbywlindsey May 23, 2025
daf3a4f
Resolve merge conflict in test
bobbywlindsey May 27, 2025
13e52e5
Merge remote-tracking branch 'upstream/main'
bobbywlindsey May 27, 2025
8755ee9
Merge branch 'BerriAI:main' into main
bobbywlindsey May 28, 2025
ad591d6
Merge branch 'BerriAI:main' into main
bobbywlindsey May 30, 2025
71f1575
Merge branch 'BerriAI:main' into main
bobbywlindsey Jun 1, 2025
d96bb61
Merge branch 'BerriAI:main' into main
bobbywlindsey Jun 2, 2025
d2d7bc9
Merge branch 'BerriAI:main' into main
bobbywlindsey Jun 3, 2025
e907cca
Merge branch 'BerriAI:main' into main
bobbywlindsey Jun 5, 2025
de2c6df
Merge branch 'BerriAI:main' into main
bobbywlindsey Jun 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions litellm/llms/openai/chat/gpt_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,11 +383,16 @@ def transform_request(
dict: The transformed request. Sent as the body of the API call.
"""
messages = self._transform_messages(messages=messages, model=model)
return {
model_id = optional_params.get("model_id", None)
request_data = {
"model": model,
"messages": messages,
**optional_params,
}
if model_id:
del request_data["model"]

return request_data

async def async_transform_request(
self,
Expand All @@ -401,11 +406,16 @@ async def async_transform_request(
messages=messages, model=model, is_async=True
)

return {
model_id = optional_params.get("model_id", None)
request_data = {
"model": model,
"messages": transformed_messages,
**optional_params,
}
if model_id:
del request_data["model"]

return request_data

def _passed_in_tools(self, optional_params: dict) -> bool:
return optional_params.get("tools", None) is not None
Expand Down
3 changes: 0 additions & 3 deletions litellm/llms/sagemaker/chat/transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,6 @@ def sign_request(
stream: Optional[bool] = None,
fake_stream: Optional[bool] = None,
) -> Tuple[dict, Optional[bytes]]:
model_id = optional_params.get("model_id", None)
if model_id:
del request_data["model"]
return self._sign_request(
service_name="sagemaker",
headers=headers,
Expand Down
30 changes: 9 additions & 21 deletions tests/litellm/llms/sagemaker/test_sagemaker_common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,25 +153,13 @@ def test_inference_component_header(self):

def test_inference_component_model_not_in_request(self):
"""Test that `model` is not part of the request body"""
test_params = {"model_id": "test"}

with patch(
"litellm.llms.sagemaker.chat.transformation.SagemakerChatConfig._sign_request"
) as mock_sign_request:
self.config.sign_request(
headers={"X-Amzn-SageMaker-Inference-Component": "test"},
optional_params=test_params,
request_data={"model": self.model},
api_base="",
)

mock_sign_request.assert_called_once_with(
service_name="sagemaker",
headers={"X-Amzn-SageMaker-Inference-Component": "test"},
optional_params=test_params,
request_data={},
api_base="",
model=None,
stream=None,
fake_stream=None,
)
result = self.config.transform_request(
model=self.model,
messages=self.messages,
optional_params=self.optional_params,
litellm_params=None,
headers=None,
)

assert "model" not in result
Loading