Skip to content

Commit d584007

Browse files
committed
chore: use starlette built-in Route class
Use a more common pattern and known terminology from the ecosystem, where Route is more approved than Endpoint. Signed-off-by: Sébastien Han <[email protected]>
1 parent 448f009 commit d584007

File tree

7 files changed

+112
-72
lines changed

7 files changed

+112
-72
lines changed

llama_stack/distribution/inspect.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
VersionInfo,
1717
)
1818
from llama_stack.distribution.datatypes import StackRunConfig
19-
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
19+
from llama_stack.distribution.server.routes import get_all_api_routes
2020
from llama_stack.providers.datatypes import HealthStatus
2121

2222

@@ -42,15 +42,15 @@ async def list_routes(self) -> ListRoutesResponse:
4242
run_config: StackRunConfig = self.config.run_config
4343

4444
ret = []
45-
all_endpoints = get_all_api_endpoints()
45+
all_endpoints = get_all_api_routes()
4646
for api, endpoints in all_endpoints.items():
4747
# Always include provider and inspect APIs, filter others based on run config
4848
if api.value in ["providers", "inspect"]:
4949
ret.extend(
5050
[
5151
RouteInfo(
52-
route=e.route,
53-
method=e.method,
52+
route=e.path,
53+
method=next(iter([m for m in e.methods if m != "HEAD"])),
5454
provider_types=[], # These APIs don't have "real" providers - they're internal to the stack
5555
)
5656
for e in endpoints
@@ -62,8 +62,8 @@ async def list_routes(self) -> ListRoutesResponse:
6262
ret.extend(
6363
[
6464
RouteInfo(
65-
route=e.route,
66-
method=e.method,
65+
route=e.path,
66+
method=next(iter([m for m in e.methods if m != "HEAD"])),
6767
provider_types=[p.provider_type for p in providers],
6868
)
6969
for e in endpoints

llama_stack/distribution/library_client.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,7 @@
3737
request_provider_data_context,
3838
)
3939
from llama_stack.distribution.resolver import ProviderRegistry
40-
from llama_stack.distribution.server.endpoints import (
41-
find_matching_endpoint,
42-
initialize_endpoint_impls,
43-
)
40+
from llama_stack.distribution.server.routes import find_matching_route, initialize_route_impls
4441
from llama_stack.distribution.stack import (
4542
construct_stack,
4643
get_stack_run_config_from_template,
@@ -208,7 +205,7 @@ def __init__(
208205

209206
async def initialize(self) -> bool:
210207
try:
211-
self.endpoint_impls = None
208+
self.route_impls = None
212209
self.impls = await construct_stack(self.config, self.custom_provider_registry)
213210
except ModuleNotFoundError as _e:
214211
cprint(_e.msg, color="red", file=sys.stderr)
@@ -254,7 +251,7 @@ async def initialize(self) -> bool:
254251
safe_config = redact_sensitive_fields(self.config.model_dump())
255252
console.print(yaml.dump(safe_config, indent=2))
256253

257-
self.endpoint_impls = initialize_endpoint_impls(self.impls)
254+
self.route_impls = initialize_route_impls(self.impls)
258255
return True
259256

260257
async def request(
@@ -265,7 +262,7 @@ async def request(
265262
stream=False,
266263
stream_cls=None,
267264
):
268-
if not self.endpoint_impls:
265+
if not self.route_impls:
269266
raise ValueError("Client not initialized")
270267

271268
# Create headers with provider data if available
@@ -296,11 +293,14 @@ async def _call_non_streaming(
296293
cast_to: Any,
297294
options: Any,
298295
):
296+
if self.route_impls is None:
297+
raise ValueError("Client not initialized")
298+
299299
path = options.url
300300
body = options.params or {}
301301
body |= options.json_data or {}
302302

303-
matched_func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
303+
matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls)
304304
body |= path_params
305305
body = self._convert_body(path, options.method, body)
306306
await start_trace(route, {"__location__": "library_client"})
@@ -342,10 +342,13 @@ async def _call_streaming(
342342
options: Any,
343343
stream_cls: Any,
344344
):
345+
if self.route_impls is None:
346+
raise ValueError("Client not initialized")
347+
345348
path = options.url
346349
body = options.params or {}
347350
body |= options.json_data or {}
348-
func, path_params, route = find_matching_endpoint(options.method, path, self.endpoint_impls)
351+
func, path_params, route = find_matching_route(options.method, path, self.route_impls)
349352
body |= path_params
350353

351354
body = self._convert_body(path, options.method, body)
@@ -397,7 +400,10 @@ def _convert_body(self, path: str, method: str, body: dict | None = None) -> dic
397400
if not body:
398401
return {}
399402

400-
func, _, _ = find_matching_endpoint(method, path, self.endpoint_impls)
403+
if self.route_impls is None:
404+
raise ValueError("Client not initialized")
405+
406+
func, _, _ = find_matching_route(method, path, self.route_impls)
401407
sig = inspect.signature(func)
402408

403409
# Strip NOT_GIVENs to use the defaults in signature

llama_stack/distribution/server/endpoints.py renamed to llama_stack/distribution/server/routes.py

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,23 @@
66

77
import inspect
88
import re
9+
from collections.abc import Callable
10+
from typing import Any
911

10-
from pydantic import BaseModel
12+
from aiohttp import hdrs
13+
from starlette.routing import Route
1114

1215
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
1316
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
1417
from llama_stack.distribution.resolver import api_protocol_map
1518
from llama_stack.providers.datatypes import Api
1619

17-
18-
class ApiEndpoint(BaseModel):
19-
route: str
20-
method: str
21-
name: str
22-
descriptive_name: str | None = None
20+
EndpointFunc = Callable[..., Any]
21+
PathParams = dict[str, str]
22+
RouteInfo = tuple[EndpointFunc, str]
23+
PathImpl = dict[str, RouteInfo]
24+
RouteImpls = dict[str, PathImpl]
25+
RouteMatch = tuple[EndpointFunc, PathParams, str]
2326

2427

2528
def toolgroup_protocol_map():
@@ -28,13 +31,13 @@ def toolgroup_protocol_map():
2831
}
2932

3033

31-
def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]:
34+
def get_all_api_routes() -> dict[Api, list[Route]]:
3235
apis = {}
3336

3437
protocols = api_protocol_map()
3538
toolgroup_protocols = toolgroup_protocol_map()
3639
for api, protocol in protocols.items():
37-
endpoints = []
40+
routes = []
3841
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
3942

4043
# HACK ALERT
@@ -51,26 +54,28 @@ def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]:
5154
if not hasattr(method, "__webmethod__"):
5255
continue
5356

54-
webmethod = method.__webmethod__
55-
route = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
56-
if webmethod.method == "GET":
57-
method = "get"
58-
elif webmethod.method == "DELETE":
59-
method = "delete"
57+
# The __webmethod__ attribute is dynamically added by the @webmethod decorator
58+
# mypy doesn't know about this dynamic attribute, so we ignore the attr-defined error
59+
webmethod = method.__webmethod__ # type: ignore[attr-defined]
60+
path = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
61+
if webmethod.method == hdrs.METH_GET:
62+
http_method = hdrs.METH_GET
63+
elif webmethod.method == hdrs.METH_DELETE:
64+
http_method = hdrs.METH_DELETE
6065
else:
61-
method = "post"
62-
endpoints.append(
63-
ApiEndpoint(route=route, method=method, name=name, descriptive_name=webmethod.descriptive_name)
64-
)
66+
http_method = hdrs.METH_POST
67+
routes.append(
68+
Route(path=path, methods=[http_method], name=name, endpoint=None)
69+
) # setting endpoint to None since don't use a Router object
6570

66-
apis[api] = endpoints
71+
apis[api] = routes
6772

6873
return apis
6974

7075

71-
def initialize_endpoint_impls(impls):
72-
endpoints = get_all_api_endpoints()
73-
endpoint_impls = {}
76+
def initialize_route_impls(impls: dict[Api, Any]) -> RouteImpls:
77+
routes = get_all_api_routes()
78+
route_impls: RouteImpls = {}
7479

7580
def _convert_path_to_regex(path: str) -> str:
7681
# Convert {param} to named capture groups
@@ -83,37 +88,42 @@ def _convert_path_to_regex(path: str) -> str:
8388

8489
return f"^{pattern}$"
8590

86-
for api, api_endpoints in endpoints.items():
91+
for api, api_routes in routes.items():
8792
if api not in impls:
8893
continue
89-
for endpoint in api_endpoints:
94+
for route in api_routes:
9095
impl = impls[api]
91-
func = getattr(impl, endpoint.name)
92-
if endpoint.method not in endpoint_impls:
93-
endpoint_impls[endpoint.method] = {}
94-
endpoint_impls[endpoint.method][_convert_path_to_regex(endpoint.route)] = (
96+
func = getattr(impl, route.name)
97+
# Get the first (and typically only) method from the set, filtering out HEAD
98+
available_methods = [m for m in route.methods if m != "HEAD"]
99+
if not available_methods:
100+
continue # Skip if only HEAD method is available
101+
method = available_methods[0].lower()
102+
if method not in route_impls:
103+
route_impls[method] = {}
104+
route_impls[method][_convert_path_to_regex(route.path)] = (
95105
func,
96-
endpoint.descriptive_name or endpoint.route,
106+
route.path,
97107
)
98108

99-
return endpoint_impls
109+
return route_impls
100110

101111

102-
def find_matching_endpoint(method, path, endpoint_impls):
112+
def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> RouteMatch:
103113
"""Find the matching endpoint implementation for a given method and path.
104114
105115
Args:
106116
method: HTTP method (GET, POST, etc.)
107117
path: URL path to match against
108-
endpoint_impls: A dictionary of endpoint implementations
118+
route_impls: A dictionary of endpoint implementations
109119
110120
Returns:
111121
A tuple of (endpoint_function, path_params, descriptive_name)
112122
113123
Raises:
114124
ValueError: If no matching endpoint is found
115125
"""
116-
impls = endpoint_impls.get(method.lower())
126+
impls = route_impls.get(method.lower())
117127
if not impls:
118128
raise ValueError(f"No endpoint found for {path}")
119129

0 commit comments

Comments
 (0)