Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 15 additions & 29 deletions codex-rs/core/src/exec_command/exec_command_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@ pub(crate) struct ExecCommandSession {
/// Broadcast stream of output chunks read from the PTY. New subscribers
/// receive only chunks emitted after they subscribe.
output_tx: broadcast::Sender<Vec<u8>>,
/// Receiver subscribed before the child process starts emitting output so
/// the first caller can consume any early data without races.
initial_output_rx: StdMutex<Option<broadcast::Receiver<Vec<u8>>>>,

/// Child killer handle for termination on drop (can signal independently
/// of a thread blocked in `.wait()`).
Expand Down Expand Up @@ -41,39 +38,28 @@ impl ExecCommandSession {
writer_handle: JoinHandle<()>,
wait_handle: JoinHandle<()>,
exit_status: std::sync::Arc<std::sync::atomic::AtomicBool>,
) -> Self {
Self {
writer_tx,
output_tx,
initial_output_rx: StdMutex::new(None),
killer: StdMutex::new(Some(killer)),
reader_handle: StdMutex::new(Some(reader_handle)),
writer_handle: StdMutex::new(Some(writer_handle)),
wait_handle: StdMutex::new(Some(wait_handle)),
exit_status,
}
}

pub(crate) fn set_initial_output_receiver(&self, receiver: broadcast::Receiver<Vec<u8>>) {
if let Ok(mut guard) = self.initial_output_rx.lock()
&& guard.is_none()
{
*guard = Some(receiver);
}
) -> (Self, broadcast::Receiver<Vec<u8>>) {
let initial_output_rx = output_tx.subscribe();
(
Self {
writer_tx,
output_tx,
killer: StdMutex::new(Some(killer)),
reader_handle: StdMutex::new(Some(reader_handle)),
writer_handle: StdMutex::new(Some(writer_handle)),
wait_handle: StdMutex::new(Some(wait_handle)),
exit_status,
},
initial_output_rx,
)
}

pub(crate) fn writer_sender(&self) -> mpsc::Sender<Vec<u8>> {
self.writer_tx.clone()
}

pub(crate) fn output_receiver(&self) -> broadcast::Receiver<Vec<u8>> {
if let Ok(mut guard) = self.initial_output_rx.lock()
&& let Some(receiver) = guard.take()
{
receiver
} else {
self.output_tx.subscribe()
}
self.output_tx.subscribe()
}

pub(crate) fn has_exited(&self) -> bool {
Expand Down
31 changes: 15 additions & 16 deletions codex-rs/core/src/exec_command/session_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,16 @@ impl SessionManager {
.fetch_add(1, std::sync::atomic::Ordering::SeqCst),
);

let (session, mut exit_rx) =
create_exec_command_session(params.clone())
.await
.map_err(|err| {
format!(
"failed to create exec command session for session id {}: {err}",
session_id.0
)
})?;
let (session, mut output_rx, mut exit_rx) = create_exec_command_session(params.clone())
.await
.map_err(|err| {
format!(
"failed to create exec command session for session id {}: {err}",
session_id.0
)
})?;

// Insert into session map.
let mut output_rx = session.output_receiver();
self.sessions.lock().await.insert(session_id, session);

// Collect output until either timeout expires or process exits.
Expand Down Expand Up @@ -245,7 +243,11 @@ impl SessionManager {
/// Spawn PTY and child process per spawn_exec_command_session logic.
async fn create_exec_command_session(
params: ExecCommandParams,
) -> anyhow::Result<(ExecCommandSession, oneshot::Receiver<i32>)> {
) -> anyhow::Result<(
ExecCommandSession,
tokio::sync::broadcast::Receiver<Vec<u8>>,
oneshot::Receiver<i32>,
)> {
let ExecCommandParams {
cmd,
yield_time_ms: _,
Expand Down Expand Up @@ -279,8 +281,6 @@ async fn create_exec_command_session(
let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
// Broadcast for streaming PTY output to readers: subscribers receive from subscription time.
let (output_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(256);
let initial_output_rx = output_tx.subscribe();

// Reader task: drain PTY and forward chunks to output channel.
let mut reader = pair.master.try_clone_reader()?;
let output_tx_clone = output_tx.clone();
Expand Down Expand Up @@ -342,7 +342,7 @@ async fn create_exec_command_session(
});

// Create and store the session with channels.
let session = ExecCommandSession::new(
let (session, initial_output_rx) = ExecCommandSession::new(
writer_tx,
output_tx,
killer,
Expand All @@ -351,8 +351,7 @@ async fn create_exec_command_session(
wait_handle,
exit_status,
);
session.set_initial_output_receiver(initial_output_rx);
Ok((session, exit_rx))
Ok((session, initial_output_rx, exit_rx))
}

#[cfg(test)]
Expand Down
26 changes: 16 additions & 10 deletions codex-rs/core/src/unified_exec/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,13 @@ type OutputBuffer = Arc<Mutex<OutputBufferState>>;
type OutputHandles = (OutputBuffer, Arc<Notify>);

impl ManagedUnifiedExecSession {
fn new(session: ExecCommandSession) -> Self {
fn new(
session: ExecCommandSession,
initial_output_rx: tokio::sync::broadcast::Receiver<Vec<u8>>,
) -> Self {
let output_buffer = Arc::new(Mutex::new(OutputBufferState::default()));
let output_notify = Arc::new(Notify::new());
let mut receiver = session.output_receiver();
let mut receiver = initial_output_rx;
let buffer_clone = Arc::clone(&output_buffer);
let notify_clone = Arc::clone(&output_notify);
let output_task = tokio::spawn(async move {
Expand Down Expand Up @@ -193,8 +196,8 @@ impl UnifiedExecSessionManager {
} else {
let command = request.input_chunks.to_vec();
let new_id = self.next_session_id.fetch_add(1, Ordering::SeqCst);
let session = create_unified_exec_session(&command).await?;
let managed_session = ManagedUnifiedExecSession::new(session);
let (session, initial_output_rx) = create_unified_exec_session(&command).await?;
let managed_session = ManagedUnifiedExecSession::new(session, initial_output_rx);
let (buffer, notify) = managed_session.output_handles();
writer_tx = managed_session.writer_sender();
output_buffer = buffer;
Expand Down Expand Up @@ -297,7 +300,13 @@ impl UnifiedExecSessionManager {

async fn create_unified_exec_session(
command: &[String],
) -> Result<ExecCommandSession, UnifiedExecError> {
) -> Result<
(
ExecCommandSession,
tokio::sync::broadcast::Receiver<Vec<u8>>,
),
UnifiedExecError,
> {
if command.is_empty() {
return Err(UnifiedExecError::MissingCommandLine);
}
Expand Down Expand Up @@ -327,7 +336,6 @@ async fn create_unified_exec_session(

let (writer_tx, mut writer_rx) = mpsc::channel::<Vec<u8>>(128);
let (output_tx, _) = tokio::sync::broadcast::channel::<Vec<u8>>(256);
let initial_output_rx = output_tx.subscribe();

let mut reader = pair
.master
Expand Down Expand Up @@ -381,7 +389,7 @@ async fn create_unified_exec_session(
wait_exit_status.store(true, Ordering::SeqCst);
});

let session = ExecCommandSession::new(
let (session, initial_output_rx) = ExecCommandSession::new(
writer_tx,
output_tx,
killer,
Expand All @@ -390,9 +398,7 @@ async fn create_unified_exec_session(
wait_handle,
exit_status,
);
session.set_initial_output_receiver(initial_output_rx);

Ok(session)
Ok((session, initial_output_rx))
}

#[cfg(test)]
Expand Down
Loading