Skip to content

Commit 7715267

Browse files
fix(router.py): simplify scheduler
move the scheduler poll queuing logic into the router class, making it easier to use
1 parent 27087f6 commit 7715267

File tree

5 files changed

+177
-131
lines changed

5 files changed

+177
-131
lines changed

docs/my-website/docs/scheduler.md

+9-47
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@ Prioritize LLM API requests in high-traffic.
2222
## Quick Start
2323

2424
```python
25-
from litellm import Scheduler, FlowItem, Router
26-
27-
scheduler = Scheduler()
25+
from litellm import Router
2826

2927
router = Router(
3028
model_list=[
@@ -39,53 +37,17 @@ router = Router(
3937
],
4038
timeout=2, # timeout request if takes > 2s
4139
routing_strategy="usage-based-routing-v2",
40+
polling_interval=0.03 # poll queue every 3ms if no healthy deployments
4241
)
4342

44-
scheduler.update_variables(llm_router=router)
45-
46-
### 🚨 IMPORTANT ###
47-
48-
item = FlowItem(
49-
priority=0, # 👈 SET PRIORITY FOR REQUEST
50-
request_id=str(uuid.uuid4()), # 👈 SET REQUEST ID
51-
model_name="gpt-3.5-turbo" # 👈 SAME as 'Router'
52-
)
53-
54-
### [fin] IMPORTANT ###
55-
56-
## ADDS REQUEST TO QUEUE ##
57-
await scheduler.add_request(request=item)
58-
59-
## POLL QUEUE
60-
default_timeout = router.timeout
61-
end_time = time.time() + default_timeout
62-
poll_interval = 0.03 # poll every 3ms
63-
curr_time = time.time()
64-
65-
make_request = False
66-
67-
while curr_time < end_time:
68-
make_request = await scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue
69-
id=item.request_id, model_name=item.model_name
43+
try:
44+
_response = await router.schedule_acompletion( # 👈 ADDS TO QUEUE + POLLS + MAKES CALL
45+
model=item.model_name,
46+
messages=[{"role": "user", "content": "Hey!"}],
47+
priority=0, # 👈 LOWER IS BETTER
7048
)
71-
if make_request: ## IF TRUE -> MAKE REQUEST
72-
break
73-
else: ## ELSE -> loop till default_timeout
74-
await asyncio.sleep(poll_interval)
75-
curr_time = time.time()
76-
77-
if make_request:
78-
try:
79-
_response = await router.acompletion(
80-
model=item.model_name,
81-
messages=[{"role": "user", "content": "Hey!"}],
82-
)
83-
except Exception as e:
84-
print("{}, {}, {}".format(item.priority, item.request_id, "Error occurred"))
85-
86-
print("{}, {}, {}".format(item.priority, item.request_id, time.time()))
87-
88-
print("didn't make request")
49+
except Exception as e:
50+
print("didn't make request")
8951
```
9052

9153
## LiteLLM Proxy

litellm/proxy/proxy_server.py

+3-57
Original file line numberDiff line numberDiff line change
@@ -398,8 +398,6 @@ async def openai_exception_handler(request: Request, exc: ProxyException):
398398
async_result = None
399399
celery_app_conn = None
400400
celery_fn = None # Redis Queue for handling requests
401-
### SIMPLE QUEUE ###
402-
simple_scheduler = Scheduler()
403401
### DB WRITER ###
404402
db_writer_client: Optional[HTTPHandler] = None
405403
### logger ###
@@ -3705,7 +3703,7 @@ def on_backoff(details):
37053703

37063704
@router.on_event("startup")
37073705
async def startup_event():
3708-
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db, simple_scheduler
3706+
global prisma_client, master_key, use_background_health_checks, llm_router, llm_model_list, general_settings, proxy_budget_rescheduler_min_time, proxy_budget_rescheduler_max_time, litellm_proxy_admin_name, db_writer_client, store_model_in_db
37093707
import json
37103708

37113709
### LOAD MASTER KEY ###
@@ -3741,10 +3739,6 @@ async def startup_event():
37413739
## Error Tracking ##
37423740
error_tracking()
37433741

3744-
## Priority Workload Scheduler ##
3745-
if llm_router is not None:
3746-
simple_scheduler.update_variables(llm_router=llm_router)
3747-
37483742
## UPDATE SLACK ALERTING ##
37493743
proxy_logging_obj.slack_alerting_instance.update_values(llm_router=llm_router)
37503744

