|
8 | 8 | import re
|
9 | 9 | import tempfile
|
10 | 10 | import warnings
|
11 |
| -from collections.abc import Callable |
| 11 | +from collections.abc import Callable, Generator |
12 | 12 | from pathlib import Path
|
13 | 13 | from typing import TYPE_CHECKING, Literal
|
14 | 14 |
|
|
30 | 30 |
|
31 | 31 | if TYPE_CHECKING:
|
32 | 32 | from gradio.blocks import Blocks
|
| 33 | + from gradio.chat_interface import ChatInterface |
33 | 34 | from gradio.interface import Interface
|
34 | 35 |
|
35 | 36 |
|
@@ -581,3 +582,66 @@ def fn(*data):
|
581 | 582 | kwargs["_api_mode"] = True
|
582 | 583 | interface = gradio.Interface(**kwargs)
|
583 | 584 | return interface
|
| 585 | + |
| 586 | + |
| 587 | +@document() |
| 588 | +def load_chat( |
| 589 | + base_url: str, |
| 590 | + model: str, |
| 591 | + token: str | None = None, |
| 592 | + *, |
| 593 | + system_message: str | None = None, |
| 594 | + streaming: bool = True, |
| 595 | +) -> ChatInterface: |
| 596 | + """ |
| 597 | + Load a chat interface from an OpenAI API chat compatible endpoint. |
| 598 | + Parameters: |
| 599 | + base_url: The base URL of the endpoint. |
| 600 | + model: The model name. |
| 601 | + token: The API token. |
| 602 | + system_message: The system message for the conversation, if any. |
| 603 | + streaming: Whether the response should be streamed. |
| 604 | + """ |
| 605 | + try: |
| 606 | + from openai import OpenAI |
| 607 | + except ImportError as e: |
| 608 | + raise ImportError( |
| 609 | + "To use OpenAI API Client, you must install the `openai` package. You can install it with `pip install openai`." |
| 610 | + ) from e |
| 611 | + from gradio.chat_interface import ChatInterface |
| 612 | + |
| 613 | + client = OpenAI(api_key=token, base_url=base_url) |
| 614 | + start_message = ( |
| 615 | + [{"role": "system", "content": system_message}] if system_message else [] |
| 616 | + ) |
| 617 | + |
| 618 | + def open_api(message: str, history: list | None) -> str: |
| 619 | + history = history or start_message |
| 620 | + if len(history) > 0 and isinstance(history[0], (list, tuple)): |
| 621 | + history = ChatInterface._tuples_to_messages(history) |
| 622 | + return ( |
| 623 | + client.chat.completions.create( |
| 624 | + model=model, |
| 625 | + messages=history + [{"role": "user", "content": message}], |
| 626 | + ) |
| 627 | + .choices[0] |
| 628 | + .message.content |
| 629 | + ) |
| 630 | + |
| 631 | + def open_api_stream( |
| 632 | + message: str, history: list | None |
| 633 | + ) -> Generator[str, None, None]: |
| 634 | + history = history or start_message |
| 635 | + if len(history) > 0 and isinstance(history[0], (list, tuple)): |
| 636 | + history = ChatInterface._tuples_to_messages(history) |
| 637 | + stream = client.chat.completions.create( |
| 638 | + model=model, |
| 639 | + messages=history + [{"role": "user", "content": message}], |
| 640 | + stream=True, |
| 641 | + ) |
| 642 | + response = "" |
| 643 | + for chunk in stream: |
| 644 | + response += chunk.choices[0].delta.content |
| 645 | + yield response |
| 646 | + |
| 647 | + return ChatInterface(open_api_stream if streaming else open_api, type="messages") |
0 commit comments