diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index a50a82b438..ac792eaf3b 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -688,9 +688,11 @@ dependencies = [ "codex-file-search", "codex-mcp-client", "codex-protocol", + "codex-rmcp-client", "core_test_support", "dirs", "env-flags", + "escargot", "eventsource-stream", "futures", "indexmap 2.10.0", @@ -949,6 +951,20 @@ dependencies = [ "zeroize", ] +[[package]] +name = "codex-rmcp-client" +version = "0.0.0" +dependencies = [ + "anyhow", + "mcp-types", + "pretty_assertions", + "rmcp", + "serde", + "serde_json", + "tokio", + "tracing", +] + [[package]] name = "codex-tui" version = "0.0.0" @@ -1266,8 +1282,18 @@ version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" dependencies = [ - "darling_core", - "darling_macro", + "darling_core 0.20.11", + "darling_macro 0.20.11", +] + +[[package]] +name = "darling" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" +dependencies = [ + "darling_core 0.21.3", + "darling_macro 0.21.3", ] [[package]] @@ -1284,13 +1310,38 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "darling_core" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.11.1", + "syn 2.0.104", +] + [[package]] name = "darling_macro" version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ - "darling_core", + "darling_core 0.20.11", + "quote", + "syn 2.0.104", +] + +[[package]] +name = "darling_macro" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" +dependencies = [ + "darling_core 0.21.3", "quote", "syn 2.0.104", ] @@ -1667,6 +1718,17 @@ version = "3.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dea2df4cf52843e0452895c455a1a2cfbb842a1e7329671acf418fdc53ed4c59" +[[package]] +name = "escargot" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11c3aea32bc97b500c9ca6a72b768a26e558264303d101d3409cf6d57a9ed0cf" +dependencies = [ + "log", + "serde", + "serde_json", +] + [[package]] name = "event-listener" version = "5.4.0" @@ -2496,7 +2558,7 @@ version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "435d80800b936787d62688c927b6490e887c7ef5ff9ce922c6c6050fca75eb9a" dependencies = [ - "darling", + "darling 0.20.11", "indoc", "proc-macro2", "quote", @@ -2976,6 +3038,18 @@ dependencies = [ "libc", ] +[[package]] +name = "nix" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74523f3a35e05aba87a1d978330aef40f67b0304ac79c1c00b294c9830543db6" +dependencies = [ + "bitflags 2.9.1", + "cfg-if", + "cfg_aliases 0.2.1", + "libc", +] + [[package]] name = "nom" version = "7.1.3" @@ -3398,7 +3472,7 @@ dependencies = [ "lazy_static", "libc", "log", - "nix", + "nix 0.28.0", "serial2", "shared_library", "shell-words", @@ -3486,6 +3560,20 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "process-wrap" +version = "8.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3ef4f2f0422f23a82ec9f628ea2acd12871c81a9362b02c43c1aa86acfc3ba1" +dependencies = [ + "futures", + "indexmap 2.10.0", + "nix 0.30.1", + "tokio", + "tracing", + "windows", +] + [[package]] name = "pulldown-cmark" version = "0.10.3" @@ -3812,6 +3900,42 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rmcp" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "534fd1cd0601e798ac30545ff2b7f4a62c6f14edd4aaed1cc5eb1e85f69f09af" +dependencies = [ + "base64", + "chrono", + "futures", + "paste", + "pin-project-lite", + "process-wrap", + "rmcp-macros", + "schemars 1.0.4", + "serde", + "serde_json", + "thiserror 2.0.16", + "tokio", + "tokio-stream", + "tokio-util", + "tracing", +] + +[[package]] +name = "rmcp-macros" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ba777eb0e5f53a757e36f0e287441da0ab766564ba7201600eeb92a4753022e" +dependencies = [ + "darling 0.21.3", + "proc-macro2", + "quote", + "serde_json", + "syn 2.0.104", +] + [[package]] name = "rustc-demangle" version = "0.1.25" @@ -3905,7 +4029,7 @@ dependencies = [ "libc", "log", "memchr", - "nix", + "nix 0.28.0", "radix_trie", "unicode-segmentation", "unicode-width 0.1.14", @@ -3986,7 +4110,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fbf2ae1b8bc8e02df939598064d22402220cd5bbcca1c76f7d6a310974d5615" dependencies = [ "dyn-clone", - "schemars_derive", + "schemars_derive 0.8.22", "serde", "serde_json", ] @@ -4009,8 +4133,10 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "82d20c4491bc164fa2f6c5d44565947a52ad80b9505d8e36f8d54c27c739fcd0" dependencies = [ + "chrono", "dyn-clone", "ref-cast", + "schemars_derive 1.0.4", "serde", "serde_json", ] @@ -4027,6 +4153,18 @@ dependencies = [ "syn 2.0.104", ] +[[package]] +name = "schemars_derive" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33d020396d1d138dc19f1165df7545479dcd58d93810dc5d646a16e55abefa80" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn 2.0.104", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -4178,7 +4316,7 @@ version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de90945e6565ce0d9a25098082ed4ee4002e047cb59892c318d66821e14bb30f" dependencies = [ - "darling", + "darling 0.20.11", "proc-macro2", "quote", "syn 2.0.104", @@ -5502,6 +5640,28 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "windows" +version = "0.61.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893" +dependencies = [ + "windows-collections", + "windows-core", + "windows-future", + "windows-link 0.1.3", + "windows-numerics", +] + +[[package]] +name = "windows-collections" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" +dependencies = [ + "windows-core", +] + [[package]] name = "windows-core" version = "0.61.2" @@ -5515,6 +5675,17 @@ dependencies = [ "windows-strings", ] +[[package]] +name = "windows-future" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" +dependencies = [ + "windows-core", + "windows-link 0.1.3", + "windows-threading", +] + [[package]] name = "windows-implement" version = "0.60.0" @@ -5549,6 +5720,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "45e46c0661abb7180e7b9c281db115305d49ca1709ab8242adf09666d2173c65" +[[package]] +name = "windows-numerics" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" +dependencies = [ + "windows-core", + "windows-link 0.1.3", +] + [[package]] name = "windows-registry" version = "0.5.3" @@ -5676,6 +5857,15 @@ dependencies = [ "windows_x86_64_msvc 0.53.0", ] +[[package]] +name = "windows-threading" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6" +dependencies = [ + "windows-link 0.1.3", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.42.2" diff --git a/codex-rs/Cargo.toml b/codex-rs/Cargo.toml index 33b1b303f0..af111c6f19 100644 --- a/codex-rs/Cargo.toml +++ b/codex-rs/Cargo.toml @@ -18,6 +18,7 @@ members = [ "ollama", "protocol", "protocol-ts", + "rmcp-client", "responses-api-proxy", "tui", "utils/readiness", @@ -49,6 +50,7 @@ codex-mcp-client = { path = "mcp-client" } codex-mcp-server = { path = "mcp-server" } codex-ollama = { path = "ollama" } codex-protocol = { path = "protocol" } +codex-rmcp-client = { path = "rmcp-client" } codex-protocol-ts = { path = "protocol-ts" } codex-responses-api-proxy = { path = "responses-api-proxy" } codex-tui = { path = "tui" } @@ -82,6 +84,7 @@ dotenvy = "0.15.7" env-flags = "0.1.1" env_logger = "0.11.5" eventsource-stream = "0.2.3" +escargot = "0.5" futures = "0.3" icu_decimal = "2.0.0" icu_locale_core = "2.0.0" diff --git a/codex-rs/core/Cargo.toml b/codex-rs/core/Cargo.toml index a1e7876a26..9b1c4888a4 100644 --- a/codex-rs/core/Cargo.toml +++ b/codex-rs/core/Cargo.toml @@ -22,6 +22,7 @@ chrono = { workspace = true, features = ["serde"] } codex-apply-patch = { workspace = true } codex-file-search = { workspace = true } codex-mcp-client = { workspace = true } +codex-rmcp-client = { workspace = true } codex-protocol = { workspace = true } dirs = { workspace = true } env-flags = { workspace = true } @@ -82,6 +83,7 @@ openssl-sys = { workspace = true, features = ["vendored"] } [dev-dependencies] assert_cmd = { workspace = true } core_test_support = { workspace = true } +escargot = { workspace = true } maplit = { workspace = true } predicates = { workspace = true } pretty_assertions = { workspace = true } diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index 93850a6583..8b792887c0 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -377,7 +377,10 @@ impl Session { // - load history metadata let rollout_fut = RolloutRecorder::new(&config, rollout_params); - let mcp_fut = McpConnectionManager::new(config.mcp_servers.clone()); + let mcp_fut = McpConnectionManager::new( + config.mcp_servers.clone(), + config.use_experimental_use_rmcp_client, + ); let default_shell_fut = shell::default_user_shell(); let history_meta_fut = crate::message_history::history_metadata(&config); diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs index 508a3dc36f..5b5b60f8df 100644 --- a/codex-rs/core/src/config.rs +++ b/codex-rs/core/src/config.rs @@ -184,6 +184,10 @@ pub struct Config { /// If set to `true`, used only the experimental unified exec tool. pub use_experimental_unified_exec_tool: bool, + /// If set to `true`, use the experimental official Rust MCP client. + /// https://github.com/modelcontextprotocol/rust-sdk + pub use_experimental_use_rmcp_client: bool, + /// Include the `view_image` tool that lets the agent attach a local image path to context. pub include_view_image_tool: bool, @@ -693,6 +697,7 @@ pub struct ConfigToml { pub experimental_use_exec_command_tool: Option, pub experimental_use_unified_exec_tool: Option, + pub experimental_use_rmcp_client: Option, pub projects: Option>, @@ -1043,6 +1048,7 @@ impl Config { use_experimental_unified_exec_tool: cfg .experimental_use_unified_exec_tool .unwrap_or(false), + use_experimental_use_rmcp_client: cfg.experimental_use_rmcp_client.unwrap_or(false), include_view_image_tool, active_profile: active_profile_name, disable_paste_burst: cfg.disable_paste_burst.unwrap_or(false), @@ -1651,6 +1657,7 @@ model_verbosity = "high" tools_web_search_request: false, use_experimental_streamable_shell_tool: false, use_experimental_unified_exec_tool: false, + use_experimental_use_rmcp_client: false, include_view_image_tool: true, active_profile: Some("o3".to_string()), disable_paste_burst: false, @@ -1709,6 +1716,7 @@ model_verbosity = "high" tools_web_search_request: false, use_experimental_streamable_shell_tool: false, use_experimental_unified_exec_tool: false, + use_experimental_use_rmcp_client: false, include_view_image_tool: true, active_profile: Some("gpt3".to_string()), disable_paste_burst: false, @@ -1782,6 +1790,7 @@ model_verbosity = "high" tools_web_search_request: false, use_experimental_streamable_shell_tool: false, use_experimental_unified_exec_tool: false, + use_experimental_use_rmcp_client: false, include_view_image_tool: true, active_profile: Some("zdr".to_string()), disable_paste_burst: false, @@ -1841,6 +1850,7 @@ model_verbosity = "high" tools_web_search_request: false, use_experimental_streamable_shell_tool: false, use_experimental_unified_exec_tool: false, + use_experimental_use_rmcp_client: false, include_view_image_tool: true, active_profile: Some("gpt5".to_string()), disable_paste_burst: false, diff --git a/codex-rs/core/src/mcp_connection_manager.rs b/codex-rs/core/src/mcp_connection_manager.rs index e9c95fc80b..5648e20b3b 100644 --- a/codex-rs/core/src/mcp_connection_manager.rs +++ b/codex-rs/core/src/mcp_connection_manager.rs @@ -16,6 +16,7 @@ use anyhow::Context; use anyhow::Result; use anyhow::anyhow; use codex_mcp_client::McpClient; +use codex_rmcp_client::RmcpClient; use mcp_types::ClientCapabilities; use mcp_types::Implementation; use mcp_types::Tool; @@ -86,11 +87,64 @@ struct ToolInfo { } struct ManagedClient { - client: Arc, + client: McpClientAdapter, startup_timeout: Duration, tool_timeout: Option, } +#[derive(Clone)] +enum McpClientAdapter { + Legacy(Arc), + Rmcp(Arc), +} + +impl McpClientAdapter { + async fn new_stdio_client( + use_rmcp_client: bool, + program: OsString, + args: Vec, + env: Option>, + params: mcp_types::InitializeRequestParams, + startup_timeout: Duration, + ) -> Result { + tracing::error!( + "new_stdio_client use_rmcp_client: {use_rmcp_client} program: {program:?} args: {args:?} env: {env:?} params: {params:?} startup_timeout: {startup_timeout:?}" + ); + if use_rmcp_client { + let client = Arc::new(RmcpClient::new_stdio_client(program, args, env).await?); + client.initialize(params, Some(startup_timeout)).await?; + Ok(McpClientAdapter::Rmcp(client)) + } else { + let client = Arc::new(McpClient::new_stdio_client(program, args, env).await?); + client.initialize(params, Some(startup_timeout)).await?; + Ok(McpClientAdapter::Legacy(client)) + } + } + + async fn list_tools( + &self, + params: Option, + timeout: Option, + ) -> Result { + match self { + McpClientAdapter::Legacy(client) => client.list_tools(params, timeout).await, + McpClientAdapter::Rmcp(client) => client.list_tools(params, timeout).await, + } + } + + async fn call_tool( + &self, + name: String, + arguments: Option, + timeout: Option, + ) -> Result { + match self { + McpClientAdapter::Legacy(client) => client.call_tool(name, arguments, timeout).await, + McpClientAdapter::Rmcp(client) => client.call_tool(name, arguments, timeout).await, + } + } +} + /// A thin wrapper around a set of running [`McpClient`] instances. #[derive(Default)] pub(crate) struct McpConnectionManager { @@ -115,12 +169,15 @@ impl McpConnectionManager { /// user should be informed about these errors. pub async fn new( mcp_servers: HashMap, + use_rmcp_client: bool, ) -> Result<(Self, ClientStartErrors)> { // Early exit if no servers are configured. if mcp_servers.is_empty() { return Ok((Self::default(), ClientStartErrors::default())); } + tracing::error!("new mcp_servers: {mcp_servers:?} use_rmcp_client: {use_rmcp_client}"); + // Launch all configured servers concurrently. let mut join_set = JoinSet::new(); let mut errors = ClientStartErrors::new(); @@ -137,57 +194,48 @@ impl McpConnectionManager { } let startup_timeout = cfg.startup_timeout_sec.unwrap_or(DEFAULT_STARTUP_TIMEOUT); - let tool_timeout = cfg.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT); + let use_rmcp_client_flag = use_rmcp_client; join_set.spawn(async move { let McpServerConfig { command, args, env, .. } = cfg; - let client_res = McpClient::new_stdio_client( - command.into(), - args.into_iter().map(OsString::from).collect(), + let command_os: OsString = command.into(); + let args_os: Vec = args.into_iter().map(Into::into).collect(); + let params = mcp_types::InitializeRequestParams { + capabilities: ClientCapabilities { + experimental: None, + roots: None, + sampling: None, + // https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities + // indicates this should be an empty object. + elicitation: Some(json!({})), + }, + client_info: Implementation { + name: "codex-mcp-client".to_owned(), + version: env!("CARGO_PKG_VERSION").to_owned(), + title: Some("Codex".into()), + // This field is used by Codex when it is an MCP + // server: it should not be used when Codex is + // an MCP client. + user_agent: None, + }, + protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(), + }; + + let client = McpClientAdapter::new_stdio_client( + use_rmcp_client_flag, + command_os, + args_os, env, + params, + startup_timeout, ) - .await; - match client_res { - Ok(client) => { - // Initialize the client. - let params = mcp_types::InitializeRequestParams { - capabilities: ClientCapabilities { - experimental: None, - roots: None, - sampling: None, - // https://modelcontextprotocol.io/specification/2025-06-18/client/elicitation#capabilities - // indicates this should be an empty object. - elicitation: Some(json!({})), - }, - client_info: Implementation { - name: "codex-mcp-client".to_owned(), - version: env!("CARGO_PKG_VERSION").to_owned(), - title: Some("Codex".into()), - // This field is used by Codex when it is an MCP - // server: it should not be used when Codex is - // an MCP client. - user_agent: None, - }, - protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(), - }; - let initialize_notification_params = None; - let init_result = client - .initialize( - params, - initialize_notification_params, - Some(startup_timeout), - ) - .await; - ( - (server_name, tool_timeout), - init_result.map(|_| (client, startup_timeout)), - ) - } - Err(e) => ((server_name, tool_timeout), Err(e.into())), - } + .await + .map(|c| (c, startup_timeout)); + + ((server_name, tool_timeout), client) }); } @@ -207,7 +255,7 @@ impl McpConnectionManager { clients.insert( server_name, ManagedClient { - client: Arc::new(client), + client, startup_timeout, tool_timeout: Some(tool_timeout), }, diff --git a/codex-rs/core/tests/suite/mod.rs b/codex-rs/core/tests/suite/mod.rs index 0e4e725c36..5dd2cb673e 100644 --- a/codex-rs/core/tests/suite/mod.rs +++ b/codex-rs/core/tests/suite/mod.rs @@ -14,6 +14,7 @@ mod live_cli; mod model_overrides; mod prompt_caching; mod review; +mod rmcp_client; mod rollout_list_find; mod seatbelt; mod stream_error_allows_next_turn; diff --git a/codex-rs/core/tests/suite/rmcp_client.rs b/codex-rs/core/tests/suite/rmcp_client.rs new file mode 100644 index 0000000000..2ebe9f011c --- /dev/null +++ b/codex-rs/core/tests/suite/rmcp_client.rs @@ -0,0 +1,162 @@ +use std::collections::HashMap; +use std::time::Duration; + +use codex_core::config_types::McpServerConfig; +use codex_core::protocol::AskForApproval; +use codex_core::protocol::EventMsg; +use codex_core::protocol::InputItem; +use codex_core::protocol::Op; +use codex_core::protocol::SandboxPolicy; +use codex_protocol::config_types::ReasoningSummary; +use core_test_support::responses; +use core_test_support::responses::mount_sse_once; +use core_test_support::skip_if_no_network; +use core_test_support::test_codex::test_codex; +use core_test_support::wait_for_event; +use core_test_support::wait_for_event_with_timeout; +use escargot::CargoBuild; +use serde_json::Value; +use wiremock::matchers::any; + +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn rmcp_tool_call_round_trip() -> anyhow::Result<()> { + skip_if_no_network!(Ok(())); + + let server = responses::start_mock_server().await; + + let call_id = "call-123"; + let server_name = "rmcp"; + let tool_name = format!("{server_name}__echo"); + + mount_sse_once( + &server, + any(), + responses::sse(vec![ + serde_json::json!({ + "type": "response.created", + "response": {"id": "resp-1"} + }), + responses::ev_function_call(call_id, &tool_name, "{\"message\":\"ping\"}"), + responses::ev_completed("resp-1"), + ]), + ) + .await; + mount_sse_once( + &server, + any(), + responses::sse(vec![ + responses::ev_assistant_message("msg-1", "rmcp echo tool completed successfully."), + responses::ev_completed("resp-2"), + ]), + ) + .await; + + let expected_env_value = "propagated-env"; + let rmcp_test_server_bin = CargoBuild::new() + .package("codex-rmcp-client") + .bin("rmcp_test_server") + .run()? + .path() + .to_string_lossy() + .into_owned(); + + let fixture = test_codex() + .with_config(move |config| { + config.use_experimental_use_rmcp_client = true; + config.mcp_servers.insert( + server_name.to_string(), + McpServerConfig { + command: rmcp_test_server_bin.clone(), + args: Vec::new(), + env: Some(HashMap::from([( + "MCP_TEST_VALUE".to_string(), + expected_env_value.to_string(), + )])), + startup_timeout_sec: Some(Duration::from_secs(10)), + tool_timeout_sec: None, + }, + ); + }) + .build(&server) + .await?; + let session_model = fixture.session_configured.model.clone(); + + fixture + .codex + .submit(Op::UserTurn { + items: vec![InputItem::Text { + text: "call the rmcp echo tool".into(), + }], + final_output_json_schema: None, + cwd: fixture.cwd.path().to_path_buf(), + approval_policy: AskForApproval::Never, + sandbox_policy: SandboxPolicy::DangerFullAccess, + model: session_model, + effort: None, + summary: ReasoningSummary::Auto, + }) + .await?; + + eprintln!("waiting for mcp tool call begin event"); + let begin_event = wait_for_event_with_timeout( + &fixture.codex, + |ev| { + eprintln!("ev: {ev:?}"); + matches!(ev, EventMsg::McpToolCallBegin(_)) + }, + Duration::from_secs(10), + ) + .await; + + eprintln!("mcp tool call begin event: {begin_event:?}"); + let EventMsg::McpToolCallBegin(begin) = begin_event else { + unreachable!("event guard guarantees McpToolCallBegin"); + }; + assert_eq!(begin.invocation.server, server_name); + assert_eq!(begin.invocation.tool, "echo"); + + let end_event = wait_for_event(&fixture.codex, |ev| { + matches!(ev, EventMsg::McpToolCallEnd(_)) + }) + .await; + eprintln!("end_event: {end_event:?}"); + let EventMsg::McpToolCallEnd(end) = end_event else { + unreachable!("event guard guarantees McpToolCallEnd"); + }; + + let result = end + .result + .as_ref() + .expect("rmcp echo tool should return success"); + assert_eq!(result.is_error, Some(false)); + assert!( + result.content.is_empty(), + "content should default to an empty array" + ); + + let structured = result + .structured_content + .as_ref() + .expect("structured content"); + let Value::Object(map) = structured else { + panic!("structured content should be an object: {structured:?}"); + }; + let echo_value = map + .get("echo") + .and_then(Value::as_str) + .expect("echo payload present"); + assert_eq!(echo_value, "ping"); + let env_value = map + .get("env") + .and_then(Value::as_str) + .expect("env snapshot inserted"); + assert_eq!(env_value, expected_env_value); + + let task_complete_event = + wait_for_event(&fixture.codex, |ev| matches!(ev, EventMsg::TaskComplete(_))).await; + eprintln!("task_complete_event: {task_complete_event:?}"); + + server.verify().await; + + Ok(()) +} diff --git a/codex-rs/mcp-client/src/main.rs b/codex-rs/mcp-client/src/main.rs index d25bca4ba3..f46058b99e 100644 --- a/codex-rs/mcp-client/src/main.rs +++ b/codex-rs/mcp-client/src/main.rs @@ -70,11 +70,8 @@ async fn main() -> Result<()> { }, protocol_version: MCP_SCHEMA_VERSION.to_owned(), }; - let initialize_notification_params = None; let timeout = Some(Duration::from_secs(10)); - let response = client - .initialize(params, initialize_notification_params, timeout) - .await?; + let response = client.initialize(params, timeout).await?; eprintln!("initialize response: {response:?}"); // Issue `tools/list` request (no params). diff --git a/codex-rs/mcp-client/src/mcp_client.rs b/codex-rs/mcp-client/src/mcp_client.rs index 505df6bd4e..087335e66b 100644 --- a/codex-rs/mcp-client/src/mcp_client.rs +++ b/codex-rs/mcp-client/src/mcp_client.rs @@ -315,13 +315,12 @@ impl McpClient { pub async fn initialize( &self, initialize_params: InitializeRequestParams, - initialize_notification_params: Option, timeout: Option, ) -> Result { let response = self .send_request::(initialize_params, timeout) .await?; - self.send_notification::(initialize_notification_params) + self.send_notification::(None) .await?; Ok(response) } diff --git a/codex-rs/rmcp-client/Cargo.toml b/codex-rs/rmcp-client/Cargo.toml new file mode 100644 index 0000000000..da9989e531 --- /dev/null +++ b/codex-rs/rmcp-client/Cargo.toml @@ -0,0 +1,34 @@ +[package] +edition = "2024" +name = "codex-rmcp-client" +version = { workspace = true } + +[lints] +workspace = true + +[dependencies] +anyhow = "1" +mcp-types = { path = "../mcp-types" } +rmcp = { version = "0.7.0", default-features = false, features = [ + "base64", + "client", + "macros", + "schemars", + "server", + "transport-child-process", +] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +tokio = { version = "1", features = [ + "io-util", + "macros", + "process", + "rt-multi-thread", + "sync", + "io-std", + "time", +] } +tracing = { version = "0.1.41", features = ["log"] } + +[dev-dependencies] +pretty_assertions = "1.4.1" diff --git a/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs b/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs new file mode 100644 index 0000000000..23b2f93b38 --- /dev/null +++ b/codex-rs/rmcp-client/src/bin/rmcp_test_server.rs @@ -0,0 +1,142 @@ +use std::borrow::Cow; +use std::collections::HashMap; +use std::sync::Arc; + +use rmcp::ErrorData as McpError; +use rmcp::ServiceExt; +use rmcp::handler::server::ServerHandler; +use rmcp::model::CallToolRequestParam; +use rmcp::model::CallToolResult; +use rmcp::model::JsonObject; +use rmcp::model::ListToolsResult; +use rmcp::model::PaginatedRequestParam; +use rmcp::model::ServerCapabilities; +use rmcp::model::ServerInfo; +use rmcp::model::Tool; +use serde::Deserialize; +use serde_json::json; +use tokio::task; + +#[derive(Clone)] +struct TestToolServer { + tools: Arc>, +} +pub fn stdio() -> (tokio::io::Stdin, tokio::io::Stdout) { + (tokio::io::stdin(), tokio::io::stdout()) +} +impl TestToolServer { + fn new() -> Self { + let tools = vec![Self::echo_tool()]; + Self { + tools: Arc::new(tools), + } + } + + fn echo_tool() -> Tool { + #[expect(clippy::expect_used)] + let schema: JsonObject = serde_json::from_value(json!({ + "type": "object", + "properties": { + "message": { "type": "string" }, + "env_var": { "type": "string" } + }, + "required": ["message"], + "additionalProperties": false + })) + .expect("echo tool schema should deserialize"); + + Tool::new( + Cow::Borrowed("echo"), + Cow::Borrowed("Echo back the provided message and include environment data."), + Arc::new(schema), + ) + } +} + +#[derive(Deserialize)] +struct EchoArgs { + message: String, + #[allow(dead_code)] + env_var: Option, +} + +impl ServerHandler for TestToolServer { + fn get_info(&self) -> ServerInfo { + ServerInfo { + capabilities: ServerCapabilities::builder() + .enable_tools() + .enable_tool_list_changed() + .build(), + ..ServerInfo::default() + } + } + + fn list_tools( + &self, + _request: Option, + _context: rmcp::service::RequestContext, + ) -> impl std::future::Future> + Send + '_ { + let tools = self.tools.clone(); + async move { + Ok(ListToolsResult { + tools: (*tools).clone(), + next_cursor: None, + }) + } + } + + async fn call_tool( + &self, + request: CallToolRequestParam, + _context: rmcp::service::RequestContext, + ) -> Result { + match request.name.as_ref() { + "echo" => { + let args: EchoArgs = match request.arguments { + Some(arguments) => serde_json::from_value(serde_json::Value::Object( + arguments.into_iter().collect(), + )) + .map_err(|err| McpError::invalid_params(err.to_string(), None))?, + None => { + return Err(McpError::invalid_params( + "missing arguments for echo tool", + None, + )); + } + }; + + let env_snapshot: HashMap = std::env::vars().collect(); + let structured_content = json!({ + "echo": args.message, + "env": env_snapshot.get("MCP_TEST_VALUE"), + }); + + Ok(CallToolResult { + content: Vec::new(), + structured_content: Some(structured_content), + is_error: Some(false), + meta: None, + }) + } + other => Err(McpError::invalid_params( + format!("unknown tool: {other}"), + None, + )), + } + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + eprintln!("starting rmcp test server"); + // Run the server with STDIO transport. If the client disconnects we simply + // bubble up the error so the process exits. + let service = TestToolServer::new(); + let running = service.serve(stdio()).await?; + + // Wait for the client to finish interacting with the server. + running.waiting().await?; + // Drain background tasks to ensure clean shutdown. + task::yield_now().await; + Ok(()) +} diff --git a/codex-rs/rmcp-client/src/lib.rs b/codex-rs/rmcp-client/src/lib.rs new file mode 100644 index 0000000000..ef5088406c --- /dev/null +++ b/codex-rs/rmcp-client/src/lib.rs @@ -0,0 +1,5 @@ +mod logging_client_handler; +mod rmcp_client; +mod utils; + +pub use rmcp_client::RmcpClient; diff --git a/codex-rs/rmcp-client/src/logging_client_handler.rs b/codex-rs/rmcp-client/src/logging_client_handler.rs new file mode 100644 index 0000000000..85d237b0e9 --- /dev/null +++ b/codex-rs/rmcp-client/src/logging_client_handler.rs @@ -0,0 +1,134 @@ +use rmcp::ClientHandler; +use rmcp::RoleClient; +use rmcp::model::CancelledNotificationParam; +use rmcp::model::ClientInfo; +use rmcp::model::CreateElicitationRequestParam; +use rmcp::model::CreateElicitationResult; +use rmcp::model::ElicitationAction; +use rmcp::model::LoggingLevel; +use rmcp::model::LoggingMessageNotificationParam; +use rmcp::model::ProgressNotificationParam; +use rmcp::model::ResourceUpdatedNotificationParam; +use rmcp::service::NotificationContext; +use rmcp::service::RequestContext; +use tracing::debug; +use tracing::error; +use tracing::info; +use tracing::warn; + +#[derive(Debug, Clone)] +pub(crate) struct LoggingClientHandler { + client_info: ClientInfo, +} + +impl LoggingClientHandler { + pub(crate) fn new(client_info: ClientInfo) -> Self { + Self { client_info } + } +} + +impl ClientHandler for LoggingClientHandler { + // TODO (CODEX-3571): support elicitations. + async fn create_elicitation( + &self, + request: CreateElicitationRequestParam, + _context: RequestContext, + ) -> Result { + info!( + "MCP server requested elicitation ({}). Elicitations are not supported yet. Declining.", + request.message + ); + Ok(CreateElicitationResult { + action: ElicitationAction::Decline, + content: None, + }) + } + + async fn on_cancelled( + &self, + params: CancelledNotificationParam, + _context: NotificationContext, + ) { + info!( + "MCP server cancelled request (request_id: {}, reason: {:?})", + params.request_id, params.reason + ); + } + + async fn on_progress( + &self, + params: ProgressNotificationParam, + _context: NotificationContext, + ) { + info!( + "MCP server progress notification (token: {:?}, progress: {}, total: {:?}, message: {:?})", + params.progress_token, params.progress, params.total, params.message + ); + } + + async fn on_resource_updated( + &self, + params: ResourceUpdatedNotificationParam, + _context: NotificationContext, + ) { + info!("MCP server resource updated (uri: {})", params.uri); + } + + async fn on_resource_list_changed(&self, _context: NotificationContext) { + info!("MCP server resource list changed"); + } + + async fn on_tool_list_changed(&self, _context: NotificationContext) { + info!("MCP server tool list changed"); + } + + async fn on_prompt_list_changed(&self, _context: NotificationContext) { + info!("MCP server prompt list changed"); + } + + fn get_info(&self) -> ClientInfo { + self.client_info.clone() + } + + async fn on_logging_message( + &self, + params: LoggingMessageNotificationParam, + _context: NotificationContext, + ) { + let LoggingMessageNotificationParam { + level, + logger, + data, + } = params; + let logger = logger.as_deref(); + match level { + LoggingLevel::Emergency + | LoggingLevel::Alert + | LoggingLevel::Critical + | LoggingLevel::Error => { + error!( + "MCP server log message (level: {:?}, logger: {:?}, data: {})", + level, logger, data + ); + } + LoggingLevel::Warning => { + warn!( + "MCP server log message (level: {:?}, logger: {:?}, data: {})", + level, logger, data + ); + } + LoggingLevel::Notice | LoggingLevel::Info => { + info!( + "MCP server log message (level: {:?}, logger: {:?}, data: {})", + level, logger, data + ); + } + LoggingLevel::Debug => { + debug!( + "MCP server log message (level: {:?}, logger: {:?}, data: {})", + level, logger, data + ); + } + } + } +} diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs new file mode 100644 index 0000000000..c7ac1ecc9a --- /dev/null +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -0,0 +1,183 @@ +use std::collections::HashMap; +use std::ffi::OsString; +use std::io; +use std::process::Stdio; +use std::sync::Arc; +use std::time::Duration; + +use anyhow::Result; +use anyhow::anyhow; +use mcp_types::CallToolRequestParams; +use mcp_types::CallToolResult; +use mcp_types::InitializeRequestParams; +use mcp_types::InitializeResult; +use mcp_types::ListToolsRequestParams; +use mcp_types::ListToolsResult; +use rmcp::model::CallToolRequestParam; +use rmcp::model::InitializeRequestParam; +use rmcp::model::PaginatedRequestParam; +use rmcp::service::RoleClient; +use rmcp::service::RunningService; +use rmcp::service::{self}; +use rmcp::transport::child_process::TokioChildProcess; +use tokio::io::AsyncBufReadExt; +use tokio::io::BufReader; +use tokio::process::Command; +use tokio::sync::Mutex; +use tokio::time; +use tracing::info; +use tracing::warn; + +use crate::logging_client_handler::LoggingClientHandler; +use crate::utils::convert_call_tool_result; +use crate::utils::convert_to_mcp; +use crate::utils::convert_to_rmcp; +use crate::utils::create_env_for_mcp_server; +use crate::utils::run_with_timeout; + +enum ClientState { + Connecting { + transport: Option, + }, + Ready { + service: Arc>, + }, +} + +/// MCP client implemented on top of the official `rmcp` SDK. +/// https://github.com/modelcontextprotocol/rust-sdk +pub struct RmcpClient { + state: Mutex, +} + +impl RmcpClient { + pub async fn new_stdio_client( + program: OsString, + args: Vec, + env: Option>, + ) -> io::Result { + let program_name = program.to_string_lossy().into_owned(); + let mut command = Command::new(&program); + command + .kill_on_drop(true) + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .env_clear() + .envs(create_env_for_mcp_server(env)) + .args(&args); + + let (transport, stderr) = TokioChildProcess::builder(command) + .stderr(Stdio::piped()) + .spawn()?; + + if let Some(stderr) = stderr { + tokio::spawn(async move { + let mut reader = BufReader::new(stderr).lines(); + loop { + match reader.next_line().await { + Ok(Some(line)) => { + info!("MCP server stderr ({program_name}): {line}"); + } + Ok(None) => break, + Err(error) => { + warn!("Failed to read MCP server stderr ({program_name}): {error}"); + break; + } + } + } + }); + } + + Ok(Self { + state: Mutex::new(ClientState::Connecting { + transport: Some(transport), + }), + }) + } + + /// Perform the initialization handshake with the MCP server. + /// https://modelcontextprotocol.io/specification/2025-06-18/basic/lifecycle#initialization + pub async fn initialize( + &self, + params: InitializeRequestParams, + timeout: Option, + ) -> Result { + let transport = { + let mut guard = self.state.lock().await; + match &mut *guard { + ClientState::Connecting { transport } => transport + .take() + .ok_or_else(|| anyhow!("client already initializing"))?, + ClientState::Ready { .. } => { + return Err(anyhow!("client already initialized")); + } + } + }; + + let client_info = convert_to_rmcp::<_, InitializeRequestParam>(params.clone())?; + let client_handler = LoggingClientHandler::new(client_info); + let service_future = service::serve_client(client_handler, transport); + + let service = match timeout { + Some(duration) => time::timeout(duration, service_future) + .await + .map_err(|_| anyhow!("timed out handshaking with MCP server after {duration:?}"))? + .map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?, + None => service_future + .await + .map_err(|err| anyhow!("handshaking with MCP server failed: {err}"))?, + }; + + let initialize_result_rmcp = service + .peer() + .peer_info() + .ok_or_else(|| anyhow!("handshake succeeded but server info was missing"))?; + let initialize_result = convert_to_mcp(initialize_result_rmcp)?; + + { + let mut guard = self.state.lock().await; + *guard = ClientState::Ready { + service: Arc::new(service), + }; + } + + Ok(initialize_result) + } + + pub async fn list_tools( + &self, + params: Option, + timeout: Option, + ) -> Result { + let service = self.service().await?; + let rmcp_params = params + .map(convert_to_rmcp::<_, PaginatedRequestParam>) + .transpose()?; + + let fut = service.list_tools(rmcp_params); + let result = run_with_timeout(fut, timeout, "tools/list").await?; + convert_to_mcp(result) + } + + pub async fn call_tool( + &self, + name: String, + arguments: Option, + timeout: Option, + ) -> Result { + let service = self.service().await?; + let params = CallToolRequestParams { arguments, name }; + let rmcp_params: CallToolRequestParam = convert_to_rmcp(params)?; + let fut = service.call_tool(rmcp_params); + let rmcp_result = run_with_timeout(fut, timeout, "tools/call").await?; + convert_call_tool_result(rmcp_result) + } + + async fn service(&self) -> Result>> { + let guard = self.state.lock().await; + match &*guard { + ClientState::Ready { service } => Ok(Arc::clone(service)), + ClientState::Connecting { .. } => Err(anyhow!("MCP client not initialized")), + } + } +} diff --git a/codex-rs/rmcp-client/src/utils.rs b/codex-rs/rmcp-client/src/utils.rs new file mode 100644 index 0000000000..6b7bd89424 --- /dev/null +++ b/codex-rs/rmcp-client/src/utils.rs @@ -0,0 +1,160 @@ +use std::collections::HashMap; +use std::env; +use std::time::Duration; + +use anyhow::Context; +use anyhow::Result; +use anyhow::anyhow; +use mcp_types::CallToolResult; +use rmcp::model::CallToolResult as RmcpCallToolResult; +use rmcp::service::ServiceError; +use serde_json::Value; +use tokio::time; + +pub(crate) async fn run_with_timeout( + fut: F, + timeout: Option, + label: &str, +) -> Result +where + F: std::future::Future>, +{ + if let Some(duration) = timeout { + let result = time::timeout(duration, fut) + .await + .with_context(|| anyhow!("timed out awaiting {label} after {duration:?}"))?; + result.map_err(|err| anyhow!("{label} failed: {err}")) + } else { + fut.await.map_err(|err| anyhow!("{label} failed: {err}")) + } +} + +pub(crate) fn convert_call_tool_result(result: RmcpCallToolResult) -> Result { + let mut value = serde_json::to_value(result)?; + if let Some(obj) = value.as_object_mut() + && (obj.get("content").is_none() + || obj.get("content").is_some_and(serde_json::Value::is_null)) + { + obj.insert("content".to_string(), Value::Array(Vec::new())); + } + serde_json::from_value(value).context("failed to convert call tool result") +} + +/// Convert from mcp-types to Rust SDK types. +/// +/// The Rust SDK types are the same as our mcp-types crate because they are both +/// derived from the same MCP specification. +/// As a result, it should be safe to convert directly from one to the other. +pub(crate) fn convert_to_rmcp(value: T) -> Result +where + T: serde::Serialize, + U: serde::de::DeserializeOwned, +{ + let json = serde_json::to_value(value)?; + serde_json::from_value(json).map_err(|err| anyhow!(err)) +} + +/// Convert from Rust SDK types to mcp-types. +/// +/// The Rust SDK types are the same as our mcp-types crate because they are both +/// derived from the same MCP specification. +/// As a result, it should be safe to convert directly from one to the other. +pub(crate) fn convert_to_mcp(value: T) -> Result +where + T: serde::Serialize, + U: serde::de::DeserializeOwned, +{ + let json = serde_json::to_value(value)?; + serde_json::from_value(json).map_err(|err| anyhow!(err)) +} + +pub(crate) fn create_env_for_mcp_server( + extra_env: Option>, +) -> HashMap { + DEFAULT_ENV_VARS + .iter() + .filter_map(|var| env::var(var).ok().map(|value| (var.to_string(), value))) + .chain(extra_env.unwrap_or_default()) + .collect() +} + +#[cfg(unix)] +pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[ + "HOME", + "LOGNAME", + "PATH", + "SHELL", + "USER", + "__CF_USER_TEXT_ENCODING", + "LANG", + "LC_ALL", + "TERM", + "TMPDIR", + "TZ", +]; + +#[cfg(windows)] +pub(crate) const DEFAULT_ENV_VARS: &[&str] = &[ + "PATH", + "PATHEXT", + "USERNAME", + "USERDOMAIN", + "USERPROFILE", + "TEMP", + "TMP", +]; + +#[cfg(test)] +mod tests { + use super::*; + use mcp_types::ContentBlock; + use pretty_assertions::assert_eq; + use rmcp::model::CallToolResult as RmcpCallToolResult; + use serde_json::json; + + #[tokio::test] + async fn create_env_honors_overrides() { + let value = "custom".to_string(); + let env = create_env_for_mcp_server(Some(HashMap::from([("TZ".into(), value.clone())]))); + assert_eq!(env.get("TZ"), Some(&value)); + } + + #[test] + fn convert_call_tool_result_defaults_missing_content() -> Result<()> { + let structured_content = json!({ "key": "value" }); + let rmcp_result = RmcpCallToolResult { + content: vec![], + structured_content: Some(structured_content.clone()), + is_error: Some(true), + meta: None, + }; + + let result = convert_call_tool_result(rmcp_result)?; + + assert!(result.content.is_empty()); + assert_eq!(result.structured_content, Some(structured_content)); + assert_eq!(result.is_error, Some(true)); + + Ok(()) + } + + #[test] + fn convert_call_tool_result_preserves_existing_content() -> Result<()> { + let rmcp_result = RmcpCallToolResult::success(vec![rmcp::model::Content::text("hello")]); + + let result = convert_call_tool_result(rmcp_result)?; + + assert_eq!(result.content.len(), 1); + match &result.content[0] { + ContentBlock::TextContent(text_content) => { + assert_eq!(text_content.text, "hello"); + assert_eq!(text_content.r#type, "text"); + } + other => panic!("expected text content got {other:?}"), + } + assert_eq!(result.structured_content, None); + assert_eq!(result.is_error, Some(false)); + + Ok(()) + } +} diff --git a/codex-rs/tui/src/lib.rs b/codex-rs/tui/src/lib.rs index 998ecb7434..25eec9d86e 100644 --- a/codex-rs/tui/src/lib.rs +++ b/codex-rs/tui/src/lib.rs @@ -223,8 +223,9 @@ pub async fn run_main( // use RUST_LOG env var, default to info for codex crates. let env_filter = || { - EnvFilter::try_from_default_env() - .unwrap_or_else(|_| EnvFilter::new("codex_core=info,codex_tui=info")) + EnvFilter::try_from_default_env().unwrap_or_else(|_| { + EnvFilter::new("codex_core=info,codex_tui=info,codex_rmcp_client=info") + }) }; // Build layered subscriber: