1- import asyncio
2- import datetime
3- from contextlib import AsyncExitStack
41from typing import Optional
52
6- from mcp import ClientSession
7- from mcp .client .sse import sse_client
8- from mcp .client .streamable_http import streamablehttp_client
3+ from fastmcp import Client
4+ from fastmcp .client .transports import SSETransport , StreamableHttpTransport
5+ from mcp import McpError
6+ from mcp .types import CallToolResult
97from pydantic import BaseModel , Field
108
9+ from openhands .core .config .mcp_config import MCPSHTTPServerConfig , MCPSSEServerConfig
1110from openhands .core .logger import openhands_logger as logger
1211from openhands .mcp .tool import MCPClientTool
1312
@@ -17,198 +16,95 @@ class MCPClient(BaseModel):
1716 A collection of tools that connects to an MCP server and manages available tools through the Model Context Protocol.
1817 """
1918
20- session : Optional [ClientSession ] = None
21- exit_stack : AsyncExitStack = AsyncExitStack ()
19+ client : Optional [Client ] = None
2220 description : str = 'MCP client tools for server interaction'
2321 tools : list [MCPClientTool ] = Field (default_factory = list )
2422 tool_map : dict [str , MCPClientTool ] = Field (default_factory = dict )
2523
2624 class Config :
2725 arbitrary_types_allowed = True
2826
29- async def connect_sse (
30- self ,
31- server_url : str ,
32- api_key : str | None = None ,
33- conversation_id : str | None = None ,
34- timeout : float = 30.0 ,
35- ) -> None :
36- """Connect to an MCP server using SSE transport.
37-
38- Args:
39- server_url: The URL of the SSE server to connect to.
40- timeout: Connection timeout in seconds. Default is 30 seconds.
41- """
42- if not server_url :
43- raise ValueError ('Server URL is required.' )
44- if self .session :
45- await self .disconnect ()
46-
47- try :
48- # Use asyncio.wait_for to enforce the timeout
49- async def connect_with_timeout ():
50- headers = (
51- {
52- 'Authorization' : f'Bearer { api_key } ' ,
53- 's' : api_key , # We need this for action execution server's MCP Router
54- 'X-Session-API-Key' : api_key , # We need this for Remote Runtime
55- }
56- if api_key
57- else {}
58- )
59-
60- if conversation_id :
61- headers ['X-OpenHands-Conversation-ID' ] = conversation_id
62-
63- # Convert float timeout to datetime.timedelta for consistency
64- timeout_delta = datetime .timedelta (seconds = timeout )
65-
66- streams_context = sse_client (
67- url = server_url ,
68- headers = headers if headers else None ,
69- timeout = timeout ,
70- )
71- streams = await self .exit_stack .enter_async_context (streams_context )
72- # For SSE client, we only get read_stream and write_stream (2 values)
73- read_stream , write_stream = streams
74- self .session = await self .exit_stack .enter_async_context (
75- ClientSession (
76- read_stream , write_stream , read_timeout_seconds = timeout_delta
77- )
78- )
79- await self ._initialize_and_list_tools ()
80-
81- # Apply timeout to the entire connection process
82- await asyncio .wait_for (connect_with_timeout (), timeout = timeout )
83- except asyncio .TimeoutError :
84- logger .error (
85- f'Connection to { server_url } timed out after { timeout } seconds'
86- )
87- await self .disconnect () # Clean up resources
88- raise # Re-raise the TimeoutError
89- except Exception as e :
90- logger .error (f'Error connecting to { server_url } : { str (e )} ' )
91- await self .disconnect () # Clean up resources
92- raise
93-
9427 async def _initialize_and_list_tools (self ) -> None :
9528 """Initialize session and populate tool map."""
96- if not self .session :
29+ if not self .client :
9730 raise RuntimeError ('Session not initialized.' )
9831
99- await self .session . initialize ()
100- response = await self .session .list_tools ()
32+ async with self .client :
33+ tools = await self .client .list_tools ()
10134
10235 # Clear existing tools
10336 self .tools = []
10437
10538 # Create proper tool objects for each server tool
106- for tool in response . tools :
39+ for tool in tools :
10740 server_tool = MCPClientTool (
10841 name = tool .name ,
10942 description = tool .description ,
11043 inputSchema = tool .inputSchema ,
111- session = self .session ,
44+ session = self .client ,
11245 )
11346 self .tool_map [tool .name ] = server_tool
11447 self .tools .append (server_tool )
11548
116- logger .info (
117- f'Connected to server with tools: { [tool .name for tool in response .tools ]} '
118- )
49+ logger .info (f'Connected to server with tools: { [tool .name for tool in tools ]} ' )
11950
120- async def call_tool (self , tool_name : str , args : dict ):
121- """Call a tool on the MCP server."""
122- if tool_name not in self .tool_map :
123- raise ValueError (f'Tool { tool_name } not found.' )
124- # The MCPClientTool is primarily for metadata; use the session to call the actual tool.
125- if not self .session :
126- raise RuntimeError ('Client session is not available.' )
127- return await self .session .call_tool (name = tool_name , arguments = args )
128-
129- async def connect_shttp (
51+ async def connect_http (
13052 self ,
131- server_url : str ,
132- api_key : str | None = None ,
53+ server : MCPSSEServerConfig | MCPSHTTPServerConfig ,
13354 conversation_id : str | None = None ,
13455 timeout : float = 30.0 ,
135- ) -> None :
136- """Connect to an MCP server using StreamableHTTP transport.
137-
138- Args:
139- server_url: The URL of the StreamableHTTP server to connect to.
140- api_key: Optional API key for authentication.
141- conversation_id: Optional conversation ID for session tracking.
142- timeout: Connection timeout in seconds. Default is 30 seconds.
143- """
56+ ):
57+ """Connect to MCP server using SHTTP or SSE transport"""
58+ server_url = server .url
59+ api_key = server .api_key
60+
14461 if not server_url :
14562 raise ValueError ('Server URL is required.' )
146- if self .session :
147- await self .disconnect ()
14863
14964 try :
150- # Use asyncio.wait_for to enforce the timeout
151- async def connect_with_timeout ():
152- headers = (
153- {
154- 'Authorization' : f'Bearer { api_key } ' ,
155- 's' : api_key , # We need this for action execution server's MCP Router
156- 'X-Session-API-Key' : api_key , # We need this for Remote Runtime
157- }
158- if api_key
159- else {}
160- )
161-
162- if conversation_id :
163- headers ['X-OpenHands-Conversation-ID' ] = conversation_id
65+ headers = (
66+ {
67+ 'Authorization' : f'Bearer { api_key } ' ,
68+ 's' : api_key , # We need this for action execution server's MCP Router
69+ 'X-Session-API-Key' : api_key , # We need this for Remote Runtime
70+ }
71+ if api_key
72+ else {}
73+ )
16474
165- # Convert float timeout to datetime.timedelta
166- timeout_delta = datetime .timedelta (seconds = timeout )
167- sse_read_timeout_delta = datetime .timedelta (
168- seconds = timeout * 10
169- ) # 10x longer for read timeout
75+ if conversation_id :
76+ headers ['X-OpenHands-Conversation-ID' ] = conversation_id
17077
171- streams_context = streamablehttp_client (
78+ # Instantiate custom transports due to custom headers
79+ if isinstance (server , MCPSHTTPServerConfig ):
80+ transport = StreamableHttpTransport (
17281 url = server_url ,
17382 headers = headers if headers else None ,
174- timeout = timeout_delta ,
175- sse_read_timeout = sse_read_timeout_delta ,
17683 )
177- streams = await self .exit_stack .enter_async_context (streams_context )
178- # For StreamableHTTP client, we get read_stream, write_stream, and get_session_id (3 values)
179- read_stream , write_stream , _ = streams
180- self .session = await self .exit_stack .enter_async_context (
181- ClientSession (
182- read_stream , write_stream , read_timeout_seconds = timeout_delta
183- )
84+ else :
85+ transport = SSETransport (
86+ url = server_url ,
87+ headers = headers if headers else None ,
18488 )
185- await self ._initialize_and_list_tools ()
18689
187- # Apply timeout to the entire connection process
188- await asyncio .wait_for (connect_with_timeout (), timeout = timeout )
189- except asyncio .TimeoutError :
190- logger .error (
191- f'Connection to { server_url } timed out after { timeout } seconds'
192- )
193- await self .disconnect () # Clean up resources
194- raise # Re-raise the TimeoutError
90+ self .client = Client (transport , timeout = timeout )
91+
92+ await self ._initialize_and_list_tools ()
93+ except McpError as e :
94+ logger .error (f'McpError connecting to { server_url } : { e } ' )
95+ raise # Re-raise the error
96+
19597 except Exception as e :
196- logger .error (f'Error connecting to { server_url } : { str (e )} ' )
197- await self .disconnect () # Clean up resources
98+ logger .error (f'Error connecting to { server_url } : { e } ' )
19899 raise
199100
200- async def disconnect (self ) -> None :
201- """Disconnect from the MCP server and clean up resources."""
202- if self .session :
203- try :
204- # Close the session first
205- if hasattr (self .session , 'close' ):
206- await self .session .close ()
207- # Then close the exit stack
208- await self .exit_stack .aclose ()
209- except Exception as e :
210- logger .error (f'Error during disconnect: { str (e )} ' )
211- finally :
212- self .session = None
213- self .tools = []
214- logger .info ('Disconnected from MCP server' )
101+ async def call_tool (self , tool_name : str , args : dict ) -> CallToolResult :
102+ """Call a tool on the MCP server."""
103+ if tool_name not in self .tool_map :
104+ raise ValueError (f'Tool { tool_name } not found.' )
105+ # The MCPClientTool is primarily for metadata; use the session to call the actual tool.
106+ if not self .client :
107+ raise RuntimeError ('Client session is not available.' )
108+
109+ async with self .client :
110+ return await self .client .call_tool_mcp (name = tool_name , arguments = args )
0 commit comments