Skip to content

Dev branch for the ToolUseAgent #239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 30 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
9ee2367
moving the browsergym.experiment.benchmark module to agentlab
TLSDC Apr 23, 2025
c2e2b9c
added comment for new parameter
TLSDC Apr 23, 2025
596fcd2
BaseMessages take into account 'input_text' key too (for xray)
TLSDC Apr 23, 2025
f9d7b91
convenient array to base64 function
TLSDC Apr 23, 2025
73ba428
tool agent embryo
TLSDC Apr 23, 2025
c11db49
Merge branch 'main' of github.com:ServiceNow/AgentLab into tlsdc/tool…
TLSDC Apr 24, 2025
6604dbc
added the MessageBuilder class, which should help interfacing APIs
TLSDC Apr 24, 2025
ef6f648
claude
TLSDC Apr 30, 2025
4e973ac
adding markdown display for MessageBuilder in xray
TLSDC May 1, 2025
54ec412
changed LLM structure to be more versatile
TLSDC May 1, 2025
0fc43cc
unified claude and openai response apis
TLSDC May 2, 2025
19cdaf9
i dont think this is relevant anymore
TLSDC May 2, 2025
5b3f469
backtracking from moving bgym.benchmarks etc
TLSDC May 2, 2025
087ad75
defaulting to claude bc it's better
TLSDC May 2, 2025
8a17470
kind of forced to comment this to avoid circular imports atm
TLSDC May 2, 2025
5f675ba
Merge branch 'main' of github.com:ServiceNow/AgentLab into tlsdc/tool…
TLSDC May 2, 2025
234be09
parametrized env output to agent_args
TLSDC May 2, 2025
544908e
fixing broken import in test
TLSDC May 2, 2025
16cc3cd
Add pricing tracking for Anthropic model and refactor pricing functions
recursix May 8, 2025
c674094
Enhance ToolUseAgent with token counting and improved message handlin…
recursix May 9, 2025
528b513
Update action in ClaudeResponseModel to None for improved clarity
recursix May 9, 2025
417893c
typo
recursix May 13, 2025
c676eab
typo
recursix May 13, 2025
ab2d331
Remove unnecessary import of anthropic for cleaner code
recursix May 13, 2025
bf57591
moving some utils to agent_utils.py
amanjaiswal73892 May 14, 2025
ce72b41
Fix: Formatting ang Darglint
amanjaiswal73892 May 15, 2025
fe05d75
Refactor: Simplify message builder methods and add support for chat c…
amanjaiswal73892 May 15, 2025
97a39cc
added vllm-support-for-tool-use-agent
amanjaiswal73892 May 17, 2025
7d8a08c
Moving some functions to llm utils.py
amanjaiswal73892 May 21, 2025
ffd5c5e
Merge pull request #248 from ServiceNow/aj/tool_use_agent_chat_comple…
amanjaiswal73892 May 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/agentlab/agents/agent_args.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import bgym
from bgym import AbstractAgentArgs

from agentlab.experiments.benchmark import Benchmark


class AgentArgs(AbstractAgentArgs):
"""Base class for agent arguments for instantiating an agent.
Expand All @@ -14,7 +16,7 @@ class MyAgentArgs(AgentArgs):
Note: for working properly with AgentXRay, the arguments need to be serializable and hasable.
"""

