Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ install-integration-tests:
cd clients/python && pip install .

install-router:
cd router && cargo install --path .
cd router && RUSTFLAGS="-D warnings" cargo install --path .

install-launcher:
cd launcher && cargo install --path .
Expand Down
1 change: 0 additions & 1 deletion router/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use base64::{engine::general_purpose::STANDARD, Engine};
use grpc_metadata::InjectTelemetryContext;
use std::cmp::min;
use tonic::transport::{Channel, Uri};
use tonic::Response;
use tracing::instrument;

use self::input_chunk::Chunk;
Expand Down
7 changes: 1 addition & 6 deletions router/client/src/sharded_client.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
use crate::pb::generate::v1::{
ClassifyPredictionList, EmbedResponse, Embedding, Entity, EntityList,
};
use crate::pb::generate::v1::{ClassifyPredictionList, Embedding};
/// Multi shard Client
use crate::{
AdapterParameters, Batch, CachedBatch, Client, DownloadAdapterResponse, Generation,
HealthResponse, ShardInfo,
};
use crate::{ClientError, Result};
use futures::future::join_all;
use regex::Regex;
use std::sync::Arc;
use tokio::task;
use tonic::transport::Uri;
use tracing::instrument;

Expand Down
1 change: 1 addition & 0 deletions router/grpc-metadata/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use opentelemetry::propagation::{Extractor, Injector};
use tracing_opentelemetry::OpenTelemetrySpanExt;

/// Extract context metadata from a gRPC request's metadata
#[allow(dead_code)]
struct MetadataExtractor<'a>(pub &'a tonic::metadata::MetadataMap);

