6
6
7
7
import inspect
8
8
import re
9
+ from collections .abc import Callable
10
+ from typing import Any
9
11
10
- from pydantic import BaseModel
12
+ from aiohttp import hdrs
13
+ from starlette .routing import Route
11
14
12
15
from llama_stack .apis .tools import RAGToolRuntime , SpecialToolGroup
13
16
from llama_stack .apis .version import LLAMA_STACK_API_VERSION
14
17
from llama_stack .distribution .resolver import api_protocol_map
15
18
from llama_stack .providers .datatypes import Api
16
19
17
-
18
- class ApiEndpoint ( BaseModel ):
19
- route : str
20
- method : str
21
- name : str
22
- descriptive_name : str | None = None
20
+ EndpointFunc = Callable [..., Any ]
21
+ PathParams = dict [ str , str ]
22
+ RouteInfo = tuple [ EndpointFunc , str ]
23
+ PathImpl = dict [ str , RouteInfo ]
24
+ RouteImpls = dict [ str , PathImpl ]
25
+ RouteMatch = tuple [ EndpointFunc , PathParams , str ]
23
26
24
27
25
28
def toolgroup_protocol_map ():
@@ -28,13 +31,13 @@ def toolgroup_protocol_map():
28
31
}
29
32
30
33
31
- def get_all_api_endpoints () -> dict [Api , list [ApiEndpoint ]]:
34
+ def get_all_api_routes () -> dict [Api , list [Route ]]:
32
35
apis = {}
33
36
34
37
protocols = api_protocol_map ()
35
38
toolgroup_protocols = toolgroup_protocol_map ()
36
39
for api , protocol in protocols .items ():
37
- endpoints = []
40
+ routes = []
38
41
protocol_methods = inspect .getmembers (protocol , predicate = inspect .isfunction )
39
42
40
43
# HACK ALERT
@@ -51,26 +54,28 @@ def get_all_api_endpoints() -> dict[Api, list[ApiEndpoint]]:
51
54
if not hasattr (method , "__webmethod__" ):
52
55
continue
53
56
54
- webmethod = method .__webmethod__
55
- route = f"/{ LLAMA_STACK_API_VERSION } /{ webmethod .route .lstrip ('/' )} "
56
- if webmethod .method == "GET" :
57
- method = "get"
58
- elif webmethod .method == "DELETE" :
59
- method = "delete"
57
+ # The __webmethod__ attribute is dynamically added by the @webmethod decorator
58
+ # mypy doesn't know about this dynamic attribute, so we ignore the attr-defined error
59
+ webmethod = method .__webmethod__ # type: ignore[attr-defined]
60
+ path = f"/{ LLAMA_STACK_API_VERSION } /{ webmethod .route .lstrip ('/' )} "
61
+ if webmethod .method == hdrs .METH_GET :
62
+ http_method = hdrs .METH_GET
63
+ elif webmethod .method == hdrs .METH_DELETE :
64
+ http_method = hdrs .METH_DELETE
60
65
else :
61
- method = "post"
62
- endpoints .append (
63
- ApiEndpoint ( route = route , method = method , name = name , descriptive_name = webmethod . descriptive_name )
64
- )
66
+ http_method = hdrs . METH_POST
67
+ routes .append (
68
+ Route ( path = path , methods = [ http_method ] , name = name , endpoint = None )
69
+ ) # setting endpoint to None since don't use a Router object
65
70
66
- apis [api ] = endpoints
71
+ apis [api ] = routes
67
72
68
73
return apis
69
74
70
75
71
- def initialize_endpoint_impls (impls ) :
72
- endpoints = get_all_api_endpoints ()
73
- endpoint_impls = {}
76
+ def initialize_route_impls (impls : dict [ Api , Any ]) -> RouteImpls :
77
+ routes = get_all_api_routes ()
78
+ route_impls : RouteImpls = {}
74
79
75
80
def _convert_path_to_regex (path : str ) -> str :
76
81
# Convert {param} to named capture groups
@@ -83,37 +88,42 @@ def _convert_path_to_regex(path: str) -> str:
83
88
84
89
return f"^{ pattern } $"
85
90
86
- for api , api_endpoints in endpoints .items ():
91
+ for api , api_routes in routes .items ():
87
92
if api not in impls :
88
93
continue
89
- for endpoint in api_endpoints :
94
+ for route in api_routes :
90
95
impl = impls [api ]
91
- func = getattr (impl , endpoint .name )
92
- if endpoint .method not in endpoint_impls :
93
- endpoint_impls [endpoint .method ] = {}
94
- endpoint_impls [endpoint .method ][_convert_path_to_regex (endpoint .route )] = (
96
+ func = getattr (impl , route .name )
97
+ # Get the first (and typically only) method from the set, filtering out HEAD
98
+ available_methods = [m for m in route .methods if m != "HEAD" ]
99
+ if not available_methods :
100
+ continue # Skip if only HEAD method is available
101
+ method = available_methods [0 ].lower ()
102
+ if method not in route_impls :
103
+ route_impls [method ] = {}
104
+ route_impls [method ][_convert_path_to_regex (route .path )] = (
95
105
func ,
96
- endpoint . descriptive_name or endpoint . route ,
106
+ route . path ,
97
107
)
98
108
99
- return endpoint_impls
109
+ return route_impls
100
110
101
111
102
- def find_matching_endpoint (method , path , endpoint_impls ) :
112
+ def find_matching_route (method : str , path : str , route_impls : RouteImpls ) -> RouteMatch :
103
113
"""Find the matching endpoint implementation for a given method and path.
104
114
105
115
Args:
106
116
method: HTTP method (GET, POST, etc.)
107
117
path: URL path to match against
108
- endpoint_impls : A dictionary of endpoint implementations
118
+ route_impls : A dictionary of endpoint implementations
109
119
110
120
Returns:
111
121
A tuple of (endpoint_function, path_params, descriptive_name)
112
122
113
123
Raises:
114
124
ValueError: If no matching endpoint is found
115
125
"""
116
- impls = endpoint_impls .get (method .lower ())
126
+ impls = route_impls .get (method .lower ())
117
127
if not impls :
118
128
raise ValueError (f"No endpoint found for { path } " )
119
129
0 commit comments