@@ -175,6 +175,29 @@ def _get_arguments_input_schema(function, name):
175
175
class Toolbox :
176
176
_blocked = ("method_tools" , "introspect_methods" , "methods" )
177
177
name : Optional [str ] = None
178
+ instance_id : Optional [int ] = None
179
+
180
+ def __init_subclass__ (cls , ** kwargs ):
181
+ super ().__init_subclass__ (** kwargs )
182
+
183
+ original_init = cls .__init__
184
+
185
+ def wrapped_init (self , * args , ** kwargs ):
186
+ sig = inspect .signature (original_init )
187
+ bound = sig .bind (self , * args , ** kwargs )
188
+ bound .apply_defaults ()
189
+
190
+ self ._config = {
191
+ name : value
192
+ for name , value in bound .arguments .items ()
193
+ if name != "self"
194
+ and sig .parameters [name ].kind
195
+ not in (inspect .Parameter .VAR_POSITIONAL , inspect .Parameter .VAR_KEYWORD )
196
+ }
197
+
198
+ original_init (self , * args , ** kwargs )
199
+
200
+ cls .__init__ = wrapped_init
178
201
179
202
@classmethod
180
203
def methods (cls ):
@@ -197,10 +220,12 @@ def method_tools(self):
197
220
method = getattr (self , method_name )
198
221
# The attribute must be a bound method, i.e. inspect.ismethod()
199
222
if callable (method ) and inspect .ismethod (method ):
200
- yield Tool .function (
223
+ tool = Tool .function (
201
224
method ,
202
225
name = "{}_{}" .format (self .__class__ .__name__ , method_name ),
203
226
)
227
+ tool .plugin = getattr (self , "plugin" , None )
228
+ yield tool
204
229
205
230
@classmethod
206
231
def introspect_methods (cls ):
@@ -235,6 +260,7 @@ class ToolResult:
235
260
name : str
236
261
output : str
237
262
tool_call_id : Optional [str ] = None
263
+ instance : Optional [Toolbox ] = None
238
264
239
265
240
266
class CancelToolCall (Exception ):
@@ -834,13 +860,29 @@ def log_to_db(self, db):
834
860
}
835
861
)
836
862
for tool_result in self .prompt .tool_results :
863
+ instance_id = None
864
+ if tool_result .instance :
865
+ if not tool_result .instance .instance_id :
866
+ tool_result .instance .instance_id = (
867
+ db ["tool_instances" ]
868
+ .insert (
869
+ {
870
+ "plugin" : tool .plugin ,
871
+ "name" : tool .name .split ("_" )[0 ],
872
+ "arguments" : json .dumps (tool_result .instance ._config ),
873
+ }
874
+ )
875
+ .last_pk
876
+ )
877
+ instance_id = tool_result .instance .instance_id
837
878
db ["tool_results" ].insert (
838
879
{
839
880
"response_id" : response_id ,
840
881
"tool_id" : tool_ids_by_name .get (tool_result .name ) or None ,
841
882
"name" : tool_result .name ,
842
883
"output" : tool_result .output ,
843
884
"tool_call_id" : tool_result .tool_call_id ,
885
+ "instance_id" : instance_id ,
844
886
}
845
887
)
846
888
@@ -919,6 +961,7 @@ def execute_tool_calls(
919
961
name = tool_call .name ,
920
962
output = result ,
921
963
tool_call_id = tool_call .tool_call_id ,
964
+ instance = _get_instance (tool .implementation ),
922
965
)
923
966
924
967
if after_call :
@@ -1078,6 +1121,7 @@ async def run_async(tc=tc, tool=tool, idx=idx):
1078
1121
name = tc .name ,
1079
1122
output = output ,
1080
1123
tool_call_id = tc .tool_call_id ,
1124
+ instance = _get_instance (tool .implementation ),
1081
1125
)
1082
1126
1083
1127
# after_call inside the task
@@ -1111,6 +1155,7 @@ async def run_async(tc=tc, tool=tool, idx=idx):
1111
1155
name = tc .name ,
1112
1156
output = output ,
1113
1157
tool_call_id = tc .tool_call_id ,
1158
+ instance = _get_instance (tool .implementation ),
1114
1159
)
1115
1160
1116
1161
if after_call :
@@ -1844,3 +1889,9 @@ def _remove_titles_recursively(obj):
1844
1889
# Process each item in lists
1845
1890
for item in obj :
1846
1891
_remove_titles_recursively (item )
1892
+
1893
+
1894
+ def _get_instance (implementation ):
1895
+ if hasattr (implementation , "__self__" ):
1896
+ return implementation .__self__
1897
+ return None
0 commit comments