Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions codex-rs/cli/src/mcp_cmd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ fn run_add(config_overrides: &CliConfigOverrides, add_args: AddArgs) -> Result<(
command: command_bin,
args: command_args,
env: env_map,
startup_timeout_ms: None,
startup_timeout_sec: None,
tool_timeout_sec: None,
};

servers.insert(name.clone(), new_entry);
Expand Down Expand Up @@ -210,7 +211,12 @@ fn run_list(config_overrides: &CliConfigOverrides, list_args: ListArgs) -> Resul
"command": cfg.command,
"args": cfg.args,
"env": env,
"startup_timeout_ms": cfg.startup_timeout_ms,
"startup_timeout_sec": cfg
.startup_timeout_sec
.map(|timeout| timeout.as_secs_f64()),
"tool_timeout_sec": cfg
.tool_timeout_sec
.map(|timeout| timeout.as_secs_f64()),
})
})
.collect();
Expand Down Expand Up @@ -305,7 +311,12 @@ fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Result<(
"command": server.command,
"args": server.args,
"env": env,
"startup_timeout_ms": server.startup_timeout_ms,
"startup_timeout_sec": server
.startup_timeout_sec
.map(|timeout| timeout.as_secs_f64()),
"tool_timeout_sec": server
.tool_timeout_sec
.map(|timeout| timeout.as_secs_f64()),
}))?;
println!("{output}");
return Ok(());
Expand Down Expand Up @@ -333,8 +344,11 @@ fn run_get(config_overrides: &CliConfigOverrides, get_args: GetArgs) -> Result<(
}
};
println!(" env: {env_display}");
if let Some(timeout) = server.startup_timeout_ms {
println!(" startup_timeout_ms: {timeout}");
if let Some(timeout) = server.startup_timeout_sec {
println!(" startup_timeout_sec: {}", timeout.as_secs_f64());
}
if let Some(timeout) = server.tool_timeout_sec {
println!(" tool_timeout_sec: {}", timeout.as_secs_f64());
}
println!(" remove: codex mcp remove {}", get_args.name);

Expand Down
10 changes: 2 additions & 8 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -991,10 +991,9 @@ impl Session {
server: &str,
tool: &str,
arguments: Option<serde_json::Value>,
timeout: Option<Duration>,
) -> anyhow::Result<CallToolResult> {
self.mcp_connection_manager
.call_tool(server, tool, arguments, timeout)
.call_tool(server, tool, arguments)
.await
}

Expand Down Expand Up @@ -2574,12 +2573,7 @@ async fn handle_function_call(
_ => {
match sess.mcp_connection_manager.parse_tool_name(&name) {
Some((server, tool_name)) => {
// TODO(mbolin): Determine appropriate timeout for tool call.
let timeout = None;
handle_mcp_tool_call(
sess, &sub_id, call_id, server, tool_name, arguments, timeout,
)
.await
handle_mcp_tool_call(sess, &sub_id, call_id, server, tool_name, arguments).await
}
None => {
// Unknown function: reply with structured failure so the model can adapt.
Expand Down
42 changes: 33 additions & 9 deletions codex-rs/core/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,14 +333,12 @@ pub fn write_global_mcp_servers(
entry["env"] = TomlItem::Table(env_table);
}

if let Some(timeout) = config.startup_timeout_ms {
let timeout = i64::try_from(timeout).map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::InvalidData,
"startup_timeout_ms exceeds supported range",
)
})?;
entry["startup_timeout_ms"] = toml_edit::value(timeout);
if let Some(timeout) = config.startup_timeout_sec {
entry["startup_timeout_sec"] = toml_edit::value(timeout.as_secs_f64());
}

if let Some(timeout) = config.tool_timeout_sec {
entry["tool_timeout_sec"] = toml_edit::value(timeout.as_secs_f64());
}

doc["mcp_servers"][name.as_str()] = TomlItem::Table(entry);
Expand Down Expand Up @@ -1168,6 +1166,7 @@ mod tests {
use super::*;
use pretty_assertions::assert_eq;

use std::time::Duration;
use tempfile::TempDir;

#[test]
Expand Down Expand Up @@ -1292,7 +1291,8 @@ exclude_slash_tmp = true
command: "echo".to_string(),
args: vec!["hello".to_string()],
env: None,
startup_timeout_ms: None,
startup_timeout_sec: Some(Duration::from_secs(3)),
tool_timeout_sec: Some(Duration::from_secs(5)),
},
);

Expand All @@ -1303,6 +1303,8 @@ exclude_slash_tmp = true
let docs = loaded.get("docs").expect("docs entry");
assert_eq!(docs.command, "echo");
assert_eq!(docs.args, vec!["hello".to_string()]);
assert_eq!(docs.startup_timeout_sec, Some(Duration::from_secs(3)));
assert_eq!(docs.tool_timeout_sec, Some(Duration::from_secs(5)));

let empty = BTreeMap::new();
write_global_mcp_servers(codex_home.path(), &empty)?;
Expand All @@ -1312,6 +1314,28 @@ exclude_slash_tmp = true
Ok(())
}

#[test]
fn load_global_mcp_servers_accepts_legacy_ms_field() -> anyhow::Result<()> {
let codex_home = TempDir::new()?;
let config_path = codex_home.path().join(CONFIG_TOML_FILE);

std::fs::write(
&config_path,
r#"
[mcp_servers]
[mcp_servers.docs]
command = "echo"
startup_timeout_ms = 2500
"#,
)?;

let servers = load_global_mcp_servers(codex_home.path())?;
let docs = servers.get("docs").expect("docs entry");
assert_eq!(docs.startup_timeout_sec, Some(Duration::from_millis(2500)));

Ok(())
}

#[tokio::test]
async fn persist_model_selection_updates_defaults() -> anyhow::Result<()> {
let codex_home = TempDir::new()?;
Expand Down
87 changes: 83 additions & 4 deletions codex-rs/core/src/config_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@

use std::collections::HashMap;
use std::path::PathBuf;
use std::time::Duration;
use wildmatch::WildMatchPattern;

use serde::Deserialize;
use serde::Deserializer;
use serde::Serialize;
use serde::de::Error as SerdeError;

#[derive(Deserialize, Debug, Clone, PartialEq)]
#[derive(Serialize, Debug, Clone, PartialEq)]
pub struct McpServerConfig {
pub command: String,

Expand All @@ -19,9 +23,84 @@ pub struct McpServerConfig {
#[serde(default)]
pub env: Option<HashMap<String, String>>,

/// Startup timeout in milliseconds for initializing MCP server & initially listing tools.
#[serde(default)]
pub startup_timeout_ms: Option<u64>,
/// Startup timeout in seconds for initializing MCP server & initially listing tools.
#[serde(
default,
with = "option_duration_secs",
skip_serializing_if = "Option::is_none"
)]
pub startup_timeout_sec: Option<Duration>,

/// Default timeout for MCP tool calls initiated via this server.
#[serde(default, with = "option_duration_secs")]
pub tool_timeout_sec: Option<Duration>,
}

impl<'de> Deserialize<'de> for McpServerConfig {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
struct RawMcpServerConfig {
command: String,
#[serde(default)]
args: Vec<String>,
#[serde(default)]
env: Option<HashMap<String, String>>,
#[serde(default)]
startup_timeout_sec: Option<f64>,
#[serde(default)]
startup_timeout_ms: Option<u64>,
#[serde(default, with = "option_duration_secs")]
tool_timeout_sec: Option<Duration>,
}

let raw = RawMcpServerConfig::deserialize(deserializer)?;

let startup_timeout_sec = match (raw.startup_timeout_sec, raw.startup_timeout_ms) {
(Some(sec), _) => {
let duration = Duration::try_from_secs_f64(sec).map_err(SerdeError::custom)?;
Some(duration)
}
(None, Some(ms)) => Some(Duration::from_millis(ms)),
(None, None) => None,
};

Ok(Self {
command: raw.command,
args: raw.args,
env: raw.env,
startup_timeout_sec,
tool_timeout_sec: raw.tool_timeout_sec,
})
}
}

