Skip to content

Commit e4ecb86

Browse files
authored
Log tool_instances to database (#1098)
* Log tool_instances to database, closes #1089 * Tested for both sync and async models
1 parent c9e8593 commit e4ecb86

File tree

5 files changed

+376
-39
lines changed

5 files changed

+376
-39
lines changed

docs/logging.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,13 +397,14 @@ CREATE TABLE [tool_calls] (
397397
[arguments] TEXT,
398398
[tool_call_id] TEXT
399399
);
400-
CREATE TABLE [tool_results] (
400+
CREATE TABLE "tool_results" (
401401
[id] INTEGER PRIMARY KEY,
402402
[response_id] TEXT REFERENCES [responses]([id]),
403403
[tool_id] INTEGER REFERENCES [tools]([id]),
404404
[name] TEXT,
405405
[output] TEXT,
406-
[tool_call_id] TEXT
406+
[tool_call_id] TEXT,
407+
[instance_id] INTEGER REFERENCES [tool_instances]([id])
407408
);
408409
```
409410
<!-- [[[end]]] -->

llm/migrations.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,3 +373,20 @@ def m017_tools_tables(db):
373373
@migration
374374
def m017_tools_plugin(db):
375375
db["tools"].add_column("plugin")
376+
377+
378+
@migration
379+
def m018_tool_instances(db):
380+
# Used to track instances of Toolbox classes that may be
381+
# used multiple times by different tools
382+
db["tool_instances"].create(
383+
{
384+
"id": int,
385+
"plugin": str,
386+
"name": str,
387+
"arguments": str,
388+
},
389+
pk="id",
390+
)
391+
# We record which instance was used only on the results
392+
db["tool_results"].add_column("instance_id", fk="tool_instances")

llm/models.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,29 @@ def _get_arguments_input_schema(function, name):
175175
class Toolbox:
176176
_blocked = ("method_tools", "introspect_methods", "methods")
177177
name: Optional[str] = None
178+
instance_id: Optional[int] = None
179+
180+
def __init_subclass__(cls, **kwargs):
181+
super().__init_subclass__(**kwargs)
182+
183+
original_init = cls.__init__
184+
185+
def wrapped_init(self, *args, **kwargs):
186+
sig = inspect.signature(original_init)
187+
bound = sig.bind(self, *args, **kwargs)
188+
bound.apply_defaults()
189+
190+
self._config = {
191+
name: value
192+
for name, value in bound.arguments.items()
193+
if name != "self"
194+
and sig.parameters[name].kind
195+
not in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
196+
}
197+
198+
original_init(self, *args, **kwargs)
199+
200+
cls.__init__ = wrapped_init
178201

179202
@classmethod
180203
def methods(cls):
@@ -197,10 +220,12 @@ def method_tools(self):
197220
method = getattr(self, method_name)
198221
# The attribute must be a bound method, i.e. inspect.ismethod()
199222
if callable(method) and inspect.ismethod(method):
200-
yield Tool.function(
223+
tool = Tool.function(
201224
method,
202225
name="{}_{}".format(self.__class__.__name__, method_name),
203226
)
227+
tool.plugin = getattr(self, "plugin", None)
228+
yield tool
204229

205230
@classmethod
206231
def introspect_methods(cls):
@@ -235,6 +260,7 @@ class ToolResult:
235260
name: str
236261
output: str
237262
tool_call_id: Optional[str] = None
263+
instance: Optional[Toolbox] = None
238264

239265

240266
class CancelToolCall(Exception):
@@ -834,13 +860,29 @@ def log_to_db(self, db):
834860
}
835861
)
836862
for tool_result in self.prompt.tool_results:
863+
instance_id = None
864+
if tool_result.instance:
865+
if not tool_result.instance.instance_id:
866+
tool_result.instance.instance_id = (
867+
db["tool_instances"]
868+
.insert(
869+
{
870+
"plugin": tool.plugin,
871+
"name": tool.name.split("_")[0],
872+
"arguments": json.dumps(tool_result.instance._config),
873+
}
874+
)
875+
.last_pk
876+
)
877+
instance_id = tool_result.instance.instance_id
837878
db["tool_results"].insert(
838879
{
839880
"response_id": response_id,
840881
"tool_id": tool_ids_by_name.get(tool_result.name) or None,
841882
"name": tool_result.name,
842883
"output": tool_result.output,
843884
"tool_call_id": tool_result.tool_call_id,
885+
"instance_id": instance_id,
844886
}
845887
)
846888

@@ -919,6 +961,7 @@ def execute_tool_calls(
919961
name=tool_call.name,
920962
output=result,
921963
tool_call_id=tool_call.tool_call_id,
964+
instance=_get_instance(tool.implementation),
922965
)
923966

924967
if after_call:
@@ -1078,6 +1121,7 @@ async def run_async(tc=tc, tool=tool, idx=idx):
10781121
name=tc.name,
10791122
output=output,
10801123
tool_call_id=tc.tool_call_id,
1124+
instance=_get_instance(tool.implementation),
10811125
)
10821126

10831127
# after_call inside the task
@@ -1111,6 +1155,7 @@ async def run_async(tc=tc, tool=tool, idx=idx):
11111155
name=tc.name,
11121156
output=output,
11131157
tool_call_id=tc.tool_call_id,
1158+
instance=_get_instance(tool.implementation),
11141159
)
11151160

11161161
if after_call:
@@ -1844,3 +1889,9 @@ def _remove_titles_recursively(obj):
18441889
# Process each item in lists
18451890
for item in obj:
18461891
_remove_titles_recursively(item)
1892+
1893+
1894+
def _get_instance(implementation):
1895+
if hasattr(implementation, "__self__"):
1896+
return implementation.__self__
1897+
return None

0 commit comments

Comments
 (0)