Skip to content

Commit dfa5467

Browse files
ryanhoangtopenhands-agentenyst
authored
[OH-Versa] Add remaining browsing & GAIA eval improvement (#9015)
Co-authored-by: openhands <[email protected]> Co-authored-by: Engel Nyst <[email protected]>
1 parent 76914e3 commit dfa5467

File tree

16 files changed

+384
-30
lines changed

16 files changed

+384
-30
lines changed

evaluation/benchmarks/gaia/run_infer.py

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,20 @@
33
import functools
44
import os
55
import re
6+
import shutil
7+
import zipfile
68

79
import huggingface_hub
810
import pandas as pd
911
from datasets import load_dataset
12+
from PIL import Image
1013
from pydantic import SecretStr
1114

1215
from evaluation.benchmarks.gaia.scorer import question_scorer
16+
from evaluation.benchmarks.gaia.utils import (
17+
image_to_jpg_base64_url,
18+
image_to_png_base64_url,
19+
)
1320
from evaluation.utils.shared import (
1421
EvalMetadata,
1522
EvalOutput,
@@ -97,27 +104,44 @@ def initialize_runtime(
97104
if instance['file_name'] != '':
98105
# if this question comes with a file, we need to save it to the workspace
99106
assert metadata.data_split is not None
107+
extension_name = instance['file_name'].split('.')[-1]
100108
src_file = os.path.join(
101109
DATASET_CACHE_DIR, '2023', metadata.data_split, instance['file_name']
102110
)
103111
assert os.path.exists(src_file)
104-
dest_file = os.path.join('/workspace', instance['file_name'])
105-
runtime.copy_to(src_file, dest_file)
106-
107-
# rename to file.extension_name
108-
extension_name = instance['file_name'].split('.')[-1]
109-
action = CmdRunAction(
110-
command=f'mv /workspace/{instance["file_name"]} /workspace/file.{extension_name}'
111-
)
112-
logger.info(action, extra={'msg_type': 'ACTION'})
113-
obs = runtime.run_action(action)
114-
assert obs.exit_code == 0
112+
if extension_name == 'zip':
113+
temp_dir = os.path.join(
114+
DATASET_CACHE_DIR, '2023', metadata.data_split, 'tmp_file'
115+
)
116+
os.makedirs(temp_dir, exist_ok=True)
117+
with zipfile.ZipFile(src_file, 'r') as zip_ref:
118+
zip_ref.extractall(temp_dir)
119+
for root, dirs, files in os.walk(temp_dir):
120+
for file in files:
121+
dest_file = '/workspace'
122+
runtime.copy_to(os.path.join(root, file), dest_file)
123+
shutil.rmtree(temp_dir)
124+
elif extension_name not in ['jpg', 'png']:
125+
dest_file = '/workspace'
126+
runtime.copy_to(src_file, dest_file)
127+
128+
# rename to file.extension_name
129+
action = CmdRunAction(
130+
command=f'mv /workspace/{instance["file_name"]} /workspace/file.{extension_name}'
131+
)
132+
logger.info(action, extra={'msg_type': 'ACTION'})
133+
obs = runtime.run_action(action)
134+
assert obs.exit_code == 0
115135

116136
action = CmdRunAction(command='cd /workspace')
117137
logger.info(action, extra={'msg_type': 'ACTION'})
118138
obs = runtime.run_action(action)
119139
assert obs.exit_code == 0
120140

141+
action = CmdRunAction(
142+
command='apt-get update && apt-get install -y ffmpeg && apt-get install -y ffprobe'
143+
)
144+
runtime.run_action(action)
121145
logger.info(f'{"-" * 50} END Runtime Initialization Fn {"-" * 50}')
122146

123147

@@ -151,8 +175,31 @@ def process_instance(
151175
task_question=instance['Question'],
152176
)
153177
logger.info(f'Instruction: {instruction}')
178+
image_urls = []
154179
if dest_file:
155-
instruction += f'\n\nThe mentioned file is provided in the workspace at: {dest_file.split("/")[-1]}'
180+
if extension_name not in ['jpg', 'png', 'zip']:
181+
instruction += f'To solve this task you will have to use the attached file provided in the workspace at location: {dest_file}\n\n'
182+
elif extension_name == 'zip':
183+
filenames = []
184+
src_file = os.path.join(
185+
DATASET_CACHE_DIR, '2023', metadata.data_split, instance['file_name']
186+
)
187+
with zipfile.ZipFile(src_file, 'r') as zip_ref:
188+
filenames = zip_ref.namelist()
189+
190+
filenames = [f'/workspace/{file}' for file in filenames]
191+
filenames = ', '.join(filenames)
192+
instruction += f'To solve this task you will have to use the attached files provided in the workspace at locations: {filenames}\n\n'
193+
else: # Image files: jpg, png
194+
src_file = os.path.join(
195+
DATASET_CACHE_DIR, '2023', metadata.data_split, instance['file_name']
196+
)
197+
instruction += 'Image: To solve this task you will have to use the image shown below.\n\n'
198+
image = Image.open(src_file)
199+
if extension_name == 'jpg':
200+
image_urls.append(image_to_jpg_base64_url(image))
201+
else:
202+
image_urls.append(image_to_png_base64_url(image))
156203

157204
instruction += """IMPORTANT: When seeking information from a website, REFRAIN from arbitrary URL navigation. You should utilize the designated search engine tool with precise keywords to obtain relevant URLs or use the specific website's search interface. DO NOT navigate directly to specific URLs as they may not exist.\n\nFor example: if you want to search for a research paper on Arxiv, either use the search engine tool with specific keywords or navigate to arxiv.org and then use its interface.\n"""
158205
instruction += 'IMPORTANT: You should NEVER ask for Human Help.\n'
@@ -174,7 +221,9 @@ def process_instance(
174221
state: State | None = asyncio.run(
175222
run_controller(
176223
config=config,
177-
initial_user_action=MessageAction(content=instruction),
224+
initial_user_action=MessageAction(
225+
content=instruction, image_urls=image_urls
226+
),
178227
runtime=runtime,
179228
fake_user_response_fn=AGENT_CLS_TO_FAKE_USER_RESPONSE_FN[
180229
metadata.agent_class

evaluation/benchmarks/gaia/utils.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import base64
2+
import io
3+
4+
import numpy as np
5+
from PIL import Image
6+
7+
8+
def image_to_png_base64_url(
9+
image: np.ndarray | Image.Image, add_data_prefix: bool = True
10+
):
11+
"""Convert a numpy array to a base64 encoded png image url."""
12+
if isinstance(image, np.ndarray):
13+
image = Image.fromarray(image)
14+
if image.mode in ('RGBA', 'LA'):
15+
image = image.convert('RGB')
16+
buffered = io.BytesIO()
17+
image.save(buffered, format='PNG')
18+
19+
image_base64 = base64.b64encode(buffered.getvalue()).decode()
20+
return (
21+
f'data:image/png;base64,{image_base64}'
22+
if add_data_prefix
23+
else f'{image_base64}'
24+
)
25+
26+
27+
def image_to_jpg_base64_url(
28+
image: np.ndarray | Image.Image, add_data_prefix: bool = True
29+
):
30+
"""Convert a numpy array to a base64 encoded jpeg image url."""
31+
if isinstance(image, np.ndarray):
32+
image = Image.fromarray(image)
33+
if image.mode in ('RGBA', 'LA'):
34+
image = image.convert('RGB')
35+
buffered = io.BytesIO()
36+
image.save(buffered, format='JPEG')
37+
38+
image_base64 = base64.b64encode(buffered.getvalue()).decode()
39+
return (
40+
f'data:image/jpeg;base64,{image_base64}'
41+
if add_data_prefix
42+
else f'{image_base64}'
43+
)

evaluation/utils/shared.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,15 @@ def codeact_user_response(
109109
) -> str:
110110
encaps_str = (
111111
(
112-
'Please encapsulate your final answer (answer ONLY) within <solution> and </solution>.\n'
112+
'Your final answer MUST be encapsulated within <solution> and </solution>.\n'
113113
'For example: The answer to the question is <solution> 42 </solution>.\n'
114114
)
115115
if encapsulate_solution
116116
else ''
117117
)
118118
msg = (
119119
'Please continue working on the task on whatever approach you think is suitable.\n'
120-
'If you think you have solved the task, please first send your answer to user through message and then finish the interaction.\n'
120+
'When you think you have solved the question, please use the finish tool and include your final answer in the message parameter of the finish tool.\n'
121121
f'{encaps_str}'
122122
'IMPORTANT: YOU SHOULD NEVER ASK FOR HUMAN HELP.\n'
123123
)

openhands/core/schema/observation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,6 @@ class ObservationType(str, Enum):
5252

5353
MCP = 'mcp'
5454
"""Result of a MCP Server operation"""
55+
56+
DOWNLOAD = 'download'
57+
"""Result of downloading/opening a file via the browser"""

openhands/events/observation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
NullObservation,
1717
)
1818
from openhands.events.observation.error import ErrorObservation
19+
from openhands.events.observation.file_download import FileDownloadObservation
1920
from openhands.events.observation.files import (
2021
FileEditObservation,
2122
FileReadObservation,
@@ -46,4 +47,5 @@
4647
'RecallObservation',
4748
'RecallType',
4849
'MCPObservation',
50+
'FileDownloadObservation',
4951
]

openhands/events/observation/browse.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class BrowserOutputObservation(Observation):
3232
last_browser_action: str = ''
3333
last_browser_action_error: str = ''
3434
focused_element_bid: str = ''
35+
filter_visible_only: bool = False
3536

3637
@property
3738
def message(self) -> str:
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from dataclasses import dataclass
2+
3+
from openhands.core.schema import ObservationType
4+
from openhands.events.observation.observation import Observation
5+
6+
7+
@dataclass
8+
class FileDownloadObservation(Observation):
9+
file_path: str
10+
observation: str = ObservationType.DOWNLOAD
11+
12+
@property
13+
def message(self) -> str:
14+
return f'Downloaded the file at location: {self.file_path}'
15+
16+
def __str__(self) -> str:
17+
ret = (
18+
'**FileDownloadObservation**\n'
19+
f'Location of downloaded file: {self.file_path}\n'
20+
)
21+
return ret

openhands/events/serialization/observation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
NullObservation,
2121
)
2222
from openhands.events.observation.error import ErrorObservation
23+
from openhands.events.observation.file_download import FileDownloadObservation
2324
from openhands.events.observation.files import (
2425
FileEditObservation,
2526
FileReadObservation,
@@ -47,6 +48,7 @@
4748
AgentThinkObservation,
4849
RecallObservation,
4950
MCPObservation,
51+
FileDownloadObservation,
5052
)
5153

5254
OBSERVATION_TYPE_TO_CLASS = {

openhands/memory/conversation_memory.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
AgentThinkObservation,
2929
BrowserOutputObservation,
3030
CmdOutputObservation,
31+
FileDownloadObservation,
3132
FileEditObservation,
3233
FileReadObservation,
3334
IPythonRunCellObservation,
@@ -288,7 +289,12 @@ def _process_action(
288289
role = 'user' if action.source == 'user' else 'assistant'
289290
content = [TextContent(text=action.content or '')]
290291
if vision_is_active and action.image_urls:
291-
content.append(ImageContent(image_urls=action.image_urls))
292+
if role == 'user':
293+
for idx, url in enumerate(action.image_urls):
294+
content.append(TextContent(text=f'Image {idx + 1}:'))
295+
content.append(ImageContent(image_urls=[url]))
296+
else:
297+
content.append(ImageContent(image_urls=action.image_urls))
292298
if role not in ('user', 'system', 'assistant', 'tool'):
293299
raise ValueError(f'Invalid role: {role}')
294300
return [
@@ -339,6 +345,7 @@ def _process_observation(
339345
- AgentDelegateObservation: Formats results from delegated agent tasks
340346
- ErrorObservation: Formats error messages from failed actions
341347
- UserRejectObservation: Formats user rejection messages
348+
- FileDownloadObservation: Formats the result of a browsing action that opened/downloaded a file
342349
343350
In function calling mode, observations with tool_call_metadata are stored in
344351
tool_call_id_to_message for later processing instead of being returned immediately.
@@ -429,7 +436,7 @@ def _process_observation(
429436
and enable_som_visual_browsing
430437
and vision_is_active
431438
):
432-
text += 'Image: Current webpage screenshot (Note that only visible portion of webpage is present in the screenshot. You may need to scroll to view the remaining portion of the web-page.)\n'
439+
text += 'Image: Current webpage screenshot (Note that only visible portion of webpage is present in the screenshot. However, the Accessibility tree contains information from the entire webpage.)\n'
433440

434441
# Determine which image to use and validate it
435442
image_url = None
@@ -492,6 +499,9 @@ def _process_observation(
492499
elif isinstance(obs, AgentCondensationObservation):
493500
text = truncate_content(obs.content, max_message_chars)
494501
message = Message(role='user', content=[TextContent(text=text)])
502+
elif isinstance(obs, FileDownloadObservation):
503+
text = truncate_content(obs.content, max_message_chars)
504+
message = Message(role='user', content=[TextContent(text=text)])
495505
elif (
496506
isinstance(obs, RecallObservation)
497507
and self.agent_config.enable_prompt_extensions

openhands/runtime/action_execution_server.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pathlib import Path
2121
from zipfile import ZipFile
2222

23+
import puremagic
2324
from binaryornot.check import is_binary
2425
from fastapi import Depends, FastAPI, HTTPException, Request, UploadFile
2526
from fastapi.exceptions import RequestValidationError
@@ -51,6 +52,7 @@
5152
from openhands.events.observation import (
5253
CmdOutputObservation,
5354
ErrorObservation,
55+
FileDownloadObservation,
5456
FileEditObservation,
5557
FileReadObservation,
5658
FileWriteObservation,
@@ -193,6 +195,8 @@ def __init__(
193195
self.start_time = time.time()
194196
self.last_execution_time = self.start_time
195197
self._initialized = False
198+
self.downloaded_files: list[str] = []
199+
self.downloads_directory = '/workspace/.downloads'
196200

197201
self.max_memory_gb: int | None = None
198202
if _override_max_memory_gb := os.environ.get('RUNTIME_MAX_MEMORY_GB', None):
@@ -603,7 +607,45 @@ async def browse_interactive(self, action: BrowseInteractiveAction) -> Observati
603607
'Browser functionality is not supported on Windows.'
604608
)
605609
await self._ensure_browser_ready()
606-
return await browse(action, self.browser, self.initial_cwd)
610+
browser_observation = await browse(action, self.browser, self.initial_cwd)
611+
if not browser_observation.error:
612+
return browser_observation
613+
else:
614+
curr_files = os.listdir(self.downloads_directory)
615+
new_download = False
616+
for file in curr_files:
617+
if file not in self.downloaded_files:
618+
new_download = True
619+
self.downloaded_files.append(file)
620+
break # FIXME: assuming only one file will be downloaded for simplicity
621+
622+
if not new_download:
623+
return browser_observation
624+
else:
625+
# A new file is downloaded in self.downloads_directory, shift file to /workspace
626+
src_path = os.path.join(
627+
self.downloads_directory, self.downloaded_files[-1]
628+
)
629+
# Guess extension of file using puremagic and add it to tgt_path file name
630+
file_ext = ''
631+
try:
632+
guesses = puremagic.magic_file(src_path)
633+
if len(guesses) > 0:
634+
ext = guesses[0].extension.strip()
635+
if len(ext) > 0:
636+
file_ext = ext
637+
except Exception as _:
638+
pass
639+
640+
tgt_path = os.path.join(
641+
'/workspace', f'file_{len(self.downloaded_files)}{file_ext}'
642+
)
643+
shutil.copy(src_path, tgt_path)
644+
file_download_obs = FileDownloadObservation(
645+
content=f'Execution of the previous action {action.browser_actions} resulted in a file download. The downloaded file is saved at location: {tgt_path}',
646+
file_path=tgt_path,
647+
)
648+
return file_download_obs
607649

608650
def close(self):
609651
self.memory_monitor.stop_monitoring()

openhands/runtime/browser/browser_env.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ def browser_process(self) -> None:
9494
headless=True,
9595
disable_env_checker=True,
9696
tags_to_mark='all',
97+
timeout=100000,
98+
pw_context_kwargs={'accept_downloads': True},
99+
pw_chromium_kwargs={'downloads_path': '/workspace/.downloads/'},
97100
)
98101
obs, info = env.reset()
99102

@@ -105,6 +108,7 @@ def browser_process(self) -> None:
105108
if self.eval_mode:
106109
self.eval_goal = obs['goal']
107110
if 'goal_object' in obs:
111+
obs['goal_object'] = list(obs['goal_object'])
108112
if len(obs['goal_object']) > 0:
109113
self.eval_goal = obs['goal_object'][0]['text']
110114
for message in obs['goal_object']:
@@ -182,7 +186,7 @@ def browser_process(self) -> None:
182186
pass
183187
return
184188

185-
def step(self, action_str: str, timeout: float = 100) -> dict:
189+
def step(self, action_str: str, timeout: float = 120) -> dict:
186190
"""Execute an action in the browser environment and return the observation."""
187191
unique_request_id = str(uuid.uuid4())
188192
self.agent_side.send((unique_request_id, {'action': action_str}))

0 commit comments

Comments
 (0)