def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode: bool):
def set_benchmark(self, benchmark: Benchmark, demo_mode: bool):
"""Optional method to set benchmark specific flags.

This allows the agent to have minor adjustments based on the benchmark.
Expand Down
10 changes: 3 additions & 7 deletions src/agentlab/agents/dynamic_prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,9 @@

import bgym
from browsergym.core.action.base import AbstractActionSet
from browsergym.utils.obs import (
flatten_axtree_to_str,
flatten_dom_to_str,
overlay_som,
prune_html,
)
from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, overlay_som, prune_html

from agentlab.experiments.benchmark import HighLevelActionSetArgs
from agentlab.llm.llm_utils import (
BaseMessage,
ParseError,
Expand Down Expand Up @@ -99,7 +95,7 @@ class ObsFlags(Flags):

@dataclass
class ActionFlags(Flags):
action_set: bgym.HighLevelActionSetArgs = None # should be set by the set_benchmark method
action_set: HighLevelActionSetArgs = None # should be set by the set_benchmark method
long_description: bool = True
individual_examples: bool = False

Expand Down
13 changes: 7 additions & 6 deletions src/agentlab/agents/generic_agent/agent_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from agentlab.agents import dynamic_prompting as dp
from agentlab.experiments import args
from agentlab.experiments.benchmark import HighLevelActionSetArgs
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT

from .generic_agent import GenericAgentArgs
Expand All @@ -31,7 +32,7 @@
filter_visible_elements_only=False,
),
action=dp.ActionFlags(
action_set=bgym.HighLevelActionSetArgs(
action_set=HighLevelActionSetArgs(
subsets=["bid"],
multiaction=False,
),
Expand Down Expand Up @@ -79,7 +80,7 @@
filter_visible_elements_only=False,
),
action=dp.ActionFlags(
action_set=bgym.HighLevelActionSetArgs(
action_set=HighLevelActionSetArgs(
subsets=["bid"],
multiaction=False,
),
Expand Down Expand Up @@ -126,7 +127,7 @@
filter_visible_elements_only=False,
),
action=dp.ActionFlags(
action_set=bgym.HighLevelActionSetArgs(
action_set=HighLevelActionSetArgs(
subsets=["bid"],
multiaction=False,
),
Expand Down Expand Up @@ -176,7 +177,7 @@
filter_visible_elements_only=False,
),
action=dp.ActionFlags(
action_set=bgym.HighLevelActionSetArgs(
action_set=HighLevelActionSetArgs(
subsets=["bid"],
multiaction=True,
),
Expand Down Expand Up @@ -231,7 +232,7 @@
filter_visible_elements_only=False,
),
action=dp.ActionFlags(
action_set=bgym.HighLevelActionSetArgs(
action_set=HighLevelActionSetArgs(
subsets=["bid"],
multiaction=False,
),
Expand Down Expand Up @@ -319,7 +320,7 @@
filter_visible_elements_only=args.Choice([True, False], p=[0.3, 0.7]),
),
action=dp.ActionFlags(
action_set=bgym.HighLevelActionSetArgs(
action_set=HighLevelActionSetArgs(
subsets=args.Choice([["bid"], ["bid", "coord"]]),
multiaction=args.Choice([True, False], p=[0.7, 0.3]),
),
Expand Down
5 changes: 3 additions & 2 deletions src/agentlab/agents/generic_agent/generic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,20 @@

from copy import deepcopy
from dataclasses import asdict, dataclass
from functools import partial
from warnings import warn

import bgym
from browsergym.experiments.agent import Agent, AgentInfo

from agentlab.agents import dynamic_prompting as dp
from agentlab.agents.agent_args import AgentArgs
from agentlab.experiments.benchmark import Benchmark
from agentlab.llm.chat_api import BaseModelArgs
from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry
from agentlab.llm.tracking import cost_tracker_decorator

from .generic_agent_prompt import GenericPromptFlags, MainPrompt
from functools import partial


@dataclass
Expand All @@ -37,7 +38,7 @@ def __post_init__(self):
except AttributeError:
pass

def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode):
def set_benchmark(self, benchmark: Benchmark, demo_mode):
"""Override Some flags based on the benchmark."""
if benchmark.name.startswith("miniwob"):
self.flags.obs.use_html = True
Expand Down
3 changes: 2 additions & 1 deletion src/agentlab/agents/generic_agent/reproducibility_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from bs4 import BeautifulSoup

from agentlab.agents.agent_args import AgentArgs
from agentlab.experiments.benchmark import HighLevelActionSetArgs
from agentlab.experiments.loop import ExpArgs, ExpResult, yield_all_exp_results
from agentlab.experiments.study import Study
from agentlab.llm.chat_api import make_assistant_message
Expand Down Expand Up @@ -144,7 +145,7 @@ def _make_backward_compatible(agent_args: GenericAgentArgs):
if isinstance(action_set, str):
action_set = action_set.split("+")

agent_args.flags.action.action_set = bgym.HighLevelActionSetArgs(
agent_args.flags.action.action_set = HighLevelActionSetArgs(
subsets=action_set,
multiaction=agent_args.flags.action.multi_actions,
)
Expand Down
Empty file.
184 changes: 184 additions & 0 deletions src/agentlab/agents/tool_use_agent/agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
import json
import logging
from copy import deepcopy as copy
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any

import bgym
from browsergym.core.observation import extract_screenshot

from agentlab.agents.agent_args import AgentArgs
from agentlab.llm.llm_utils import image_to_png_base64_url
from agentlab.llm.response_api import OpenAIResponseModelArgs
from agentlab.llm.tracking import cost_tracker_decorator

if TYPE_CHECKING:
from openai.types.responses import Response


@dataclass
class ToolUseAgentArgs(AgentArgs):
temperature: float = 0.1
model_args: OpenAIResponseModelArgs = None

def __post_init__(self):
try:
self.agent_name = f"ToolUse-{self.model_args.model_name}".replace("/", "_")
except AttributeError:
pass

def make_agent(self) -> bgym.Agent:
return ToolUseAgent(
temperature=self.temperature,
model_args=self.model_args,
)

def set_reproducibility_mode(self):
self.temperature = 0

def prepare(self):
return self.model_args.prepare_server()

def close(self):
return self.model_args.close_server()


class ToolUseAgent(bgym.Agent):
def __init__(
self,
temperature: float,
model_args: OpenAIResponseModelArgs,
):
self.temperature = temperature
self.chat = model_args.make_model()
self.model_args = model_args

self.action_set = bgym.HighLevelActionSet(["coord"], multiaction=False)

self.tools = self.action_set.to_tool_description()

# self.tools.append(
# {
# "type": "function",
# "name": "chain_of_thought",
# "description": "A tool that allows the agent to think step by step. Every other action must ALWAYS be preceeded by a call to this tool.",
# "parameters": {
# "type": "object",
# "properties": {
# "thoughts": {
# "type": "string",
# "description": "The agent's reasoning process.",
# },
# },
# "required": ["thoughts"],
# },
# }
# )

self.llm = model_args.make_model(extra_kwargs={"tools": self.tools})

self.messages = []

def obs_preprocessor(self, obs):
page = obs.pop("page", None)
if page is not None:
obs["screenshot"] = extract_screenshot(page)
else:
raise ValueError("No page found in the observation.")

return obs

@cost_tracker_decorator
def get_action(self, obs: Any) -> tuple[str, dict]:

if len(self.messages) == 0:
system_message = {
"role": "system",
"content": "You are an agent. Based on the observation, you will decide which action to take to accomplish your goal.",
}
goal_object = [el for el in obs["goal_object"]]
for content in goal_object:
if content["type"] == "text":
content["type"] = "input_text"
elif content["type"] == "image_url":
content["type"] = "input_image"
goal_message = {"role": "user", "content": goal_object}
goal_message["content"].append(
{
"type": "input_image",
"image_url": image_to_png_base64_url(obs["screenshot"]),
}
)
self.messages.append(system_message)
self.messages.append(goal_message)
else:
if obs["last_action_error"] == "":
self.messages.append(
{
"type": "function_call_output",
"call_id": self.previous_call_id,
"output": "Function call executed, see next observation.",
}
)
self.messages.append(
{
"role": "user",
"content": [
{
"type": "input_image",
"image_url": image_to_png_base64_url(obs["screenshot"]),
}
],
}
)
else:
self.messages.append(
{
"type": "function_call_output",
"call_id": self.previous_call_id,
"output": f"Function call failed: {obs['last_action_error']}",
}
)

response: "Response" = self.llm(
messages=self.messages,
temperature=self.temperature,
)

action = "noop()"
think = ""
for output in response.output:
if output.type == "function_call":
arguments = json.loads(output.arguments)
action = f"{output.name}({", ".join([f"{k}={v}" for k, v in arguments.items()])})"
self.previous_call_id = output.call_id
self.messages.append(output)
break
elif output.type == "reasoning":
if len(output.summary) > 0:
think += output.summary[0].text + "\n"
self.messages.append(output)

return (
action,
bgym.AgentInfo(
think=think,
chat_messages=[],
stats={},
),
)


MODEL_CONFIG = OpenAIResponseModelArgs(
model_name="o4-mini-2025-04-16",
max_total_tokens=200_000,
max_input_tokens=200_000,
max_new_tokens=100_000,
temperature=0.1,
vision_support=True,
)

AGENT_CONFIG = ToolUseAgentArgs(
temperature=0.1,
model_args=MODEL_CONFIG,
)
8 changes: 5 additions & 3 deletions src/agentlab/agents/visual_agent/agent_configs.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import bgym

import agentlab.agents.dynamic_prompting as dp
from agentlab.experiments.benchmark import HighLevelActionSetArgs
from agentlab.llm.llm_configs import CHAT_MODEL_ARGS_DICT

from .visual_agent import VisualAgentArgs
from .visual_agent_prompts import PromptFlags
import agentlab.agents.dynamic_prompting as dp
import bgym

# the other flags are ignored for this agent.
DEFAULT_OBS_FLAGS = dp.ObsFlags(
Expand All @@ -16,7 +18,7 @@
)

DEFAULT_ACTION_FLAGS = dp.ActionFlags(
action_set=bgym.HighLevelActionSetArgs(subsets=["coord"]),
action_set=HighLevelActionSetArgs(subsets=["coord"]),
long_description=True,
individual_examples=False,
)
Expand Down
5 changes: 3 additions & 2 deletions src/agentlab/agents/visual_agent/visual_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

from agentlab.agents import dynamic_prompting as dp
from agentlab.agents.agent_args import AgentArgs
from agentlab.experiments.benchmark import Benchmark
from agentlab.llm.chat_api import BaseModelArgs
from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, retry
from agentlab.llm.tracking import cost_tracker_decorator

from .visual_agent_prompts import PromptFlags, MainPrompt
from .visual_agent_prompts import MainPrompt, PromptFlags


@dataclass
Expand All @@ -34,7 +35,7 @@ def __post_init__(self):
except AttributeError:
pass

def set_benchmark(self, benchmark: bgym.Benchmark, demo_mode):
def set_benchmark(self, benchmark: Benchmark, demo_mode):
"""Override Some flags based on the benchmark."""
self.flags.obs.use_tabs = benchmark.is_multi_tab

Expand Down
2 changes: 2 additions & 0 deletions src/agentlab/experiments/benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .base import Benchmark, HighLevelActionSetArgs
from .configs import DEFAULT_BENCHMARKS
Loading
Loading