Skip to content

Commit 4df3ee9

Browse files
authored
(refactor): Update MCP Client to use FastMCP (OpenHands#8931)
1 parent aa54a25 commit 4df3ee9

File tree

7 files changed

+103
-256
lines changed

7 files changed

+103
-256
lines changed

openhands/mcp/client.py

Lines changed: 55 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1-
import asyncio
2-
import datetime
3-
from contextlib import AsyncExitStack
41
from typing import Optional
52

6-
from mcp import ClientSession
7-
from mcp.client.sse import sse_client
8-
from mcp.client.streamable_http import streamablehttp_client
3+
from fastmcp import Client
4+
from fastmcp.client.transports import SSETransport, StreamableHttpTransport
5+
from mcp import McpError
6+
from mcp.types import CallToolResult
97
from pydantic import BaseModel, Field
108

9+
from openhands.core.config.mcp_config import MCPSHTTPServerConfig, MCPSSEServerConfig
1110
from openhands.core.logger import openhands_logger as logger
1211
from openhands.mcp.tool import MCPClientTool
1312

@@ -17,198 +16,95 @@ class MCPClient(BaseModel):
1716
A collection of tools that connects to an MCP server and manages available tools through the Model Context Protocol.
1817
"""
1918

20-
session: Optional[ClientSession] = None
21-
exit_stack: AsyncExitStack = AsyncExitStack()
19+
client: Optional[Client] = None
2220
description: str = 'MCP client tools for server interaction'
2321
tools: list[MCPClientTool] = Field(default_factory=list)
2422
tool_map: dict[str, MCPClientTool] = Field(default_factory=dict)
2523

2624
class Config:
2725
arbitrary_types_allowed = True
2826

29-
async def connect_sse(
30-
self,
31-
server_url: str,
32-
api_key: str | None = None,
33-
conversation_id: str | None = None,
34-
timeout: float = 30.0,
35-
) -> None:
36-
"""Connect to an MCP server using SSE transport.
37-
38-
Args:
39-
server_url: The URL of the SSE server to connect to.
40-
timeout: Connection timeout in seconds. Default is 30 seconds.
41-
"""
42-
if not server_url:
43-
raise ValueError('Server URL is required.')
44-
if self.session:
45-
await self.disconnect()
46-
47-
try:
48-
# Use asyncio.wait_for to enforce the timeout
49-
async def connect_with_timeout():
50-
headers = (
51-
{
52-
'Authorization': f'Bearer {api_key}',
53-
's': api_key, # We need this for action execution server's MCP Router
54-
'X-Session-API-Key': api_key, # We need this for Remote Runtime
55-
}
56-
if api_key
57-
else {}
58-
)
59-
60-
if conversation_id:
61-
headers['X-OpenHands-Conversation-ID'] = conversation_id
62-
63-
# Convert float timeout to datetime.timedelta for consistency
64-
timeout_delta = datetime.timedelta(seconds=timeout)
65-
66-
streams_context = sse_client(
67-
url=server_url,
68-
headers=headers if headers else None,
69-
timeout=timeout,
70-
)
71-
streams = await self.exit_stack.enter_async_context(streams_context)
72-
# For SSE client, we only get read_stream and write_stream (2 values)
73-
read_stream, write_stream = streams
74-
self.session = await self.exit_stack.enter_async_context(
75-
ClientSession(
76-
read_stream, write_stream, read_timeout_seconds=timeout_delta
77-
)
78-
)
79-
await self._initialize_and_list_tools()
80-
81-
# Apply timeout to the entire connection process
82-
await asyncio.wait_for(connect_with_timeout(), timeout=timeout)
83-
except asyncio.TimeoutError:
84-
logger.error(
85-
f'Connection to {server_url} timed out after {timeout} seconds'
86-
)
87-
await self.disconnect() # Clean up resources
88-
raise # Re-raise the TimeoutError
89-
except Exception as e:
90-
logger.error(f'Error connecting to {server_url}: {str(e)}')
91-
await self.disconnect() # Clean up resources
92-
raise
93-
9427
async def _initialize_and_list_tools(self) -> None:
9528
"""Initialize session and populate tool map."""
96-
if not self.session:
29+
if not self.client:
9730
raise RuntimeError('Session not initialized.')
9831

99-
await self.session.initialize()
100-
response = await self.session.list_tools()
32+
async with self.client:
33+
tools = await self.client.list_tools()
10134

10235
# Clear existing tools
10336
self.tools = []
10437

10538
# Create proper tool objects for each server tool
106-
for tool in response.tools:
39+
for tool in tools:
10740
server_tool = MCPClientTool(
10841
name=tool.name,
10942
description=tool.description,
11043
inputSchema=tool.inputSchema,
111-
session=self.session,
44+
session=self.client,
11245
)
11346
self.tool_map[tool.name] = server_tool
11447
self.tools.append(server_tool)
11548

116-
logger.info(
117-
f'Connected to server with tools: {[tool.name for tool in response.tools]}'
118-
)
49+
logger.info(f'Connected to server with tools: {[tool.name for tool in tools]}')
11950

120-
async def call_tool(self, tool_name: str, args: dict):
121-
"""Call a tool on the MCP server."""
122-
if tool_name not in self.tool_map:
123-
raise ValueError(f'Tool {tool_name} not found.')
124-
# The MCPClientTool is primarily for metadata; use the session to call the actual tool.
125-
if not self.session:
126-
raise RuntimeError('Client session is not available.')
127-
return await self.session.call_tool(name=tool_name, arguments=args)
128-
129-
async def connect_shttp(
51+
async def connect_http(
13052
self,
131-
server_url: str,
132-
api_key: str | None = None,
53+
server: MCPSSEServerConfig | MCPSHTTPServerConfig,
13354
conversation_id: str | None = None,
13455
timeout: float = 30.0,
135-
) -> None:
136-
"""Connect to an MCP server using StreamableHTTP transport.
137-
138-
Args:
139-
server_url: The URL of the StreamableHTTP server to connect to.
140-
api_key: Optional API key for authentication.
141-
conversation_id: Optional conversation ID for session tracking.
142-
timeout: Connection timeout in seconds. Default is 30 seconds.
143-
"""
56+
):
57+
"""Connect to MCP server using SHTTP or SSE transport"""
58+
server_url = server.url
59+
api_key = server.api_key
60+
14461
if not server_url:
14562
raise ValueError('Server URL is required.')
146-
if self.session:
147-
await self.disconnect()
14863

14964
try:
150-
# Use asyncio.wait_for to enforce the timeout
151-
async def connect_with_timeout():
152-
headers = (
153-
{
154-
'Authorization': f'Bearer {api_key}',
155-
's': api_key, # We need this for action execution server's MCP Router
156-
'X-Session-API-Key': api_key, # We need this for Remote Runtime
157-
}
158-
if api_key
159-
else {}
160-
)
161-
162-
if conversation_id:
163-
headers['X-OpenHands-Conversation-ID'] = conversation_id
65+
headers = (
66+
{
67+
'Authorization': f'Bearer {api_key}',
68+
's': api_key, # We need this for action execution server's MCP Router
69+
'X-Session-API-Key': api_key, # We need this for Remote Runtime
70+
}
71+
if api_key
72+
else {}
73+
)
16474

165-
# Convert float timeout to datetime.timedelta
166-
timeout_delta = datetime.timedelta(seconds=timeout)
167-
sse_read_timeout_delta = datetime.timedelta(
168-
seconds=timeout * 10
169-
) # 10x longer for read timeout
75+
if conversation_id:
76+
headers['X-OpenHands-Conversation-ID'] = conversation_id
17077

171-
streams_context = streamablehttp_client(
78+
# Instantiate custom transports due to custom headers
79+
if isinstance(server, MCPSHTTPServerConfig):
80+
transport = StreamableHttpTransport(
17281
url=server_url,
17382
headers=headers if headers else None,
174-
timeout=timeout_delta,
175-
sse_read_timeout=sse_read_timeout_delta,
17683
)
177-
streams = await self.exit_stack.enter_async_context(streams_context)
178-
# For StreamableHTTP client, we get read_stream, write_stream, and get_session_id (3 values)
179-
read_stream, write_stream, _ = streams
180-
self.session = await self.exit_stack.enter_async_context(
181-
ClientSession(
182-
read_stream, write_stream, read_timeout_seconds=timeout_delta
183-
)
84+
else:
85+
transport = SSETransport(
86+
url=server_url,
87+
headers=headers if headers else None,
18488
)
185-
await self._initialize_and_list_tools()
18689

187-
# Apply timeout to the entire connection process
188-
await asyncio.wait_for(connect_with_timeout(), timeout=timeout)
189-
except asyncio.TimeoutError:
190-
logger.error(
191-
f'Connection to {server_url} timed out after {timeout} seconds'
192-
)
193-
await self.disconnect() # Clean up resources
194-
raise # Re-raise the TimeoutError
90+
self.client = Client(transport, timeout=timeout)
91+
92+
await self._initialize_and_list_tools()
93+
except McpError as e:
94+
logger.error(f'McpError connecting to {server_url}: {e}')
95+
raise # Re-raise the error
96+
19597
except Exception as e:
196-
logger.error(f'Error connecting to {server_url}: {str(e)}')
197-
await self.disconnect() # Clean up resources
98+
logger.error(f'Error connecting to {server_url}: {e}')
19899
raise
199100

200-
async def disconnect(self) -> None:
201-
"""Disconnect from the MCP server and clean up resources."""
202-
if self.session:
203-
try:
204-
# Close the session first
205-
if hasattr(self.session, 'close'):
206-
await self.session.close()
207-
# Then close the exit stack
208-
await self.exit_stack.aclose()
209-
except Exception as e:
210-
logger.error(f'Error during disconnect: {str(e)}')
211-
finally:
212-
self.session = None
213-
self.tools = []
214-
logger.info('Disconnected from MCP server')
101+
async def call_tool(self, tool_name: str, args: dict) -> CallToolResult:
102+
"""Call a tool on the MCP server."""
103+
if tool_name not in self.tool_map:
104+
raise ValueError(f'Tool {tool_name} not found.')
105+
# The MCPClientTool is primarily for metadata; use the session to call the actual tool.
106+
if not self.client:
107+
raise RuntimeError('Client session is not available.')
108+
109+
async with self.client:
110+
return await self.client.call_tool_mcp(name=tool_name, arguments=args)

openhands/mcp/utils.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -72,38 +72,22 @@ async def create_mcp_clients(
7272
mcp_clients = []
7373

7474
for server in servers:
75-
is_sse = isinstance(server, MCPSSEServerConfig)
76-
connection_type = 'SSE' if is_sse else 'SHTTP'
75+
is_shttp = isinstance(server, MCPSHTTPServerConfig)
76+
connection_type = 'SHTTP' if is_shttp else 'SSE'
7777
logger.info(
7878
f'Initializing MCP agent for {server} with {connection_type} connection...'
7979
)
8080
client = MCPClient()
8181

8282
try:
83-
if is_sse:
84-
await client.connect_sse(
85-
server.url,
86-
api_key=server.api_key,
87-
conversation_id=conversation_id,
88-
)
89-
else:
90-
await client.connect_shttp(
91-
server.url,
92-
api_key=server.api_key,
93-
conversation_id=conversation_id,
94-
)
83+
await client.connect_http(server, conversation_id=conversation_id)
9584

9685
# Only add the client to the list after a successful connection
9786
mcp_clients.append(client)
9887

9988
except Exception as e:
10089
logger.error(f'Failed to connect to {server}: {str(e)}', exc_info=True)
101-
try:
102-
await client.disconnect()
103-
except Exception as disconnect_error:
104-
logger.error(
105-
f'Error during disconnect after failed connection: {str(disconnect_error)}'
106-
)
90+
10791
return mcp_clients
10892

10993

@@ -143,13 +127,6 @@ async def fetch_mcp_tools_from_config(
143127
# Convert tools to the format expected by the agent
144128
mcp_tools = convert_mcp_clients_to_tools(mcp_clients)
145129

146-
# Always disconnect clients to clean up resources
147-
for mcp_client in mcp_clients:
148-
try:
149-
await mcp_client.disconnect()
150-
except Exception as disconnect_error:
151-
logger.error(f'Error disconnecting MCP client: {str(disconnect_error)}')
152-
153130
except Exception as e:
154131
logger.error(f'Error fetching MCP tools: {str(e)}')
155132
return []

openhands/runtime/impl/action_execution/action_execution_client.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -471,11 +471,6 @@ async def call_tool_mcp(self, action: MCPAction) -> Observation:
471471
# Call the tool and return the result
472472
# No need for try/finally since disconnect() is now just resetting state
473473
result = await call_tool_mcp_handler(mcp_clients, action)
474-
475-
# Reset client state (no active connections to worry about)
476-
for client in mcp_clients:
477-
await client.disconnect()
478-
479474
return result
480475

481476
def close(self) -> None:

0 commit comments

Comments
 (0)