diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index 9c075a66f38f..62ad243e915b 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -90,15 +90,7 @@ def __init__( self.pending_actions: deque[Action] = deque() self.reset() - self.mock_function_calling = False - if not self.llm.is_function_calling_active(): - logger.info( - f'Function calling not enabled for model {self.llm.config.model}. ' - 'Mocking function calling via prompting.' - ) - self.mock_function_calling = True - - # Function calling mode + # Retrieve the enabled tools self.tools = codeact_function_calling.get_tools( codeact_enable_browsing=self.config.codeact_enable_browsing, codeact_enable_jupyter=self.config.codeact_enable_jupyter, @@ -311,10 +303,7 @@ def get_observation_message( and len(obs.set_of_marks) > 0 and self.config.enable_som_visual_browsing and self.llm.vision_is_active() - and ( - self.mock_function_calling - or self.llm.is_visual_browser_tool_active() - ) + and self.llm.is_visual_browser_tool_supported() ): text += 'Image: Current webpage screenshot (Note that only visible portion of webpage is present in the screenshot. You may need to scroll to view the remaining portion of the web-page.)\n' message = Message( @@ -400,8 +389,6 @@ def step(self, state: State) -> Action: 'messages': self.llm.format_messages_for_llm(messages), } params['tools'] = self.tools - if self.mock_function_calling: - params['mock_function_calling'] = True response = self.llm.completion(**params) actions = codeact_function_calling.response_to_actions(response) for action in actions: diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 34124a479d7f..9c82dd088691 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -197,7 +197,7 @@ def wrapper(*args, **kwargs): from openhands.core.utils import json messages: list[dict[str, Any]] | dict[str, Any] = [] - mock_function_calling = kwargs.pop('mock_function_calling', False) + mock_function_calling = not self.is_function_calling_active() # some callers might send the model and messages directly # litellm allows positional args, like completion(model, messages, **kwargs) @@ -216,18 +216,21 @@ def wrapper(*args, **kwargs): # ensure we work with a list of messages messages = messages if isinstance(messages, list) else [messages] + + # handle conversion of to non-function calling messages if needed original_fncall_messages = copy.deepcopy(messages) mock_fncall_tools = None - if mock_function_calling: - assert ( - 'tools' in kwargs - ), "'tools' must be in kwargs when mock_function_calling is True" + # if the agent or caller has defined tools, and we mock via prompting, convert the messages + if mock_function_calling and 'tools' in kwargs: messages = convert_fncall_messages_to_non_fncall_messages( messages, kwargs['tools'] ) kwargs['messages'] = messages + + # add stop words if the model supports it if self.config.model not in MODELS_WITHOUT_STOP_WORDS: kwargs['stop'] = STOP_WORDS + mock_fncall_tools = kwargs.pop('tools') # if we have no messages, something went very wrong @@ -256,9 +259,10 @@ def wrapper(*args, **kwargs): self.metrics.add_response_latency(latency, response_id) non_fncall_response = copy.deepcopy(resp) - if mock_function_calling: + + # if we mocked function calling, and we have tools, convert the response back to function calling format + if mock_function_calling and mock_fncall_tools is not None: assert len(resp.choices) == 1 - assert mock_fncall_tools is not None non_fncall_response_message = resp.choices[0].message fn_call_messages_with_response = ( convert_non_fncall_messages_to_fncall_messages( @@ -488,7 +492,7 @@ def is_function_calling_active(self) -> bool: """ return self._function_calling_active - def is_visual_browser_tool_active(self) -> bool: + def is_visual_browser_tool_supported(self) -> bool: return ( self.config.model in VISUAL_BROWSING_TOOL_SUPPORTED_MODELS or self.config.model.split('/')[-1] in VISUAL_BROWSING_TOOL_SUPPORTED_MODELS diff --git a/tests/unit/test_codeact_agent.py b/tests/unit/test_codeact_agent.py index 39badebff046..58ce8d8329e2 100644 --- a/tests/unit/test_codeact_agent.py +++ b/tests/unit/test_codeact_agent.py @@ -464,16 +464,6 @@ def test_browser_tool(): assert 'description' in BrowserTool['function']['parameters']['properties']['code'] -def test_mock_function_calling(): - # Test mock function calling when LLM doesn't support it - llm = Mock() - llm.is_function_calling_active = lambda: False - config = AgentConfig() - config.enable_prompt_extensions = False - agent = CodeActAgent(llm=llm, config=config) - assert agent.mock_function_calling is True - - def test_response_to_actions_invalid_tool(): # Test response with invalid tool call mock_response = Mock()