Skip to content

Commit a517a58

Browse files
authored
fix: sagemaker config and chat methods (#1142)
1 parent b0e2582 commit a517a58

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

private_gpt/components/llm/custom/sagemaker.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import io
55
import json
66
import logging
7-
from typing import TYPE_CHECKING, Any
7+
from typing import TYPE_CHECKING
88

99
import boto3 # type: ignore
1010
from llama_index.bridge.pydantic import Field
@@ -13,7 +13,14 @@
1313
CustomLLM,
1414
LLMMetadata,
1515
)
16-
from llama_index.llms.base import llm_completion_callback
16+
from llama_index.llms.base import (
17+
llm_chat_callback,
18+
llm_completion_callback,
19+
)
20+
from llama_index.llms.generic_utils import (
21+
completion_response_to_chat_response,
22+
stream_completion_response_to_chat_response,
23+
)
1724
from llama_index.llms.llama_utils import (
1825
completion_to_prompt as generic_completion_to_prompt,
1926
)
@@ -22,8 +29,14 @@
2229
)
2330

2431
if TYPE_CHECKING:
32+
from collections.abc import Sequence
33+
from typing import Any
34+
2535
from llama_index.callbacks import CallbackManager
2636
from llama_index.llms import (
37+
ChatMessage,
38+
ChatResponse,
39+
ChatResponseGen,
2740
CompletionResponseGen,
2841
)
2942

@@ -247,3 +260,17 @@ def get_stream():
247260
yield CompletionResponse(delta=delta, text=text, raw=data)
248261

249262
return get_stream()
263+
264+
@llm_chat_callback()
265+
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
266+
prompt = self.messages_to_prompt(messages)
267+
completion_response = self.complete(prompt, formatted=True, **kwargs)
268+
return completion_response_to_chat_response(completion_response)
269+
270+
@llm_chat_callback()
271+
def stream_chat(
272+
self, messages: Sequence[ChatMessage], **kwargs: Any
273+
) -> ChatResponseGen:
274+
prompt = self.messages_to_prompt(messages)
275+
completion_response = self.stream_complete(prompt, formatted=True, **kwargs)
276+
return stream_completion_response_to_chat_response(completion_response)

private_gpt/components/llm/llm_component.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ def __init__(self) -> None:
3737

3838
self.llm = SagemakerLLM(
3939
endpoint_name=settings.sagemaker.endpoint_name,
40-
messages_to_prompt=messages_to_prompt,
41-
completion_to_prompt=completion_to_prompt,
4240
)
4341
case "openai":
4442
from llama_index.llms import OpenAI

0 commit comments

Comments
 (0)