Skip to content

Commit ddaa186

Browse files
authored
[GAIA] Add prompt improvement to alleviate solution parsing issue & support Tavily search tools (OpenHands#9057)
1 parent e6e0f46 commit ddaa186

File tree

11 files changed

+62
-32
lines changed

11 files changed

+62
-32
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
data/

evaluation/benchmarks/gaia/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,13 @@ This folder contains evaluation harness for evaluating agents on the [GAIA bench
66

77
Please follow instruction [here](../../README.md#setup) to setup your local development environment and LLM.
88

9+
To enable the Tavily MCP Server, you can add the Tavily API key under the `core` section of your `config.toml` file, like below:
10+
11+
```toml
12+
[core]
13+
search_api_key = "tvly-******"
14+
```
15+
916
## Run the evaluation
1017

1118
We are using the GAIA dataset hosted on [Hugging Face](https://huggingface.co/datasets/gaia-benchmark/GAIA).

evaluation/benchmarks/gaia/run_infer.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import asyncio
2+
import copy
23
import functools
34
import os
45
import re
56

67
import huggingface_hub
78
import pandas as pd
89
from datasets import load_dataset
10+
from pydantic import SecretStr
911

1012
from evaluation.benchmarks.gaia.scorer import question_scorer
1113
from evaluation.utils.shared import (
@@ -24,6 +26,7 @@
2426
OpenHandsConfig,
2527
get_llm_config_arg,
2628
get_parser,
29+
load_from_toml,
2730
)
2831
from openhands.core.config.utils import get_agent_config_arg
2932
from openhands.core.logger import openhands_logger as logger
@@ -41,15 +44,15 @@
4144
}
4245

4346
AGENT_CLS_TO_INST_SUFFIX = {
44-
'CodeActAgent': 'When you think you have solved the question, please first send your answer to user through message and then exit.\n'
47+
'CodeActAgent': 'When you think you have solved the question, please use the finish tool and include your final answer in the message parameter of the finish tool. Your final answer MUST be encapsulated within <solution> and </solution>.\n'
4548
}
4649

4750

4851
def get_config(
4952
metadata: EvalMetadata,
5053
) -> OpenHandsConfig:
5154
sandbox_config = get_default_sandbox_config_for_eval()
52-
sandbox_config.base_container_image = 'python:3.12-bookworm'
55+
sandbox_config.base_container_image = 'nikolaik/python-nodejs:python3.12-nodejs22'
5356
config = OpenHandsConfig(
5457
default_agent=metadata.agent_class,
5558
run_as_openhands=False,
@@ -67,6 +70,11 @@ def get_config(
6770
logger.info('Agent config not provided, using default settings')
6871
agent_config = config.get_agent_config(metadata.agent_class)
6972
agent_config.enable_prompt_extensions = False
73+
74+
config_copy = copy.deepcopy(config)
75+
load_from_toml(config_copy)
76+
if config_copy.search_api_key:
77+
config.search_api_key = SecretStr(config_copy.search_api_key)
7078
return config
7179

7280

@@ -134,16 +142,26 @@ def process_instance(
134142
dest_file = None
135143

136144
# Prepare instruction
137-
instruction = f'{instance["Question"]}\n'
145+
instruction = """You have one question to answer. It is paramount that you provide a correct answer.
146+
Give it all you can: I know for a fact that you have access to all the relevant tools to solve it and find the correct answer (the answer does exist). Failure or 'I cannot answer' or 'None found' will not be tolerated, success will be rewarded.
147+
You must make sure you find the correct answer! You MUST strictly follow the task-specific formatting instructions for your final answer.
148+
Here is the task:
149+
{task_question}
150+
""".format(
151+
task_question=instance['Question'],
152+
)
138153
logger.info(f'Instruction: {instruction}')
139154
if dest_file:
140155
instruction += f'\n\nThe mentioned file is provided in the workspace at: {dest_file.split("/")[-1]}'
141156

142-
instruction += 'IMPORTANT: You should ONLY interact with the environment provided to you AND NEVER ASK FOR HUMAN HELP.\n'
143-
instruction += 'Please encapsulate your final answer (answer ONLY) within <solution> and </solution>.\n'
157+
instruction += """IMPORTANT: When seeking information from a website, REFRAIN from arbitrary URL navigation. You should utilize the designated search engine tool with precise keywords to obtain relevant URLs or use the specific website's search interface. DO NOT navigate directly to specific URLs as they may not exist.\n\nFor example: if you want to search for a research paper on Arxiv, either use the search engine tool with specific keywords or navigate to arxiv.org and then use its interface.\n"""
158+
instruction += 'IMPORTANT: You should NEVER ask for Human Help.\n'
159+
instruction += 'IMPORTANT: Please encapsulate your final answer (answer ONLY) within <solution> and </solution>. Your answer will be evaluated using string matching approaches so it important that you STRICTLY adhere to the output formatting instructions specified in the task (e.g., alphabetization, sequencing, units, rounding, decimal places, etc.)\n'
144160
instruction += (
145161
'For example: The answer to the question is <solution> 42 </solution>.\n'
146162
)
163+
instruction += "IMPORTANT: Your final answer should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, express it numerically (i.e., with digits rather than words), do not use commas, and do not include units such as $ or percent signs unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities). If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.\n"
164+
147165
# NOTE: You can actually set slightly different instruction for different agents
148166
instruction += AGENT_CLS_TO_INST_SUFFIX.get(metadata.agent_class, '')
149167
logger.info(f'Instruction:\n{instruction}', extra={'msg_type': 'OBSERVATION'})
@@ -175,7 +193,7 @@ def process_instance(
175193
for event in reversed(state.history):
176194
if event.source == 'agent':
177195
if isinstance(event, AgentFinishAction):
178-
model_answer_raw = event.thought
196+
model_answer_raw = event.final_thought
179197
break
180198
elif isinstance(event, CmdRunAction):
181199
model_answer_raw = event.thought
@@ -222,6 +240,7 @@ def process_instance(
222240
error=state.last_error if state and state.last_error else None,
223241
test_result=test_result,
224242
)
243+
runtime.close()
225244
return output
226245

227246

@@ -253,6 +272,8 @@ def process_instance(
253272
if llm_config is None:
254273
raise ValueError(f'Could not find LLM config: --llm_config {args.llm_config}')
255274

275+
toml_config = OpenHandsConfig()
276+
load_from_toml(toml_config)
256277
metadata = make_metadata(
257278
llm_config=llm_config,
258279
dataset_name='gaia',
@@ -261,7 +282,10 @@ def process_instance(
261282
eval_note=args.eval_note,
262283
eval_output_dir=args.eval_output_dir,
263284
data_split=args.data_split,
264-
details={'gaia-level': args.level},
285+
details={
286+
'gaia-level': args.level,
287+
'mcp-servers': ['tavily'] if toml_config.search_api_key else [],
288+
},
265289
agent_config=agent_config,
266290
)
267291

evaluation/benchmarks/gaia/scripts/run_infer.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ echo "LEVELS: $LEVELS"
3939
COMMAND="poetry run python ./evaluation/benchmarks/gaia/run_infer.py \
4040
--agent-cls $AGENT \
4141
--llm-config $MODEL_CONFIG \
42-
--max-iterations 30 \
42+
--max-iterations 60 \
4343
--level $LEVELS \
4444
--data-split validation \
4545
--eval-num-workers $NUM_WORKERS \

openhands/cli/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,9 +273,9 @@ def on_event(event: Event) -> None:
273273
)
274274
)
275275

276-
config.mcp.stdio_servers.extend(openhands_mcp_stdio_servers)
276+
runtime.config.mcp.stdio_servers.extend(openhands_mcp_stdio_servers)
277277

278-
await add_mcp_tools_to_agent(agent, runtime, memory, config)
278+
await add_mcp_tools_to_agent(agent, runtime, memory)
279279

280280
# Clear loading animation
281281
is_loaded.set()

openhands/core/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,9 @@ async def run_controller(
139139
config.mcp_host, config, None
140140
)
141141
)
142-
config.mcp.stdio_servers.extend(openhands_mcp_stdio_servers)
142+
runtime.config.mcp.stdio_servers.extend(openhands_mcp_stdio_servers)
143143

144-
await add_mcp_tools_to_agent(agent, runtime, memory, config)
144+
await add_mcp_tools_to_agent(agent, runtime, memory)
145145

146146
replay_events: list[Event] | None = None
147147
if config.replay_trajectory_path:

openhands/mcp/utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
MCPSHTTPServerConfig,
1111
MCPSSEServerConfig,
1212
)
13-
from openhands.core.config.openhands_config import OpenHandsConfig
1413
from openhands.core.logger import openhands_logger as logger
1514
from openhands.events.action.mcp import MCPAction
1615
from openhands.events.observation.mcp import MCPObservation
@@ -187,9 +186,7 @@ async def call_tool_mcp(mcp_clients: list[MCPClient], action: MCPAction) -> Obse
187186
)
188187

189188

190-
async def add_mcp_tools_to_agent(
191-
agent: 'Agent', runtime: Runtime, memory: 'Memory', app_config: OpenHandsConfig
192-
):
189+
async def add_mcp_tools_to_agent(agent: 'Agent', runtime: Runtime, memory: 'Memory'):
193190
"""
194191
Add MCP tools to an agent.
195192
"""
@@ -208,7 +205,6 @@ async def add_mcp_tools_to_agent(
208205
extra_stdio_servers = []
209206

210207
# Add microagent MCP tools if available
211-
mcp_config: MCPConfig = app_config.mcp
212208
microagent_mcp_configs = memory.get_microagent_mcp_tools()
213209
for mcp_config in microagent_mcp_configs:
214210
if mcp_config.sse_servers:

openhands/server/session/agent_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ async def start(
158158
# NOTE: this needs to happen before controller is created
159159
# so MCP tools can be included into the SystemMessageAction
160160
if self.runtime and runtime_connected and agent.config.enable_mcp:
161-
await add_mcp_tools_to_agent(agent, self.runtime, self.memory, config)
161+
await add_mcp_tools_to_agent(agent, self.runtime, self.memory)
162162

163163
if replay_json:
164164
initial_message = self._run_replay(

tests/runtime/test_microagent.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,6 @@ async def test_add_mcp_tools_from_microagents():
385385
"""Test that add_mcp_tools_to_agent adds tools from microagents."""
386386
# Import ActionExecutionClient for mocking
387387

388-
from openhands.core.config.openhands_config import OpenHandsConfig
389388
from openhands.runtime.impl.action_execution.action_execution_client import (
390389
ActionExecutionClient,
391390
)
@@ -394,10 +393,6 @@ async def test_add_mcp_tools_from_microagents():
394393
mock_agent = MagicMock()
395394
mock_runtime = MagicMock(spec=ActionExecutionClient)
396395
mock_memory = MagicMock()
397-
mock_mcp_config = MCPConfig()
398-
399-
# Create a mock OpenHandsConfig with the MCP config
400-
mock_app_config = OpenHandsConfig(mcp=mock_mcp_config, search_api_key=None)
401396

402397
# Configure the mock memory to return a microagent MCP config
403398
mock_stdio_server = MCPStdioServerConfig(
@@ -425,9 +420,7 @@ async def test_add_mcp_tools_from_microagents():
425420
new=AsyncMock(return_value=[mock_tool]),
426421
):
427422
# Call the function with the OpenHandsConfig instead of MCPConfig
428-
await add_mcp_tools_to_agent(
429-
mock_agent, mock_runtime, mock_memory, mock_app_config
430-
)
423+
await add_mcp_tools_to_agent(mock_agent, mock_runtime, mock_memory)
431424

432425
# Verify that the memory's get_microagent_mcp_tools was called
433426
mock_memory.get_microagent_mcp_tools.assert_called_once()

tests/unit/test_agent_controller.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import copy
23
from unittest.mock import ANY, AsyncMock, MagicMock, patch
34
from uuid import uuid4
45

@@ -259,6 +260,7 @@ def on_event(event: Event):
259260

260261
test_event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
261262
runtime.event_stream = test_event_stream
263+
runtime.config = copy.deepcopy(config)
262264

263265
def on_event_memory(event: Event):
264266
if isinstance(event, RecallAction):
@@ -326,6 +328,7 @@ def on_event(event: Event):
326328

327329
test_event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
328330
runtime.event_stream = test_event_stream
331+
runtime.config = copy.deepcopy(config)
329332

330333
def on_event_memory(event: Event):
331334
if isinstance(event, RecallAction):
@@ -762,6 +765,7 @@ def on_event(event: Event):
762765

763766
event_stream.subscribe(EventStreamSubscriber.RUNTIME, on_event, str(uuid4()))
764767
runtime.event_stream = event_stream
768+
runtime.config = copy.deepcopy(config)
765769

766770
def on_event_memory(event: Event):
767771
if isinstance(event, RecallAction):
@@ -883,15 +887,17 @@ def on_event_memory(event: Event):
883887
test_event_stream.subscribe(
884888
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
885889
)
890+
config = OpenHandsConfig(max_iterations=max_iterations)
886891
mock_runtime.event_stream = test_event_stream
892+
mock_runtime.config = copy.deepcopy(config)
887893

888894
# Now we can run the controller for a fixed number of steps. Since the step
889895
# state is set to error out before then, if this terminates and we have a
890896
# record of the error being thrown we can be confident that the controller
891897
# handles the truncation correctly.
892898
final_state = await asyncio.wait_for(
893899
run_controller(
894-
config=OpenHandsConfig(max_iterations=max_iterations),
900+
config=config,
895901
initial_user_action=MessageAction(content='INITIAL'),
896902
runtime=mock_runtime,
897903
sid='test',
@@ -1027,11 +1033,13 @@ def on_event_memory(event: Event):
10271033
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
10281034
)
10291035
mock_runtime.event_stream = test_event_stream
1036+
config = OpenHandsConfig(max_iterations=5)
1037+
mock_runtime.config = copy.deepcopy(config)
10301038

10311039
try:
10321040
state = await asyncio.wait_for(
10331041
run_controller(
1034-
config=OpenHandsConfig(max_iterations=5),
1042+
config=config,
10351043
initial_user_action=MessageAction(content='INITIAL'),
10361044
runtime=mock_runtime,
10371045
sid='test',
@@ -1104,10 +1112,12 @@ def on_event_memory(event: Event):
11041112
EventStreamSubscriber.MEMORY, on_event_memory, str(uuid4())
11051113
)
11061114
mock_runtime.event_stream = test_event_stream
1115+
config = OpenHandsConfig(max_iterations=3)
1116+
mock_runtime.config = copy.deepcopy(config)
11071117
try:
11081118
state = await asyncio.wait_for(
11091119
run_controller(
1110-
config=OpenHandsConfig(max_iterations=3),
1120+
config=config,
11111121
initial_user_action=MessageAction(content='INITIAL'),
11121122
runtime=mock_runtime,
11131123
sid='test',
@@ -1167,6 +1177,7 @@ def agent_step_fn(state):
11671177

11681178
runtime = MagicMock(spec=ActionExecutionClient)
11691179
runtime.event_stream = event_stream
1180+
runtime.config = copy.deepcopy(config)
11701181

11711182
# Create a real Memory instance
11721183
memory = Memory(event_stream=event_stream, sid='test-memory')

0 commit comments

Comments
 (0)