Skip to content

Change cloud language model provider JSON protocol to surface errors and usage information #29830

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
May 4, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Tweak naming and JSON structure of cloud completion status messages
  • Loading branch information
maxbrunsfeld committed May 3, 2025
commit 7012389f8f352e8ea3e37238834525ecbe38469a
2 changes: 1 addition & 1 deletion crates/agent/src/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1497,7 +1497,7 @@ impl Thread {
language_model::CompletionRequestStatus::Started => {
QueueState::Started
}
language_model::CompletionRequestStatus::Error {
language_model::CompletionRequestStatus::Failed {
code, message
} => {
return Err(anyhow!("completion request failed. code: {code}, message: {message}"));
Expand Down
2 changes: 1 addition & 1 deletion crates/language_model/src/language_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ pub struct LanguageModelCacheConfiguration {
pub enum CompletionRequestStatus {
Queued { position: usize },
Started,
Error { code: String, message: String },
Failed { code: String, message: String },
}

/// A completion event from a language model.
Expand Down
27 changes: 15 additions & 12 deletions crates/language_models/src/provider/cloud.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ use crate::provider::open_ai::{OpenAiEventMapper, count_open_ai_tokens, into_ope

pub const PROVIDER_NAME: &str = "Zed";

const CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER: &str = "x-zed-client-supports-status-messages";
const SERVER_SUPPORTS_STATUS_MESSAGES_HEADER: &str = "x-zed-server-supports-queueing";

const ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON: Option<&str> =
option_env!("ZED_CLOUD_PROVIDER_ADDITIONAL_MODELS_JSON");

Expand Down Expand Up @@ -537,18 +540,18 @@ impl CloudLanguageModel {
let request = request_builder
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {token}"))
.header("x-zed-client-supports-queueing", "true")
.header(CLIENT_SUPPORTS_STATUS_MESSAGES_HEADER, "true")
.body(serde_json::to_string(&body)?.into())?;
let mut response = http_client.send(request).await?;
let status = response.status();
if status.is_success() {
let includes_queue_events = response
let includes_status_messages = response
.headers()
.get("x-zed-server-supports-queueing")
.get(SERVER_SUPPORTS_STATUS_MESSAGES_HEADER)
.is_some();
let usage = RequestUsage::from_headers(response.headers()).ok();

return Ok((response, usage, includes_queue_events));
return Ok((response, usage, includes_status_messages));
} else if response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
Expand Down Expand Up @@ -788,7 +791,7 @@ impl LanguageModel for CloudLanguageModel {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream_with_usage(async move {
let (response, usage, includes_queue_events) = Self::perform_llm_completion(
let (response, usage, includes_status_messages) = Self::perform_llm_completion(
client.clone(),
llm_api_token,
CompletionBody {
Expand Down Expand Up @@ -820,7 +823,7 @@ impl LanguageModel for CloudLanguageModel {
let mut mapper = AnthropicEventMapper::new();
Ok((
map_cloud_completion_events(
Box::pin(response_lines(response, includes_queue_events)),
Box::pin(response_lines(response, includes_status_messages)),
move |event| mapper.map_event(event),
),
usage,
Expand All @@ -837,7 +840,7 @@ impl LanguageModel for CloudLanguageModel {
let request = into_open_ai(request, model, model.max_output_tokens());
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream_with_usage(async move {
let (response, usage, includes_queue_events) = Self::perform_llm_completion(
let (response, usage, includes_status_messages) = Self::perform_llm_completion(
client.clone(),
llm_api_token,
CompletionBody {
Expand All @@ -854,7 +857,7 @@ impl LanguageModel for CloudLanguageModel {
let mut mapper = OpenAiEventMapper::new();
Ok((
map_cloud_completion_events(
Box::pin(response_lines(response, includes_queue_events)),
Box::pin(response_lines(response, includes_status_messages)),
move |event| mapper.map_event(event),
),
usage,
Expand All @@ -871,7 +874,7 @@ impl LanguageModel for CloudLanguageModel {
let request = into_google(request, model.id().into());
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream_with_usage(async move {
let (response, usage, includes_queue_events) = Self::perform_llm_completion(
let (response, usage, includes_status_messages) = Self::perform_llm_completion(
client.clone(),
llm_api_token,
CompletionBody {
Expand All @@ -887,7 +890,7 @@ impl LanguageModel for CloudLanguageModel {
let mut mapper = GoogleEventMapper::new();
Ok((
map_cloud_completion_events(
Box::pin(response_lines(response, includes_queue_events)),
Box::pin(response_lines(response, includes_status_messages)),
move |event| mapper.map_event(event),
),
usage,
Expand All @@ -906,7 +909,7 @@ impl LanguageModel for CloudLanguageModel {
#[derive(Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CloudCompletionEvent<T> {
System(CompletionRequestStatus),
Status(CompletionRequestStatus),
Event(T),
}

Expand All @@ -926,7 +929,7 @@ where
Err(error) => {
vec![Err(LanguageModelCompletionError::Other(error))]
}
Ok(CloudCompletionEvent::System(event)) => {
Ok(CloudCompletionEvent::Status(event)) => {
vec![Ok(LanguageModelCompletionEvent::QueueUpdate(event))]
}
Ok(CloudCompletionEvent::Event(event)) => map_callback(event),
Expand Down