Skip to content

Commit 8997c1e

Browse files
authored
FEAT: RESTful API (#40)
1 parent 8e28844 commit 8997c1e

File tree

4 files changed

+636
-131
lines changed

4 files changed

+636
-131
lines changed

plexar/client.py

+99-5
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,22 @@
1414

1515
import asyncio
1616
import uuid
17-
from typing import List, Optional, Tuple
17+
from typing import Iterator, List, Optional, Tuple, Union
1818

19+
import requests
1920
import xoscar as xo
2021

2122
from .core.model import ModelActor
2223
from .core.service import SupervisorActor
2324
from .isolation import Isolation
2425
from .model import ModelSpec
26+
from .model.llm.types import (
27+
ChatCompletion,
28+
ChatCompletionChunk,
29+
ChatCompletionMessage,
30+
Completion,
31+
CompletionChunk,
32+
)
2533

2634

2735
class Client:
@@ -44,7 +52,7 @@ def launch_model(
4452
model_size_in_billions: Optional[int] = None,
4553
model_format: Optional[str] = None,
4654
quantization: Optional[str] = None,
47-
**kwargs
55+
**kwargs,
4856
) -> str:
4957
model_uid = self.gen_model_uid()
5058

@@ -54,7 +62,7 @@ def launch_model(
5462
model_size_in_billions=model_size_in_billions,
5563
model_format=model_format,
5664
quantization=quantization,
57-
**kwargs
65+
**kwargs,
5866
)
5967
self._isolation.call(coro)
6068

@@ -73,6 +81,92 @@ def get_model(self, model_uid: str) -> xo.ActorRefType["ModelActor"]:
7381
return self._isolation.call(coro)
7482

7583

84+
class RESTfulClient:
85+
def __init__(self, base_url):
86+
self.base_url = base_url
87+
88+
@classmethod
89+
def gen_model_uid(cls) -> str:
90+
# generate a time-based uuid.
91+
return str(uuid.uuid1())
92+
93+
def list_models(self) -> List[str]:
94+
url = f"{self.base_url}/v1/models"
95+
96+
response = requests.get(url)
97+
response_data = response.json()
98+
return response_data
99+
100+
def launch_model(
101+
self,
102+
model_name: str,
103+
model_size_in_billions: Optional[int] = None,
104+
model_format: Optional[str] = None,
105+
quantization: Optional[str] = None,
106+
**kwargs,
107+
) -> str:
108+
url = f"{self.base_url}/v1/models"
109+
110+
model_uid = self.gen_model_uid()
111+
payload = {
112+
"model_uid": model_uid,
113+
"model_name": model_name,
114+
"model_size_in_billions": model_size_in_billions,
115+
"model_format": model_format,
116+
"quantization": quantization,
117+
"kwargs": kwargs,
118+
}
119+
response = requests.post(url, json=payload)
120+
response_data = response.json()
121+
model_uid = response_data["model_uid"]
122+
return model_uid
123+
124+
def terminate_model(self, model_uid: str):
125+
url = f"{self.base_url}/v1/models/{model_uid}"
126+
127+
response = requests.delete(url)
128+
if response.status_code != 200:
129+
raise Exception(f"Error terminating the model.")
130+
131+
def generate(
132+
self, model_uid: str, prompt: str, **kwargs
133+
) -> Union[Completion, Iterator[CompletionChunk]]:
134+
url = f"{self.base_url}/v1/completions"
135+
136+
request_body = {"model": model_uid, "prompt": prompt, **kwargs}
137+
response = requests.post(url, json=request_body)
138+
response_data = response.json()
139+
return response_data
140+
141+
def chat(
142+
self,
143+
model_uid: str,
144+
prompt: str,
145+
system_prompt: Optional[str] = None,
146+
chat_history: Optional[List[ChatCompletionMessage]] = None,
147+
**kwargs,
148+
) -> Union[ChatCompletion, Iterator[ChatCompletionChunk]]:
149+
url = f"{self.base_url}/v1/chat/completions"
150+
151+
if chat_history is None:
152+
chat_history = []
153+
154+
if chat_history and chat_history[0]["role"] == "system":
155+
if system_prompt is not None:
156+
chat_history[0]["content"] = system_prompt
157+
else:
158+
if system_prompt is not None:
159+
chat_history.insert(
160+
0, ChatCompletionMessage(role="system", content=system_prompt)
161+
)
162+
163+
chat_history.append(ChatCompletionMessage(role="user", content=prompt))
164+
request_body = {"model": model_uid, "messages": chat_history, **kwargs}
165+
response = requests.post(url, json=request_body)
166+
response_data = response.json()
167+
return response_data
168+
169+
76170
class AsyncClient:
77171
def __init__(self, supervisor_address: str):
78172
self._supervisor_address = supervisor_address
@@ -96,7 +190,7 @@ async def launch_model(
96190
model_size_in_billions: Optional[int] = None,
97191
model_format: Optional[str] = None,
98192
quantization: Optional[str] = None,
99-
**kwargs
193+
**kwargs,
100194
) -> str:
101195
model_uid = self.gen_model_uid()
102196

@@ -107,7 +201,7 @@ async def launch_model(
107201
model_size_in_billions=model_size_in_billions,
108202
model_format=model_format,
109203
quantization=quantization,
110-
**kwargs
204+
**kwargs,
111205
)
112206
return model_uid
113207

0 commit comments

Comments
 (0)