diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs index b3b75ec76a..2317ae3dc6 100644 --- a/codex-rs/core/src/codex.rs +++ b/codex-rs/core/src/codex.rs @@ -3,8 +3,6 @@ use std::collections::HashMap; use std::collections::HashSet; use std::path::PathBuf; use std::sync::Arc; -use std::sync::Mutex; -use std::sync::MutexGuard; use std::sync::atomic::AtomicU64; use std::time::Duration; @@ -31,6 +29,7 @@ use mcp_types::CallToolResult; use serde::Deserialize; use serde::Serialize; use serde_json; +use tokio::sync::Mutex; use tokio::sync::oneshot; use tokio::task::AbortHandle; use tracing::debug; @@ -135,21 +134,6 @@ mod compact; use self::compact::build_compacted_history; use self::compact::collect_user_messages; -// A convenience extension trait for acquiring mutex locks where poisoning is -// unrecoverable and should abort the program. This avoids scattered `.unwrap()` -// calls on `lock()` while still surfacing a clear panic message when a lock is -// poisoned. -trait MutexExt { - fn lock_unchecked(&self) -> MutexGuard<'_, T>; -} - -impl MutexExt for Mutex { - fn lock_unchecked(&self) -> MutexGuard<'_, T> { - #[expect(clippy::expect_used)] - self.lock().expect("poisoned lock") - } -} - /// The high-level interface to the Codex system. /// It operates as a queue pair where you send submissions and receive events. pub struct Codex { @@ -272,7 +256,6 @@ struct State { pending_input: Vec, history: ConversationHistory, token_info: Option, - next_internal_sub_id: u64, } /// Context for an initialized model agent @@ -298,6 +281,7 @@ pub(crate) struct Session { codex_linux_sandbox_exe: Option, user_shell: shell::Shell, show_raw_agent_reasoning: bool, + next_internal_sub_id: AtomicU64, } /// The context needed for a single turn of the conversation. @@ -500,6 +484,7 @@ impl Session { codex_linux_sandbox_exe: config.codex_linux_sandbox_exe.clone(), user_shell: default_shell, show_raw_agent_reasoning: config.show_raw_agent_reasoning, + next_internal_sub_id: AtomicU64::new(0), }); // Dispatch the SessionConfiguredEvent first and then report any errors. @@ -528,16 +513,16 @@ impl Session { Ok((sess, turn_context)) } - pub fn set_task(&self, task: AgentTask) { - let mut state = self.state.lock_unchecked(); + pub async fn set_task(&self, task: AgentTask) { + let mut state = self.state.lock().await; if let Some(current_task) = state.current_task.take() { current_task.abort(TurnAbortReason::Replaced); } state.current_task = Some(task); } - pub fn remove_task(&self, sub_id: &str) { - let mut state = self.state.lock_unchecked(); + pub async fn remove_task(&self, sub_id: &str) { + let mut state = self.state.lock().await; if let Some(task) = &state.current_task && task.sub_id == sub_id { @@ -546,9 +531,9 @@ impl Session { } fn next_internal_sub_id(&self) -> String { - let mut state = self.state.lock_unchecked(); - let id = state.next_internal_sub_id; - state.next_internal_sub_id += 1; + let id = self + .next_internal_sub_id + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); format!("auto-compact-{id}") } @@ -571,7 +556,7 @@ impl Session { let reconstructed_history = self.reconstruct_history_from_rollout(turn_context, &rollout_items); if !reconstructed_history.is_empty() { - self.record_into_history(&reconstructed_history); + self.record_into_history(&reconstructed_history).await; } // If persisting, persist all rollout items as-is (recorder filters) @@ -604,7 +589,7 @@ impl Session { let (tx_approve, rx_approve) = oneshot::channel(); let event_id = sub_id.clone(); let prev_entry = { - let mut state = self.state.lock_unchecked(); + let mut state = self.state.lock().await; state.pending_approvals.insert(sub_id, tx_approve) }; if prev_entry.is_some() { @@ -636,7 +621,7 @@ impl Session { let (tx_approve, rx_approve) = oneshot::channel(); let event_id = sub_id.clone(); let prev_entry = { - let mut state = self.state.lock_unchecked(); + let mut state = self.state.lock().await; state.pending_approvals.insert(sub_id, tx_approve) }; if prev_entry.is_some() { @@ -656,9 +641,9 @@ impl Session { rx_approve } - pub fn notify_approval(&self, sub_id: &str, decision: ReviewDecision) { + pub async fn notify_approval(&self, sub_id: &str, decision: ReviewDecision) { let entry = { - let mut state = self.state.lock_unchecked(); + let mut state = self.state.lock().await; state.pending_approvals.remove(sub_id) }; match entry { @@ -671,15 +656,15 @@ impl Session { } } - pub fn add_approved_command(&self, cmd: Vec) { - let mut state = self.state.lock_unchecked(); + pub async fn add_approved_command(&self, cmd: Vec) { + let mut state = self.state.lock().await; state.approved_commands.insert(cmd); } /// Records input items: always append to conversation history and /// persist these response items to rollout. async fn record_conversation_items(&self, items: &[ResponseItem]) { - self.record_into_history(items); + self.record_into_history(items).await; self.persist_rollout_response_items(items).await; } @@ -711,11 +696,9 @@ impl Session { } /// Append ResponseItems to the in-memory conversation history only. - fn record_into_history(&self, items: &[ResponseItem]) { - self.state - .lock_unchecked() - .history - .record_items(items.iter()); + async fn record_into_history(&self, items: &[ResponseItem]) { + let mut state = self.state.lock().await; + state.history.record_items(items.iter()); } async fn persist_rollout_response_items(&self, items: &[ResponseItem]) { @@ -743,8 +726,8 @@ impl Session { async fn persist_rollout_items(&self, items: &[RolloutItem]) { let recorder = { - let guard = self.rollout.lock_unchecked(); - guard.as_ref().cloned() + let guard = self.rollout.lock().await; + guard.clone() }; if let Some(rec) = recorder && let Err(e) = rec.record_items(items).await @@ -753,12 +736,12 @@ impl Session { } } - fn update_token_usage_info( + async fn update_token_usage_info( &self, turn_context: &TurnContext, token_usage: &Option, ) -> Option { - let mut state = self.state.lock_unchecked(); + let mut state = self.state.lock().await; let info = TokenUsageInfo::new_or_append( &state.token_info, token_usage, @@ -973,13 +956,17 @@ impl Session { /// Build the full turn input by concatenating the current conversation /// history with additional items for this turn. - pub fn turn_input_with_history(&self, extra: Vec) -> Vec { - [self.state.lock_unchecked().history.contents(), extra].concat() + pub async fn turn_input_with_history(&self, extra: Vec) -> Vec { + let history = { + let state = self.state.lock().await; + state.history.contents() + }; + [history, extra].concat() } /// Returns the input if there was no task running to inject into - pub fn inject_input(&self, input: Vec) -> Result<(), Vec> { - let mut state = self.state.lock_unchecked(); + pub async fn inject_input(&self, input: Vec) -> Result<(), Vec> { + let mut state = self.state.lock().await; if state.current_task.is_some() { state.pending_input.push(input.into()); Ok(()) @@ -988,8 +975,8 @@ impl Session { } } - pub fn get_pending_input(&self) -> Vec { - let mut state = self.state.lock_unchecked(); + pub async fn get_pending_input(&self) -> Vec { + let mut state = self.state.lock().await; if state.pending_input.is_empty() { Vec::with_capacity(0) } else { @@ -1011,9 +998,9 @@ impl Session { .await } - fn interrupt_task(&self) { + pub async fn interrupt_task(&self) { info!("interrupt received: abort current task, if any"); - let mut state = self.state.lock_unchecked(); + let mut state = self.state.lock().await; state.pending_approvals.clear(); state.pending_input.clear(); if let Some(task) = state.current_task.take() { @@ -1021,6 +1008,16 @@ impl Session { } } + fn interrupt_task_sync(&self) { + if let Ok(mut state) = self.state.try_lock() { + state.pending_approvals.clear(); + state.pending_input.clear(); + if let Some(task) = state.current_task.take() { + task.abort(TurnAbortReason::Interrupted); + } + } + } + /// Spawn the configured notifier (if any) with the given JSON payload as /// the last argument. Failures are logged but otherwise ignored so that /// notification issues do not interfere with the main workflow. @@ -1053,7 +1050,7 @@ impl Session { impl Drop for Session { fn drop(&mut self) { - self.interrupt_task(); + self.interrupt_task_sync(); } } @@ -1184,7 +1181,7 @@ async fn submission_loop( debug!(?sub, "Submission"); match sub.op { Op::Interrupt => { - sess.interrupt_task(); + sess.interrupt_task().await; } Op::OverrideTurnContext { cwd, @@ -1277,11 +1274,11 @@ async fn submission_loop( } Op::UserInput { items } => { // attempt to inject input into current task - if let Err(items) = sess.inject_input(items) { + if let Err(items) = sess.inject_input(items).await { // no current task, spawn a new one let task = AgentTask::spawn(sess.clone(), Arc::clone(&turn_context), sub.id, items); - sess.set_task(task); + sess.set_task(task).await; } } Op::UserTurn { @@ -1294,7 +1291,7 @@ async fn submission_loop( summary, } => { // attempt to inject input into current task - if let Err(items) = sess.inject_input(items) { + if let Err(items) = sess.inject_input(items).await { // Derive a fresh TurnContext for this turn using the provided overrides. let provider = turn_context.client.get_provider(); let auth_manager = turn_context.client.get_auth_manager(); @@ -1360,20 +1357,20 @@ async fn submission_loop( // no current task, spawn a new one with the per‑turn context let task = AgentTask::spawn(sess.clone(), Arc::clone(&turn_context), sub.id, items); - sess.set_task(task); + sess.set_task(task).await; } } Op::ExecApproval { id, decision } => match decision { ReviewDecision::Abort => { - sess.interrupt_task(); + sess.interrupt_task().await; } - other => sess.notify_approval(&id, other), + other => sess.notify_approval(&id, other).await, }, Op::PatchApproval { id, decision } => match decision { ReviewDecision::Abort => { - sess.interrupt_task(); + sess.interrupt_task().await; } - other => sess.notify_approval(&id, other), + other => sess.notify_approval(&id, other).await, }, Op::AddToHistory { text } => { let id = sess.conversation_id; @@ -1452,15 +1449,19 @@ async fn submission_loop( } Op::Compact => { // Attempt to inject input into current task - if let Err(items) = sess.inject_input(vec![InputItem::Text { - text: compact::COMPACT_TRIGGER_TEXT.to_string(), - }]) { + if let Err(items) = sess + .inject_input(vec![InputItem::Text { + text: compact::COMPACT_TRIGGER_TEXT.to_string(), + }]) + .await + { compact::spawn_compact_task( sess.clone(), Arc::clone(&turn_context), sub.id, items, - ); + ) + .await; } } Op::Shutdown => { @@ -1468,7 +1469,10 @@ async fn submission_loop( // Gracefully flush and shutdown rollout recorder on session end so tests // that inspect the rollout file do not race with the background writer. - let recorder_opt = sess.rollout.lock_unchecked().take(); + let recorder_opt = { + let mut guard = sess.rollout.lock().await; + guard.take() + }; if let Some(rec) = recorder_opt && let Err(e) = rec.shutdown().await { @@ -1493,7 +1497,7 @@ async fn submission_loop( let sub_id = sub.id.clone(); // Flush rollout writes before returning the path so readers observe a consistent file. let (path, rec_opt) = { - let guard = sess.rollout.lock_unchecked(); + let guard = sess.rollout.lock().await; match guard.as_ref() { Some(rec) => (rec.get_rollout_path(), Some(rec.clone())), None => { @@ -1604,7 +1608,7 @@ async fn spawn_review_thread( // Clone sub_id for the upcoming announcement before moving it into the task. let sub_id_for_event = sub_id.clone(); let task = AgentTask::review(sess.clone(), tc.clone(), sub_id, input); - sess.set_task(task); + sess.set_task(task).await; // Announce entering review mode so UIs can switch modes. sess.send_event(Event { @@ -1675,6 +1679,7 @@ async fn run_task( // may support this, the model might not. let pending_input = sess .get_pending_input() + .await .into_iter() .map(ResponseItem::from) .collect::>(); @@ -1696,7 +1701,7 @@ async fn run_task( review_thread_history.clone() } else { sess.record_conversation_items(&pending_input).await; - sess.turn_input_with_history(pending_input) + sess.turn_input_with_history(pending_input).await }; let turn_input_messages: Vec = turn_input @@ -1908,7 +1913,7 @@ async fn run_task( .await; } - sess.remove_task(&sub_id); + sess.remove_task(&sub_id).await; let event = Event { id: sub_id, msg: EventMsg::TaskComplete(TaskCompleteEvent { last_agent_message }), @@ -2141,7 +2146,9 @@ async fn try_run_turn( response_id: _, token_usage, } => { - let info = sess.update_token_usage_info(turn_context, &token_usage); + let info = sess + .update_token_usage_info(turn_context, &token_usage) + .await; let _ = sess .send_event(Event { id: sub_id.to_string(), @@ -2475,7 +2482,10 @@ async fn handle_function_call( } }; let abs = turn_context.resolve_path(Some(args.path)); - let output = match sess.inject_input(vec![InputItem::LocalImage { path: abs }]) { + let output = match sess + .inject_input(vec![InputItem::LocalImage { path: abs }]) + .await + { Ok(()) => FunctionCallOutputPayload { content: "attached local image path".to_string(), success: Some(true), @@ -2789,7 +2799,7 @@ async fn handle_container_exec_with_params( } None => { let safety = { - let state = sess.state.lock_unchecked(); + let state = sess.state.lock().await; assess_command_safety( ¶ms.command, turn_context.approval_policy, @@ -2818,7 +2828,7 @@ async fn handle_container_exec_with_params( match rx_approve.await.unwrap_or_default() { ReviewDecision::Approved => (), ReviewDecision::ApprovedForSession => { - sess.add_approved_command(params.command.clone()); + sess.add_approved_command(params.command.clone()).await; } ReviewDecision::Denied | ReviewDecision::Abort => { return ResponseInputItem::FunctionCallOutput { @@ -2991,7 +3001,7 @@ async fn handle_sandbox_error( // remainder of the session so future // executions skip the sandbox directly. // TODO(ragona): Isn't this a bug? It always saves the command in an | fork? - sess.add_approved_command(params.command.clone()); + sess.add_approved_command(params.command.clone()).await; // Inform UI we are retrying without sandbox. sess.notify_background_event(&sub_id, "retrying command without sandbox") .await; @@ -3356,7 +3366,7 @@ mod tests { }), )); - let actual = session.state.lock_unchecked().history.contents(); + let actual = tokio_test::block_on(async { session.state.lock().await.history.contents() }); assert_eq!(expected, actual); } @@ -3369,7 +3379,7 @@ mod tests { session.record_initial_history(&turn_context, InitialHistory::Forked(rollout_items)), ); - let actual = session.state.lock_unchecked().history.contents(); + let actual = tokio_test::block_on(async { session.state.lock().await.history.contents() }); assert_eq!(expected, actual); } @@ -3611,6 +3621,7 @@ mod tests { codex_linux_sandbox_exe: None, user_shell: shell::Shell::Unknown, show_raw_agent_reasoning: config.show_raw_agent_reasoning, + next_internal_sub_id: AtomicU64::new(0), }; (session, turn_context) } diff --git a/codex-rs/core/src/codex/compact.rs b/codex-rs/core/src/codex/compact.rs index a465f937d4..c198df3b2f 100644 --- a/codex-rs/core/src/codex/compact.rs +++ b/codex-rs/core/src/codex/compact.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use super::AgentTask; -use super::MutexExt; use super::Session; use super::TurnContext; use super::get_last_assistant_message_from_turn; @@ -37,7 +36,7 @@ struct HistoryBridgeTemplate<'a> { summary_text: &'a str, } -pub(super) fn spawn_compact_task( +pub(super) async fn spawn_compact_task( sess: Arc, turn_context: Arc, sub_id: String, @@ -50,7 +49,7 @@ pub(super) fn spawn_compact_task( input, SUMMARIZATION_PROMPT.to_string(), ); - sess.set_task(task); + sess.set_task(task).await; } pub(super) async fn run_inline_auto_compact_task( @@ -109,7 +108,9 @@ async fn run_compact_task_inner( let initial_input_for_turn: ResponseInputItem = ResponseInputItem::from(input); let instructions_override = compact_instructions; - let turn_input = sess.turn_input_with_history(vec![initial_input_for_turn.clone().into()]); + let turn_input = sess + .turn_input_with_history(vec![initial_input_for_turn.clone().into()]) + .await; let prompt = Prompt { input: turn_input, @@ -168,10 +169,10 @@ async fn run_compact_task_inner( } if remove_task_on_completion { - sess.remove_task(&sub_id); + sess.remove_task(&sub_id).await; } let history_snapshot = { - let state = sess.state.lock_unchecked(); + let state = sess.state.lock().await; state.history.contents() }; let summary_text = get_last_assistant_message_from_turn(&history_snapshot).unwrap_or_default(); @@ -179,7 +180,7 @@ async fn run_compact_task_inner( let initial_context = sess.build_initial_context(turn_context.as_ref()); let new_history = build_compacted_history(initial_context, &user_messages, &summary_text); { - let mut state = sess.state.lock_unchecked(); + let mut state = sess.state.lock().await; state.history.replace(new_history); } @@ -290,7 +291,7 @@ async fn drain_to_completed( }; match event { Ok(ResponseEvent::OutputItemDone(item)) => { - let mut state = sess.state.lock_unchecked(); + let mut state = sess.state.lock().await; state.history.record_items(std::slice::from_ref(&item)); } Ok(ResponseEvent::Completed { .. }) => {