mod option_duration_secs {
use serde::Deserialize;
use serde::Deserializer;
use serde::Serializer;
use std::time::Duration;

pub fn serialize<S>(value: &Option<Duration>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match value {
Some(duration) => serializer.serialize_some(&duration.as_secs_f64()),
None => serializer.serialize_none(),
}
}

pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Duration>, D::Error>
where
D: Deserializer<'de>,
{
let secs = Option::<f64>::deserialize(deserializer)?;
secs.map(|secs| Duration::try_from_secs_f64(secs).map_err(serde::de::Error::custom))
.transpose()
}
}

#[derive(Deserialize, Debug, Copy, Clone, PartialEq)]
Expand Down
39 changes: 21 additions & 18 deletions codex-rs/core/src/mcp_connection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ const MAX_TOOL_NAME_LENGTH: usize = 64;
/// Default timeout for initializing MCP server & initially listing tools.
const DEFAULT_STARTUP_TIMEOUT: Duration = Duration::from_secs(10);

/// Default timeout for individual tool calls.
const DEFAULT_TOOL_TIMEOUT: Duration = Duration::from_secs(60);

/// Map that holds a startup error for every MCP server that could **not** be
/// spawned successfully.
pub type ClientStartErrors = HashMap<String, anyhow::Error>;
Expand Down Expand Up @@ -85,6 +88,7 @@ struct ToolInfo {
struct ManagedClient {
client: Arc<McpClient>,
startup_timeout: Duration,
tool_timeout: Option<Duration>,
}

/// A thin wrapper around a set of running [`McpClient`] instances.
Expand Down Expand Up @@ -132,10 +136,9 @@ impl McpConnectionManager {
continue;
}

let startup_timeout = cfg
.startup_timeout_ms
.map(Duration::from_millis)
.unwrap_or(DEFAULT_STARTUP_TIMEOUT);
let startup_timeout = cfg.startup_timeout_sec.unwrap_or(DEFAULT_STARTUP_TIMEOUT);

let tool_timeout = cfg.tool_timeout_sec.unwrap_or(DEFAULT_TOOL_TIMEOUT);

join_set.spawn(async move {
let McpServerConfig {
Expand Down Expand Up @@ -171,28 +174,28 @@ impl McpConnectionManager {
protocol_version: mcp_types::MCP_SCHEMA_VERSION.to_owned(),
};
let initialize_notification_params = None;
match client
let init_result = client
.initialize(
params,
initialize_notification_params,
Some(startup_timeout),
)
.await
{
Ok(_response) => (server_name, Ok((client, startup_timeout))),
Err(e) => (server_name, Err(e)),
}
.await;
(
(server_name, tool_timeout),
init_result.map(|_| (client, startup_timeout)),
)
}
Err(e) => (server_name, Err(e.into())),
Err(e) => ((server_name, tool_timeout), Err(e.into())),
}
});
}

