Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions codex-rs/core/src/codex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ use tracing::error;
use tracing::info;
use tracing::trace;
use tracing::warn;
use uuid::Uuid;

use crate::ModelProviderInfo;
use crate::apply_patch;
Expand Down Expand Up @@ -104,6 +103,7 @@ use crate::protocol::TokenUsageInfo;
use crate::protocol::TurnDiffEvent;
use crate::protocol::WebSearchBeginEvent;
use crate::rollout::RolloutRecorder;
use crate::rollout::RolloutRecorderParams;
use crate::safety::SafetyCheck;
use crate::safety::assess_command_safety;
use crate::safety::assess_safety_for_untrusted_command;
Expand Down Expand Up @@ -362,7 +362,6 @@ impl Session {
tx_event: Sender<Event>,
initial_history: InitialHistory,
) -> anyhow::Result<(Arc<Self>, TurnContext)> {
let conversation_id = ConversationId::from(Uuid::new_v4());
let ConfigureSession {
provider,
model,
Expand All @@ -380,6 +379,20 @@ impl Session {
return Err(anyhow::anyhow!("cwd is not absolute: {cwd:?}"));
}

let (conversation_id, rollout_params) = match &initial_history {
InitialHistory::New | InitialHistory::Forked(_) => {
let conversation_id = ConversationId::default();
(
conversation_id,
RolloutRecorderParams::new(conversation_id, user_instructions.clone()),
)
}
InitialHistory::Resumed(resumed_history) => (
resumed_history.conversation_id,
RolloutRecorderParams::resume(resumed_history.rollout_path.clone()),
),
};

// Error messages to dispatch after SessionConfigured is sent.
let mut post_session_configured_error_events = Vec::<Event>::new();

Expand All @@ -389,7 +402,7 @@ impl Session {
// - spin up MCP connection manager
// - perform default shell discovery
// - load history metadata
let rollout_fut = RolloutRecorder::new(&config, conversation_id, user_instructions.clone());
let rollout_fut = RolloutRecorder::new(&config, rollout_params);

let mcp_fut = McpConnectionManager::new(config.mcp_servers.clone());
let default_shell_fut = shell::default_user_shell();
Expand Down Expand Up @@ -481,7 +494,10 @@ impl Session {
// If resuming, include converted initial messages in the payload so UIs can render them immediately.
let initial_messages = match &initial_history {
InitialHistory::New => None,
InitialHistory::Resumed(items) => Some(sess.build_initial_messages(items)),
InitialHistory::Forked(items) => Some(sess.build_initial_messages(items)),
InitialHistory::Resumed(resumed_history) => {
Some(sess.build_initial_messages(&resumed_history.history))
}
};

let events = std::iter::once(Event {
Expand Down Expand Up @@ -530,8 +546,12 @@ impl Session {
InitialHistory::New => {
self.record_initial_history_new(turn_context).await;
}
InitialHistory::Resumed(items) => {
self.record_initial_history_resumed(items).await;
InitialHistory::Forked(items) => {
self.record_initial_history_from_items(items).await;
}
InitialHistory::Resumed(resumed_history) => {
self.record_initial_history_from_items(resumed_history.history)
.await;
}
}
}
Expand All @@ -553,8 +573,8 @@ impl Session {
self.record_conversation_items(&conversation_items).await;
}

async fn record_initial_history_resumed(&self, items: Vec<ResponseItem>) {
self.record_conversation_items(&items).await;
async fn record_initial_history_from_items(&self, items: Vec<ResponseItem>) {
self.record_conversation_items_internal(&items, false).await;
}

/// build the initial messages vector for SessionConfigured by converting
Expand Down Expand Up @@ -663,8 +683,14 @@ impl Session {
/// Records items to both the rollout and the chat completions/ZDR
/// transcript, if enabled.
async fn record_conversation_items(&self, items: &[ResponseItem]) {
self.record_conversation_items_internal(items, true).await;
}

async fn record_conversation_items_internal(&self, items: &[ResponseItem], persist: bool) {
debug!("Recording items for conversation: {items:?}");
self.record_state_snapshot(items).await;
if persist {
self.record_state_snapshot(items).await;
}

self.state.lock_unchecked().history.record_items(items);
}
Expand Down
30 changes: 18 additions & 12 deletions codex-rs/core/src/conversation_manager.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;

use crate::AuthManager;
use crate::CodexAuth;
use codex_protocol::mcp_protocol::ConversationId;
use tokio::sync::RwLock;

use crate::codex::Codex;
use crate::codex::CodexSpawnOk;
use crate::codex::INITIAL_SUBMIT_ID;
Expand All @@ -18,12 +11,25 @@ use crate::protocol::Event;
use crate::protocol::EventMsg;
use crate::protocol::SessionConfiguredEvent;
use crate::rollout::RolloutRecorder;
use codex_protocol::mcp_protocol::ConversationId;
use codex_protocol::models::ResponseItem;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::RwLock;

#[derive(Debug, Clone, PartialEq)]
pub struct ResumedHistory {
pub conversation_id: ConversationId,
pub history: Vec<ResponseItem>,
pub rollout_path: PathBuf,
}

#[derive(Debug, Clone, PartialEq)]
pub enum InitialHistory {
New,
Resumed(Vec<ResponseItem>),
Resumed(ResumedHistory),
Forked(Vec<ResponseItem>),
}

/// Represents a newly created Codex conversation, including the first event
Expand Down Expand Up @@ -77,7 +83,7 @@ impl ConversationManager {
let CodexSpawnOk {
codex,
conversation_id,
} = { Codex::spawn(config, auth_manager, InitialHistory::New).await? };
} = Codex::spawn(config, auth_manager, InitialHistory::New).await?;
self.finalize_spawn(codex, conversation_id).await
}
}
Expand Down Expand Up @@ -172,7 +178,7 @@ impl ConversationManager {
/// and all items that follow them.
fn truncate_after_dropping_last_messages(items: Vec<ResponseItem>, n: usize) -> InitialHistory {
if n == 0 {
return InitialHistory::Resumed(items);
return InitialHistory::Forked(items);
}

// Walk backwards counting only `user` Message items, find cut index.
Expand All @@ -194,7 +200,7 @@ fn truncate_after_dropping_last_messages(items: Vec<ResponseItem>, n: usize) ->
// No prefix remains after dropping; start a new conversation.
InitialHistory::New
} else {
InitialHistory::Resumed(items.into_iter().take(cut_index).collect())
InitialHistory::Forked(items.into_iter().take(cut_index).collect())
}
}

Expand Down Expand Up @@ -252,7 +258,7 @@ mod tests {
let truncated = truncate_after_dropping_last_messages(items.clone(), 1);
assert_eq!(
truncated,
InitialHistory::Resumed(vec![items[0].clone(), items[1].clone(), items[2].clone(),])
InitialHistory::Forked(vec![items[0].clone(), items[1].clone(), items[2].clone(),])
);

let truncated2 = truncate_after_dropping_last_messages(items, 2);
Expand Down
1 change: 1 addition & 0 deletions codex-rs/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pub mod terminal;
mod tool_apply_patch;
pub mod turn_diff_tracker;
pub use rollout::RolloutRecorder;
pub use rollout::SessionMeta;
pub use rollout::list::ConversationItem;
pub use rollout::list::ConversationsPage;
pub use rollout::list::Cursor;
Expand Down
2 changes: 1 addition & 1 deletion codex-rs/core/src/message_history.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ use std::path::PathBuf;
use serde::Deserialize;
use serde::Serialize;

use codex_protocol::mcp_protocol::ConversationId;
use std::time::Duration;
use tokio::fs;
use tokio::io::AsyncReadExt;

use crate::config::Config;
use crate::config_types::HistoryPersistence;

use codex_protocol::mcp_protocol::ConversationId;
#[cfg(unix)]
use std::os::unix::fs::OpenOptionsExt;
#[cfg(unix)]
Expand Down
2 changes: 2 additions & 0 deletions codex-rs/core/src/rollout/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ pub(crate) mod policy;
pub mod recorder;

pub use recorder::RolloutRecorder;
pub use recorder::RolloutRecorderParams;
pub use recorder::SessionMeta;
pub use recorder::SessionStateSnapshot;

#[cfg(test)]
Expand Down
Loading
Loading