impl<'a> Extractor for MetadataExtractor<'a> {
Expand Down
20 changes: 10 additions & 10 deletions router/src/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@ use lorax_client::{
StoppingCriteriaParameters, TokenizedInputs,
};
use nohash_hasher::{BuildNoHashHasher, IntMap};
use tokenizers::Token;
use tokio::time::Instant;
use tracing::{info_span, span, Instrument, Span};
use tracing::{Instrument, Span};

use crate::{
adapter::Adapter,
Expand Down Expand Up @@ -267,6 +266,7 @@ pub(crate) trait BatchEntries: Sync + Send + Debug {
fn adapters_in_use(&self) -> HashSet<Adapter>;
fn is_empty(&self) -> bool;
fn len(&self) -> usize;
#[allow(dead_code)]
fn state(&self) -> &BatchEntriesState;
fn mut_state(&mut self) -> &mut BatchEntriesState;

Expand Down Expand Up @@ -529,10 +529,10 @@ impl BatchEntries for EmbedBatchEntries {

async fn process_next(
&mut self,
client: &mut ShardedClient,
batches: Vec<CachedBatch>,
span: Span,
generation_health: &Arc<AtomicBool>,
_client: &mut ShardedClient,
_batches: Vec<CachedBatch>,
_span: Span,
_generation_health: &Arc<AtomicBool>,
) -> Option<CachedBatch> {
// TODO(travis): send error (programming eroor) if we get here
None
Expand Down Expand Up @@ -652,10 +652,10 @@ impl BatchEntries for ClassifyBatchEntries {

async fn process_next(
&mut self,
client: &mut ShardedClient,
batches: Vec<CachedBatch>,
span: Span,
generation_health: &Arc<AtomicBool>,
_client: &mut ShardedClient,
_batches: Vec<CachedBatch>,
_span: Span,
_generation_health: &Arc<AtomicBool>,
) -> Option<CachedBatch> {
// TODO(magdy): send error (programming eroor) if we get here
None
Expand Down
4 changes: 4 additions & 0 deletions router/src/health.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;

// Note: Request ids and batch ids cannot collide.
#[allow(dead_code)]
const LIVENESS_ID: u64 = u64::MAX;
const BATCH_ID: u64 = u64::MAX;

Expand All @@ -23,7 +24,9 @@ impl Health {
shard_info: ShardInfo,
) -> Self {
Self {
#[allow(dead_code)]
client,
#[allow(dead_code)]
generation_health,
shard_info,
}
Expand All @@ -33,6 +36,7 @@ impl Health {
&self.shard_info
}

#[allow(dead_code)]
pub(crate) async fn check(&mut self) -> bool {
if self.generation_health.load(Ordering::SeqCst) {
// Generation is healthy, we only check that the shards are answering gRPC calls
Expand Down
37 changes: 23 additions & 14 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::queue::AdapterEvent;
use crate::scheduler::AdapterScheduler;
use crate::validation::{Validation, ValidationError};
use crate::{
AdapterParameters, AlternativeToken, BatchClassifyRequest, BatchClassifyResponse, ChatTemplate,
AdapterParameters, AlternativeToken, BatchClassifyRequest, BatchClassifyResponse,
ChatTemplateVersions, ClassifyRequest, ClassifyResponse, EmbedRequest, EmbedResponse, Entity,
Entry, HubTokenizerConfig, Message, TextMessage, Token, TokenizerConfigToken,
};
Expand All @@ -16,8 +16,8 @@ use futures::future::try_join_all;
use futures::stream::StreamExt;
use itertools::multizip;
use lorax_client::{
Batch, CachedBatch, ClassifyPredictionList, ClientError, Embedding, EntityList, GeneratedText,
Generation, PrefillTokens, PreloadedAdapter, ShardedClient,
Batch, CachedBatch, ClassifyPredictionList, ClientError, Embedding, GeneratedText, Generation,
PrefillTokens, PreloadedAdapter, ShardedClient,
};
use minijinja::{Environment, ErrorKind, Template};
use minijinja_contrib::pycompat;
Expand Down Expand Up @@ -56,6 +56,7 @@ struct ChatTemplateRenderer {
template: Template<'static, 'static>,
bos_token: Option<String>,
eos_token: Option<String>,
#[allow(dead_code)] // For now allow this field even though it is unused
use_default_tool_template: bool,
}

Expand Down Expand Up @@ -92,7 +93,7 @@ impl ChatTemplateRenderer {

fn apply(
&self,
mut messages: Vec<Message>,
messages: Vec<Message>,
// grammar_with_prompt: Option<(GrammarType, String)>,
) -> Result<String, InferError> {
// TODO(travis): revisit when we add tool usage
Expand Down Expand Up @@ -452,7 +453,7 @@ impl Infer {
#[instrument(skip(self))]
pub(crate) async fn embed(&self, request: EmbedRequest) -> Result<EmbedResponse, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
let permit = self
let _permit = self
.clone()
.limit_concurrent_requests
.try_acquire_owned()
Expand Down Expand Up @@ -559,8 +560,8 @@ impl Infer {
}
InferStreamResponse::Embed {
embedding,
start,
queued,
start: _,
queued: _,
} => {
return_embeddings = Some(embedding.values);
}
Expand All @@ -585,7 +586,7 @@ impl Infer {
request: ClassifyRequest,
) -> Result<ClassifyResponse, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
let permit = self
let _permit = self
.clone()
.limit_concurrent_requests
.try_acquire_owned()
Expand Down Expand Up @@ -660,9 +661,9 @@ impl Infer {
}
InferStreamResponse::Classify {
predictions,
start,
queued,
id,
start: _,
queued: _,
id: _,
} => {
let entities = format_ner_output(predictions, self.tokenizer.clone().unwrap());
return_entities = Some(entities);
Expand All @@ -688,7 +689,7 @@ impl Infer {
request: BatchClassifyRequest,
) -> Result<BatchClassifyResponse, InferError> {
// Limit concurrent requests by acquiring a permit from the semaphore
let permit = self
let _permit = self
.clone()
.limit_concurrent_requests
.try_acquire_owned()
Expand Down Expand Up @@ -752,8 +753,8 @@ impl Infer {
// Add prefill tokens
InferStreamResponse::Classify {
predictions,
start,
queued,
start: _,
queued: _,
id,
} => {
let entities =
Expand Down Expand Up @@ -1409,12 +1410,20 @@ pub(crate) enum InferStreamResponse {
// Embeddings
Embed {
embedding: Embedding,
// For now allow this field even though it is unused.
// TODO:(magdy) enable tracing for these requests
#[allow(dead_code)]
start: Instant,
#[allow(dead_code)]
queued: Instant,
},
Classify {
predictions: ClassifyPredictionList,
// For now allow this field even though it is unused.
// TODO:(magdy) enable tracing for these requests
#[allow(dead_code)]
start: Instant,
#[allow(dead_code)]
queued: Instant,
id: Option<u64>, // to support batching
},
Expand Down
12 changes: 12 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ pub(crate) struct GenerateParameters {
pub return_k_alternatives: Option<i32>,
#[serde(default)]
#[schema(default = "false")]
#[allow(dead_code)] // For now allow this field even though it is unused
pub apply_chat_template: bool,
#[serde(default)]
#[schema(
Expand Down Expand Up @@ -483,6 +484,7 @@ enum ResponseFormatType {

#[derive(Clone, Debug, Deserialize, ToSchema)]
struct ResponseFormat {
#[allow(dead_code)] // For now allow this field even though it is unused
r#type: ResponseFormatType,
schema: serde_json::Value, // TODO: make this optional once arbitrary JSON object is supported in Outlines
}
Expand Down Expand Up @@ -571,9 +573,13 @@ struct ChatCompletionRequest {
#[serde(default)]
stop: Vec<String>,
stream: Option<bool>,
#[allow(dead_code)] // For now allow this field even though it is unused
presence_penalty: Option<f32>,
#[allow(dead_code)] // For now allow this field even though it is unused
frequency_penalty: Option<f32>,
#[allow(dead_code)] // For now allow this field even though it is unused
logit_bias: Option<std::collections::HashMap<String, f32>>,
#[allow(dead_code)] // For now allow this field even though it is unused
user: Option<String>,
seed: Option<u64>,
// Additional parameters
Expand All @@ -590,6 +596,7 @@ struct ChatCompletionRequest {
struct CompletionRequest {
model: String,
prompt: String,
#[allow(dead_code)] // For now allow this field even though it is unused
suffix: Option<String>,
max_tokens: Option<i32>,
temperature: Option<f32>,
Expand All @@ -600,10 +607,14 @@ struct CompletionRequest {
echo: Option<bool>,
#[serde(default)]
stop: Vec<String>,
#[allow(dead_code)] // For now allow this field even though it is unused
presence_penalty: Option<f32>,
#[allow(dead_code)] // For now allow this field even though it is unused
frequency_penalty: Option<f32>,
best_of: Option<i32>,
#[allow(dead_code)] // For now allow this field even though it is unused
logit_bias: Option<std::collections::HashMap<String, f32>>,
#[allow(dead_code)] // For now allow this field even though it is unused
user: Option<String>,
seed: Option<u64>,
// Additional parameters
Expand Down Expand Up @@ -712,6 +723,7 @@ pub(crate) enum CompletionFinishReason {
#[schema(rename = "content_filter")]
ContentFilter,
#[schema(rename = "tool_calls")]
#[allow(dead_code)] // For now allow this field even though it is unused
ToolCalls,
}

Expand Down
21 changes: 15 additions & 6 deletions router/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ impl AdapterLoader {
// response_receiver.await.unwrap()
}

#[allow(dead_code)] // cuurently unused
pub(crate) async fn is_errored(&self, adapter: Adapter) -> bool {
// Create response channel
let (response_sender, response_receiver) = oneshot::channel();
Expand Down Expand Up @@ -122,7 +123,7 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver<Adapte
AdapterLoaderCommand::DownloadAdapter {
adapter,
queues_state,
response_sender,
response_sender: _,
span: _, // TODO(geoffrey): not sure how to use 'span' with async fn
} => {
if err_msgs.contains_key(&adapter) {
Expand All @@ -133,7 +134,6 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver<Adapte
// time of request and the time of adapter download
locked_state.set_status(&adapter, AdapterStatus::Errored);
}
// response_sender.send(()).unwrap();
continue;
}

Expand Down Expand Up @@ -172,7 +172,7 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver<Adapte
AdapterLoaderCommand::LoadAdapter {
adapter,
queues_state,
response_sender,
response_sender: _,
span: _, // TODO(geoffrey): not sure how to use 'span' with async fn
} => {
if err_msgs.contains_key(&adapter) {
Expand Down Expand Up @@ -217,7 +217,7 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver<Adapte
AdapterLoaderCommand::OffloadAdapter {
adapter,
queues_state,
response_sender,
response_sender: _,
span: _, // TODO(geoffrey): not sure how to use 'span' with async fn
} => {
if err_msgs.contains_key(&adapter) {
Expand Down Expand Up @@ -270,8 +270,8 @@ async fn loader_task(mut client: ShardedClient, receiver: flume::Receiver<Adapte
AdapterLoaderCommand::Terminate {
adapter,
queues_state,
response_sender,
span,
response_sender: _,
span: _,
} => {
tracing::info!("terminating adapter {} loader", adapter.as_string());

Expand Down Expand Up @@ -302,21 +302,28 @@ enum AdapterLoaderCommand {
DownloadAdapter {
adapter: Adapter,
queues_state: Arc<Mutex<AdapterQueuesState>>,
#[allow(dead_code)] // currently unused
response_sender: oneshot::Sender<()>,
#[allow(dead_code)] // currently unused
span: Span,
},
LoadAdapter {
adapter: Adapter,
queues_state: Arc<Mutex<AdapterQueuesState>>,
#[allow(dead_code)] // currently unused
response_sender: oneshot::Sender<()>,
#[allow(dead_code)] // currently unused
span: Span,
},
OffloadAdapter {
adapter: Adapter,
queues_state: Arc<Mutex<AdapterQueuesState>>,
#[allow(dead_code)] // currently unused
response_sender: oneshot::Sender<()>,
#[allow(dead_code)] // currently unused
span: Span,
},
#[allow(dead_code)]
IsErrored {
adapter: Adapter,
response_sender: oneshot::Sender<bool>,
Expand All @@ -325,7 +332,9 @@ enum AdapterLoaderCommand {
Terminate {
adapter: Adapter,
queues_state: Arc<Mutex<AdapterQueuesState>>,
#[allow(dead_code)] // currently unused
response_sender: oneshot::Sender<()>,
#[allow(dead_code)] // currently unused
span: Span,
},
}
Expand Down
1 change: 0 additions & 1 deletion router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ use std::fs::File;
use std::io::BufReader;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use thiserror::Error;
use tokenizers::processors::template::TemplateProcessing;
use tokenizers::tokenizer::Tokenizer;
Expand Down
Loading