diff --git a/litellm/llms/sagemaker/chat/transformation.py b/litellm/llms/sagemaker/chat/transformation.py index 14dde144af1b..af2bee0ade44 100644 --- a/litellm/llms/sagemaker/chat/transformation.py +++ b/litellm/llms/sagemaker/chat/transformation.py @@ -58,6 +58,13 @@ def validate_environment( api_key: Optional[str] = None, api_base: Optional[str] = None, ) -> dict: + model_id = optional_params.get("model_id", None) + if model_id is not None: + # Add model_id as InferenceComponentName header + # boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html + headers.update( + {"X-Amzn-SageMaker-Inference-Component": model_id} + ) return headers def get_complete_url( @@ -108,6 +115,47 @@ def sign_request( fake_stream=fake_stream, ) + def transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + messages = self._transform_messages(messages=messages, model=model) + model_id = optional_params.get("model_id", None) + request_data = { + "model": model, + "messages": messages, + **optional_params, + } + if model_id: + del request_data["model"] + return request_data + + async def async_transform_request( + self, + model: str, + messages: List[AllMessageValues], + optional_params: dict, + litellm_params: dict, + headers: dict, + ) -> dict: + transformed_messages = await self._transform_messages( + messages=messages, model=model, is_async=True + ) + model_id = optional_params.get("model_id", None) + + request_data = { + "model": model, + "messages": transformed_messages, + **optional_params, + } + if model_id: + del request_data["model"] + return request_data + @property def has_custom_stream_wrapper(self) -> bool: return True diff --git a/tests/test_litellm/llms/sagemaker/test_sagemaker_common_utils.py b/tests/test_litellm/llms/sagemaker/test_sagemaker_common_utils.py index 70a8d86cb1b7..b1e955dba62b 100644 --- a/tests/test_litellm/llms/sagemaker/test_sagemaker_common_utils.py +++ b/tests/test_litellm/llms/sagemaker/test_sagemaker_common_utils.py @@ -9,6 +9,8 @@ sys.path.insert(0, os.path.abspath("../../../../..")) from litellm.llms.sagemaker.common_utils import AWSEventStreamDecoder from litellm.llms.sagemaker.completion.transformation import SagemakerConfig +from litellm.llms.sagemaker.chat.transformation import SagemakerChatConfig + @pytest.mark.asyncio @@ -138,3 +140,37 @@ def test_mistral_max_tokens_backward_compat(self): # The function should properly map max_tokens if max_completion_tokens is not provided assert result == {"temperature": 0.7, "max_new_tokens": 200} + + +class TestSagemakerChatTransform: + def setup_method(self): + self.config = SagemakerChatConfig() + self.model = "test" + self.messages = [] + self.optional_params = {"model_id": "test"} + self.logging_obj = MagicMock() + + def test_inference_component_header(self): + """Test that inference component headers are present""" + + result = self.config.validate_environment( + headers={}, + model=self.model, + messages=[], + optional_params=self.optional_params, + litellm_params={}, + ) + assert result == {"X-Amzn-SageMaker-Inference-Component": "test"} + + def test_inference_component_model_not_in_request(self): + """Test that `model` is not part of the request body""" + + result = self.config.transform_request( + model=self.model, + messages=self.messages, + optional_params=self.optional_params, + litellm_params=None, + headers=None, + ) + + assert "model" not in result \ No newline at end of file