|
4 | 4 | import io
|
5 | 5 | import json
|
6 | 6 | import logging
|
7 |
| -from typing import TYPE_CHECKING, Any |
| 7 | +from typing import TYPE_CHECKING |
8 | 8 |
|
9 | 9 | import boto3 # type: ignore
|
10 | 10 | from llama_index.bridge.pydantic import Field
|
|
13 | 13 | CustomLLM,
|
14 | 14 | LLMMetadata,
|
15 | 15 | )
|
16 |
| -from llama_index.llms.base import llm_completion_callback |
| 16 | +from llama_index.llms.base import ( |
| 17 | + llm_chat_callback, |
| 18 | + llm_completion_callback, |
| 19 | +) |
| 20 | +from llama_index.llms.generic_utils import ( |
| 21 | + completion_response_to_chat_response, |
| 22 | + stream_completion_response_to_chat_response, |
| 23 | +) |
17 | 24 | from llama_index.llms.llama_utils import (
|
18 | 25 | completion_to_prompt as generic_completion_to_prompt,
|
19 | 26 | )
|
|
22 | 29 | )
|
23 | 30 |
|
24 | 31 | if TYPE_CHECKING:
|
| 32 | + from collections.abc import Sequence |
| 33 | + from typing import Any |
| 34 | + |
25 | 35 | from llama_index.callbacks import CallbackManager
|
26 | 36 | from llama_index.llms import (
|
| 37 | + ChatMessage, |
| 38 | + ChatResponse, |
| 39 | + ChatResponseGen, |
27 | 40 | CompletionResponseGen,
|
28 | 41 | )
|
29 | 42 |
|
@@ -247,3 +260,17 @@ def get_stream():
|
247 | 260 | yield CompletionResponse(delta=delta, text=text, raw=data)
|
248 | 261 |
|
249 | 262 | return get_stream()
|
| 263 | + |
| 264 | + @llm_chat_callback() |
| 265 | + def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: |
| 266 | + prompt = self.messages_to_prompt(messages) |
| 267 | + completion_response = self.complete(prompt, formatted=True, **kwargs) |
| 268 | + return completion_response_to_chat_response(completion_response) |
| 269 | + |
| 270 | + @llm_chat_callback() |
| 271 | + def stream_chat( |
| 272 | + self, messages: Sequence[ChatMessage], **kwargs: Any |
| 273 | + ) -> ChatResponseGen: |
| 274 | + prompt = self.messages_to_prompt(messages) |
| 275 | + completion_response = self.stream_complete(prompt, formatted=True, **kwargs) |
| 276 | + return stream_completion_response_to_chat_response(completion_response) |
0 commit comments