Skip to content

Commit 9c6d83d

Browse files
aliabid94Ali Abidgradio-pr-botabidlabs
authored
gr.load_chat: Allow loading any openai-compatible server immediately as a ChatInterface (#10222)
* changes * add changeset * add changeset * Update gradio/external.py Co-authored-by: Abubakar Abid <[email protected]> * changes * changes * Update guides/05_chatbots/01_creating-a-chatbot-fast.md Co-authored-by: Abubakar Abid <[email protected]> * changes --------- Co-authored-by: Ali Abid <[email protected]> Co-authored-by: gradio-pr-bot <[email protected]> Co-authored-by: Abubakar Abid <[email protected]>
1 parent 64d1864 commit 9c6d83d

File tree

5 files changed

+85
-3
lines changed

5 files changed

+85
-3
lines changed

.changeset/thick-dingos-help.md

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
"gradio": minor
3+
---
4+
5+
feat:gr.load_chat: Allow loading any openai-compatible server immediately as a ChatInterface

gradio/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
on,
7878
)
7979
from gradio.exceptions import Error
80-
from gradio.external import load
80+
from gradio.external import load, load_chat
8181
from gradio.flagging import (
8282
CSVLogger,
8383
FlaggingCallback,

gradio/chat_interface.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,10 @@ def __init__(
155155
self.type = type
156156
self.multimodal = multimodal
157157
self.concurrency_limit = concurrency_limit
158-
self.fn = fn
158+
if isinstance(fn, ChatInterface):
159+
self.fn = fn.fn
160+
else:
161+
self.fn = fn
159162
self.is_async = inspect.iscoroutinefunction(
160163
self.fn
161164
) or inspect.isasyncgenfunction(self.fn)

gradio/external.py

+65-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import re
99
import tempfile
1010
import warnings
11-
from collections.abc import Callable
11+
from collections.abc import Callable, Generator
1212
from pathlib import Path
1313
from typing import TYPE_CHECKING, Literal
1414

@@ -30,6 +30,7 @@
3030

3131
if TYPE_CHECKING:
3232
from gradio.blocks import Blocks
33+
from gradio.chat_interface import ChatInterface
3334
from gradio.interface import Interface
3435

3536

@@ -581,3 +582,66 @@ def fn(*data):
581582
kwargs["_api_mode"] = True
582583
interface = gradio.Interface(**kwargs)
583584
return interface
585+
586+
587+
@document()
588+
def load_chat(
589+
base_url: str,
590+
model: str,
591+
token: str | None = None,
592+
*,
593+
system_message: str | None = None,
594+
streaming: bool = True,
595+
) -> ChatInterface:
596+
"""
597+
Load a chat interface from an OpenAI API chat compatible endpoint.
598+
Parameters:
599+
base_url: The base URL of the endpoint.
600+
model: The model name.
601+
token: The API token.
602+
system_message: The system message for the conversation, if any.
603+
streaming: Whether the response should be streamed.
604+
"""
605+
try:
606+
from openai import OpenAI
607+
except ImportError as e:
608+
raise ImportError(
609+
"To use OpenAI API Client, you must install the `openai` package. You can install it with `pip install openai`."
610+
) from e
611+
from gradio.chat_interface import ChatInterface
612+
613+
client = OpenAI(api_key=token, base_url=base_url)
614+
start_message = (
615+
[{"role": "system", "content": system_message}] if system_message else []
616+
)
617+
618+
def open_api(message: str, history: list | None) -> str:
619+
history = history or start_message
620+
if len(history) > 0 and isinstance(history[0], (list, tuple)):
621+
history = ChatInterface._tuples_to_messages(history)
622+
return (
623+
client.chat.completions.create(
624+
model=model,
625+
messages=history + [{"role": "user", "content": message}],
626+
)
627+
.choices[0]
628+
.message.content
629+
)
630+
631+
def open_api_stream(
632+
message: str, history: list | None
633+
) -> Generator[str, None, None]:
634+
history = history or start_message
635+
if len(history) > 0 and isinstance(history[0], (list, tuple)):
636+
history = ChatInterface._tuples_to_messages(history)
637+
stream = client.chat.completions.create(
638+
model=model,
639+
messages=history + [{"role": "user", "content": message}],
640+
stream=True,
641+
)
642+
response = ""
643+
for chunk in stream:
644+
response += chunk.choices[0].delta.content
645+
yield response
646+
647+
return ChatInterface(open_api_stream if streaming else open_api, type="messages")

guides/05_chatbots/01_creating-a-chatbot-fast.md

+10
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,16 @@ This tutorial uses `gr.ChatInterface()`, which is a high-level abstraction that
1414
$ pip install --upgrade gradio
1515
```
1616

17+
## Quickly loading from Ollama or any OpenAI-API compatible endpoint
18+
19+
If you have a chat server serving an OpenAI API compatible endpoint (skip ahead if you don't), you can spin up a ChatInterface in a single line. First, also run `pip install openai`. Then, with your own URL, model, and optional token:
20+
21+
```python
22+
import gradio as gr
23+
24+
gr.load_chat("http://localhost:11434/v1/", model="llama3.2", token=None).launch()
25+
```
26+
1727
## Defining a chat function
1828

1929
When working with `gr.ChatInterface()`, the first thing you should do is define your **chat function**. In the simplest case, your chat function should accept two arguments: `message` and `history` (the arguments can be named anything, but must be in this order).

0 commit comments

Comments
 (0)