let mut clients: HashMap<String, ManagedClient> = HashMap::with_capacity(join_set.len());

while let Some(res) = join_set.join_next().await {
let (server_name, client_res) = match res {
Ok((server_name, client_res)) => (server_name, client_res),
let ((server_name, tool_timeout), client_res) = match res {
Ok(result) => result,
Err(e) => {
warn!("Task panic when starting MCP server: {e:#}");
continue;
Expand All @@ -206,6 +209,7 @@ impl McpConnectionManager {
ManagedClient {
client: Arc::new(client),
startup_timeout,
tool_timeout: Some(tool_timeout),
},
);
}
Expand Down Expand Up @@ -243,14 +247,13 @@ impl McpConnectionManager {
server: &str,
tool: &str,
arguments: Option<serde_json::Value>,
timeout: Option<Duration>,
) -> Result<mcp_types::CallToolResult> {
let client = self
let managed = self
.clients
.get(server)
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?
.client
.clone();
.ok_or_else(|| anyhow!("unknown MCP server '{server}'"))?;
let client = managed.client.clone();
let timeout = managed.tool_timeout;

client
.call_tool(tool.to_string(), arguments, timeout)
Expand Down
4 changes: 1 addition & 3 deletions codex-rs/core/src/mcp_tool_call.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::time::Duration;
use std::time::Instant;

use tracing::error;
Expand All @@ -21,7 +20,6 @@ pub(crate) async fn handle_mcp_tool_call(
server: String,
tool_name: String,
arguments: String,
timeout: Option<Duration>,
) -> ResponseInputItem {
// Parse the `arguments` as JSON. An empty string is OK, but invalid JSON
// is not.
Expand Down Expand Up @@ -58,7 +56,7 @@ pub(crate) async fn handle_mcp_tool_call(
let start = Instant::now();
// Perform the tool call.
let result = sess
.call_tool(&server, &tool_name, arguments_value.clone(), timeout)
.call_tool(&server, &tool_name, arguments_value.clone())
.await
.map_err(|e| format!("tool call error: {e}"));
let tool_call_end_event = EventMsg::McpToolCallEnd(McpToolCallEndEvent {
Expand Down
Loading
Loading