@@ -12183,47 +12177,12 @@ async def async_queue_request(
1218312177
if user_api_base:
1218412178
data["api_base"] = user_api_base
1218512179

12186-
## FLOW ITEM ##
12187-
request_id = str(uuid.uuid4())
12188-
flow_item = FlowItem(
12189-
priority=data.pop("priority", DefaultPriorities.Medium.value),
12190-
request_id=request_id,
12191-
model_name=data["model"],
12192-
)
12193-
# [TODO] only allow premium users to set non default priorities
12194-
12195-
## ADD REQUEST TO QUEUE
12196-
response = await simple_scheduler.add_request(request=flow_item)
12197-
12198-
if llm_router is None:
12199-
raise HTTPException(
12200-
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
12201-
)
12202-
## POLL QUEUE
12203-
default_timeout = llm_router.timeout
12204-
end_time = time.time() + default_timeout
12205-
poll_interval = 0.03 # poll every 3ms
12206-
curr_time = time.time()
12207-
12208-
make_request = False
12209-
1221012180
if llm_router is None:
1221112181
raise HTTPException(
1221212182
status_code=500, detail={"error": CommonProxyErrors.no_llm_router.value}
1221312183
)
1221412184

12215-
while curr_time < end_time:
12216-
make_request = await simple_scheduler.poll(
12217-
id=request_id, model_name=data["model"]
12218-
)
12219-
if make_request: ## IF TRUE -> MAKE REQUEST
12220-
break
12221-
else: ## ELSE -> loop till default_timeout
12222-
await asyncio.sleep(poll_interval)
12223-
curr_time = time.time()
12224-
12225-
if make_request:
12226-
response = await llm_router.acompletion(**data)
12185+
response = await llm_router.schedule_acompletion(**data)
1222712186

1222812187
if (
1222912188
"stream" in data and data["stream"] == True
@@ -12237,7 +12196,7 @@ async def async_queue_request(
1223712196
media_type="text/event-stream",
1223812197
)
1223912198

12240-
fastapi_response.headers.update({"x-litellm-priority": str(flow_item.priority)})
12199+
fastapi_response.headers.update({"x-litellm-priority": str(data["priority"])})
1224112200
return response
1224212201
except Exception as e:
1224312202
await proxy_logging_obj.post_call_failure_hook(
@@ -12260,19 +12219,6 @@ async def async_queue_request(
1226012219
)
1226112220

1226212221

12263-
@router.get(
12264-
"/queue/info",
12265-
tags=["experimental"],
12266-
dependencies=[Depends(user_api_key_auth)],
12267-
)
12268-
async def queue_info(
12269-
request: Request,
12270-
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
12271-
) -> List:
12272-
"""Help user know the status of an item in the queue"""
12273-
return simple_scheduler.get_queue_status()
12274-
12275-
1227612222
@router.get(
1227712223
"/ollama_logs", dependencies=[Depends(user_api_key_auth)], tags=["experimental"]
1227812224
)

litellm/router.py

+89-2
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
Run,
6363
AssistantToolParam,
6464
)
65+
from litellm.scheduler import Scheduler, FlowItem
6566
from typing import Iterable
6667

6768

@@ -87,6 +88,8 @@ def __init__(
8788
List[tuple]
8889
] = None, # if you want to cache across model groups
8990
client_ttl: int = 3600, # ttl for cached clients - will re-initialize after this time in seconds
91+
## SCHEDULER ##
92+
polling_interval: Optional[float] = None,
9093
## RELIABILITY ##
9194
num_retries: Optional[int] = None,
9295
timeout: Optional[float] = None,
@@ -141,7 +144,8 @@ def __init__(
141144
cache_kwargs (dict): Additional kwargs to pass to RedisCache. Defaults to {}.
142145
caching_groups (Optional[List[tuple]]): List of model groups for caching across model groups. Defaults to None.
143146
client_ttl (int): Time-to-live for cached clients in seconds. Defaults to 3600.
144-
num_retries (int): Number of retries for failed requests. Defaults to 0.
147+
polling_interval: (Optional[float]): frequency of polling queue. Only for '.scheduler_acompletion()'. Default is 3ms.
148+
num_retries (Optional[int]): Number of retries for failed requests. Defaults to 2.
145149
timeout (Optional[float]): Timeout for requests. Defaults to None.
146150
default_litellm_params (dict): Default parameters for Router.chat.completion.create. Defaults to {}.
147151
set_verbose (bool): Flag to set verbose mode. Defaults to False.
@@ -208,6 +212,8 @@ def __init__(
208212
[]
209213
) # names of models under litellm_params. ex. azure/chatgpt-v-2
210214
self.deployment_latency_map = {}
215+
### SCHEDULER ###
216+
self.scheduler = Scheduler(polling_interval=polling_interval)
211217
### CACHING ###
212218
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
213219
redis_cache = None
@@ -533,11 +539,17 @@ async def acompletion(
533539
) -> ModelResponse:
534540
...
535541

542+
@overload
543+
async def acompletion(
544+
self, model: str, messages: List[Dict[str, str]], stream: Union[Literal[True], Literal[False]] = False, **kwargs
545+
) -> Union[CustomStreamWrapper, ModelResponse]:
546+
...
547+
536548
# fmt: on
537549

538550
# The actual implementation of the function
539551
async def acompletion(
540-
self, model: str, messages: List[Dict[str, str]], stream=False, **kwargs
552+
self, model: str, messages: List[Dict[str, str]], stream: bool = False, **kwargs
541553
):
542554
try:
543555
kwargs["model"] = model
@@ -905,6 +917,81 @@ async def check_response(task: asyncio.Task):
905917
# If we exit the loop without returning, all tasks failed
906918
raise Exception("All tasks failed")
907919

920+
### SCHEDULER ###
921+
922+
# fmt: off
923+
924+
@overload
925+
async def schedule_acompletion(
926+
self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[False] = False, **kwargs
927+
) -> ModelResponse:
928+
...
929+
930+
@overload
931+
async def schedule_acompletion(
932+
self, model: str, messages: List[Dict[str, str]], priority: int, stream: Literal[True], **kwargs
933+
) -> CustomStreamWrapper:
934+
...
935+
936+
# fmt: on
937+
938+
async def schedule_acompletion(
939+
self,
940+
model: str,
941+
messages: List[Dict[str, str]],
942+
priority: int,
943+
stream=False,
944+
**kwargs,
945+
):
946+
### FLOW ITEM ###
947+
_request_id = str(uuid.uuid4())
948+
item = FlowItem(
949+
priority=priority, # 👈 SET PRIORITY FOR REQUEST
950+
request_id=_request_id, # 👈 SET REQUEST ID
951+
model_name="gpt-3.5-turbo", # 👈 SAME as 'Router'
952+
)
953+
### [fin] ###
954+
955+
## ADDS REQUEST TO QUEUE ##
956+
await self.scheduler.add_request(request=item)
957+
958+
## POLL QUEUE
959+
end_time = time.time() + self.timeout
960+
curr_time = time.time()
961+
poll_interval = self.scheduler.polling_interval # poll every 3ms
962+
make_request = False
963+
964+
while curr_time < end_time:
965+
_healthy_deployments = await self._async_get_healthy_deployments(
966+
model=model
967+
)
968+
make_request = await self.scheduler.poll( ## POLL QUEUE ## - returns 'True' if there's healthy deployments OR if request is at top of queue
969+
id=item.request_id,
970+
model_name=item.model_name,
971+
health_deployments=_healthy_deployments,
972+
)
973+
if make_request: ## IF TRUE -> MAKE REQUEST
974+
break
975+
else: ## ELSE -> loop till default_timeout
976+
await asyncio.sleep(poll_interval)
977+
curr_time = time.time()
978+
979+
if make_request:
980+
try:
981+
_response = await self.acompletion(
982+
model=model, messages=messages, stream=stream, **kwargs
983+
)
984+
return _response
985+
except Exception as e:
986+
setattr(e, "priority", priority)
987+
raise e
988+
else:
989+
raise litellm.Timeout(
990+
message="Request timed out while polling queue",
991+
model=model,
992+
llm_provider="openai",
993+
)
994+
908995
def image_generation(self, prompt: str, model: str, **kwargs):
909996
try:
910997
kwargs["model"] = model

0 commit comments

Comments
 (0)