diff --git a/Cargo.lock b/Cargo.lock index 0c3eb81..c245bc4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1227,6 +1227,7 @@ dependencies = [ "half", "hf-hub", "ndarray", + "rayon", "safetensors", "serde_json", "tokenizers", diff --git a/Cargo.toml b/Cargo.toml index 6fe1910..cd535a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,6 @@ version = "0.1.0" edition = "2021" description = "Fast State-of-the-Art Static Embeddings in Rust" readme = "README.md" -license = "MIT" license-file = "LICENSE" authors = ["Thomas van Dongen ", "Stéphan Tulkens "] homepage = "https://github.com/MinishLab/model2vec-rs" @@ -21,6 +20,7 @@ clap = { version = "4.0", features = ["derive"] } anyhow = "1.0" serde_json = "1.0" half = "2.0" +rayon = "1.7" [dev-dependencies] approx = "0.5" diff --git a/README.md b/README.md index f92e21a..dba00bf 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ use model2vec_rust::inference::StaticModel; fn main() -> Result<()> { // Load a model from the Hugging Face Hub or a local path - let model = StaticModel::from_pretrained("minishlab/potion-base-8M", None)?; + let model = StaticModel::from_pretrained("minishlab/potion-base-8M", None, None, None)?; // Prepare a list of sentences let texts = vec![ diff --git a/src/main.rs b/src/main.rs index f4272e5..9b310e0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -46,7 +46,7 @@ fn main() -> Result<()> { vec![input] }; - let m = StaticModel::from_pretrained(&model, None)?; + let m = StaticModel::from_pretrained(&model, None, None, None)?; let embs = m.encode(&texts); if let Some(path) = output { @@ -67,8 +67,9 @@ fn main() -> Result<()> { vec![input] }; - let m = StaticModel::from_pretrained(&model, None)?; - let ids = m.tokenize(&texts); + let m = StaticModel::from_pretrained(&model, None, None, None)?; + // Provide default None for max_tokens to include all tokens + let ids = m.tokenize(&texts, None); println!("Token ID sequences: {:#?}", ids); } } diff --git a/src/model.rs b/src/model.rs index fee5656..174cae6 100644 --- a/src/model.rs +++ b/src/model.rs @@ -3,8 +3,8 @@ use tokenizers::Tokenizer; use safetensors::{SafeTensors, tensor::Dtype}; use half::f16; use ndarray::Array2; -use std::fs::read; -use std::path::Path; +use rayon::prelude::*; +use std::{fs::read, path::Path, env}; use anyhow::{Result, Context, anyhow}; use serde_json::Value; @@ -13,104 +13,168 @@ pub struct StaticModel { tokenizer: Tokenizer, embeddings: Array2, normalize: bool, + median_token_length: usize, } impl StaticModel { /// Load a Model2Vec model from a local folder or the HF Hub. - /// - /// # Arguments - /// * `repo_or_path` - HF repo ID or local filesystem path - /// * `subfolder` - optional subdirectory inside the repo or folder - pub fn from_pretrained(repo_or_path: &str, subfolder: Option<&str>) -> Result { + pub fn from_pretrained( + repo_or_path: &str, + token: Option<&str>, + normalize: Option, + subfolder: Option<&str>, + ) -> Result { + // If provided, set HF token for authenticated downloads + if let Some(tok) = token { + env::set_var("HF_HUB_TOKEN", tok); + } + // Determine file paths let (tok_path, mdl_path, cfg_path) = { let base = Path::new(repo_or_path); if base.exists() { - // Local path let folder = subfolder.map(|s| base.join(s)).unwrap_or_else(|| base.to_path_buf()); let t = folder.join("tokenizer.json"); let m = folder.join("model.safetensors"); let c = folder.join("config.json"); if !t.exists() || !m.exists() || !c.exists() { - return Err(anyhow!("Local path {:?} missing tokenizer/model/config files", folder)); + return Err(anyhow!("Local path {:?} missing files", folder)); } (t, m, c) } else { - // HF Hub path - let api = Api::new().context("Failed to initialize HF Hub API")?; + let api = Api::new().context("HF Hub API init failed")?; let repo = api.model(repo_or_path.to_string()); + // note: token not used with sync Api let prefix = subfolder.map(|s| format!("{}/", s)).unwrap_or_default(); - let t = repo.get(&format!("{}tokenizer.json", prefix)).context("Failed to download tokenizer.json")?; - let m = repo.get(&format!("{}model.safetensors", prefix)).context("Failed to download model.safetensors")?; - let c = repo.get(&format!("{}config.json", prefix)).context("Failed to download config.json")?; + let t = repo.get(&format!("{}tokenizer.json", prefix)) + .context("Download tokenizer.json failed")?; + let m = repo.get(&format!("{}model.safetensors", prefix)) + .context("Download model.safetensors failed")?; + let c = repo.get(&format!("{}config.json", prefix)) + .context("Download config.json failed")?; (t.into(), m.into(), c.into()) } }; // Load tokenizer let tokenizer = Tokenizer::from_file(&tok_path) - .map_err(|e| anyhow!("Failed to load tokenizer: {}", e))?; + .map_err(|e| anyhow!("Tokenizer load error: {}", e))?; + + // Median token length for char-level truncation + let mut lengths: Vec = tokenizer.get_vocab(false) + .keys().map(|tk| tk.len()).collect(); + lengths.sort_unstable(); + let median_token_length = *lengths.get(lengths.len() / 2).unwrap_or(&1); + + // Read config.json for default normalize + let cfg: Value = serde_json::from_slice(&read(&cfg_path)?) + .context("Parse config.json failed")?; + let config_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true); + let normalize = normalize.unwrap_or(config_norm); - // Read safetensors file - let bytes = read(&mdl_path).context("Failed to read model.safetensors")?; - let safet = SafeTensors::deserialize(&bytes).context("Failed to parse safetensors")?; - let tensor = safet.tensor("embeddings").or_else(|_| safet.tensor("0")).context("Embedding tensor not found")?; + // Read safetensors + let bytes = read(&mdl_path).context("Read safetensors failed")?; + let safet = SafeTensors::deserialize(&bytes).context("Parse safetensors failed")?; + let tensor = safet.tensor("embeddings").or_else(|_| safet.tensor("0")) + .context("No 'embeddings' tensor")?; let shape = (tensor.shape()[0] as usize, tensor.shape()[1] as usize); let raw = tensor.data(); let dtype = tensor.dtype(); - // Read config.json for normalization flag - let cfg_bytes = read(&cfg_path).context("Failed to read config.json")?; - let cfg: Value = serde_json::from_slice(&cfg_bytes).context("Failed to parse config.json")?; - let normalize = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true); - - // Decode raw bytes into Vec based on dtype + // Decode raw data to f32 let floats: Vec = match dtype { Dtype::F32 => raw.chunks_exact(4) - .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) - .collect(), + .map(|b| f32::from_le_bytes([b[0],b[1],b[2],b[3]])).collect(), Dtype::F16 => raw.chunks_exact(2) - .map(|b| f16::from_le_bytes([b[0], b[1]]).to_f32()) - .collect(), - Dtype::I8 => raw.iter() - .map(|&b| (b as i8) as f32) - .collect(), - other => return Err(anyhow!("Unsupported tensor dtype: {:?}", other)), + .map(|b| f16::from_le_bytes([b[0],b[1]]).to_f32()).collect(), + Dtype::I8 => raw.iter().map(|&b| b as i8 as f32).collect(), + other => return Err(anyhow!("Unsupported dtype: {:?}", other)), }; + let embeddings = Array2::from_shape_vec(shape, floats) + .context("Array shape error")?; - // Construct ndarray - let embeddings = Array2::from_shape_vec(shape, floats).context("Failed to create embeddings array")?; - - Ok(Self { tokenizer, embeddings, normalize }) + Ok(Self { tokenizer, embeddings, normalize, median_token_length }) } - /// Tokenize input texts into token ID sequences - pub fn tokenize(&self, texts: &[String]) -> Vec> { - texts.iter().map(|text| { - let enc = self.tokenizer.encode(text.as_str(), false).expect("Tokenization failed"); - enc.get_ids().to_vec() + /// Tokenize input texts into token ID sequences with optional truncation. + pub fn tokenize(&self, texts: &[String], max_length: Option) -> Vec> { + let prepared: Vec = texts.iter().map(|t| { + if let Some(max) = max_length { + t.chars().take(max.saturating_mul(self.median_token_length)).collect() + } else { t.clone() } + }).collect(); + let encs = self.tokenizer.encode_batch(prepared, false).expect("Tokenization failed"); + encs.into_iter().map(|enc| { + let mut ids = enc.get_ids().to_vec(); if let Some(max) = max_length { ids.truncate(max); } ids }).collect() } - /// Encode texts into embeddings via mean-pooling and optional L2-normalization - pub fn encode(&self, texts: &[String]) -> Vec> { - texts.iter().map(|text| { - let enc = self.tokenizer.encode(text.as_str(), false).expect("Tokenization failed"); - let ids = enc.get_ids(); - let mut sum = vec![0.0f32; self.embeddings.ncols()]; - for &id in ids { - let row = self.embeddings.row(id as usize); - for (i, &v) in row.iter().enumerate() { - sum[i] += v; + /// Encode texts into embeddings. + /// + /// # Arguments + /// * `texts` - slice of input strings + /// * `show_progress` - whether to print batch progress + /// * `max_length` - max tokens per text (truncation) + /// * `batch_size` - number of texts per batch + /// * `use_parallel` - use Rayon parallelism + /// * `parallel_threshold` - minimum texts to enable parallelism + pub fn encode_with_args( + &self, + texts: &[String], + show_progress: bool, + max_length: Option, + batch_size: usize, + use_multiprocessing: bool, + multiprocessing_threshold: usize, + ) -> Vec> { + let total = texts.len(); + let num_batches = (total + batch_size - 1) / batch_size; + let iter = texts.chunks(batch_size); + + if use_multiprocessing && total > multiprocessing_threshold { + // disable tokenizer internal parallel + env::set_var("TOKENIZERS_PARALLELISM", "false"); + iter + .enumerate() + .flat_map(|(b, chunk)| { + if show_progress { eprintln!("Batch {}/{}", b+1, num_batches); } + self.tokenize(chunk, max_length) + .into_par_iter() + .map(|ids| self.pool_ids(ids)) + .collect::>() + }) + .collect() + } else { + let mut out = Vec::with_capacity(total); + for (b, chunk) in iter.enumerate() { + if show_progress { eprintln!("Batch {}/{}", b+1, num_batches); } + for ids in self.tokenize(chunk, max_length) { + out.push(self.pool_ids(ids)); } } - let count = ids.len().max(1) as f32; - sum.iter_mut().for_each(|v| *v /= count); - if self.normalize { - let norm = sum.iter().map(|&x| x * x).sum::().sqrt().max(1e-12); - sum.iter_mut().for_each(|v| *v /= norm); - } - sum - }).collect() + out + } + } + + /// Default encode: no progress, max_length=512, batch_size=1024, no parallel. + pub fn encode(&self, texts: &[String]) -> Vec> { + self.encode_with_args(texts, false, Some(512), 1024, true, 10_000) + } + + /// Mean-pool one ID list to embedding + fn pool_ids(&self, ids: Vec) -> Vec { + let mut sum = vec![0.0; self.embeddings.ncols()]; + for &id in &ids { + let row = self.embeddings.row(id as usize); + for (i, &v) in row.iter().enumerate() { sum[i] += v; } + } + let cnt = ids.len().max(1) as f32; + sum.iter_mut().for_each(|v| *v /= cnt); + if self.normalize { + let norm = sum.iter().map(|&x| x*x).sum::().sqrt().max(1e-12); + sum.iter_mut().for_each(|v| *v /= norm); + } + sum } } + diff --git a/tests/common.rs b/tests/common.rs new file mode 100644 index 0000000..de727f4 --- /dev/null +++ b/tests/common.rs @@ -0,0 +1,12 @@ +use model2vec_rs::model::StaticModel; + +/// Load the small float32 test model from fixtures +pub fn load_test_model() -> StaticModel { + StaticModel::from_pretrained( + "tests/fixtures/test-model-float32", + None, // token + None, // normalize + None, // subfolder + ) + .expect("Failed to load test model") +} \ No newline at end of file diff --git a/tests/test_encode.rs b/tests/test_encode.rs deleted file mode 100644 index e9d86cf..0000000 --- a/tests/test_encode.rs +++ /dev/null @@ -1,29 +0,0 @@ -use approx::assert_relative_eq; -use std::fs; -use serde_json::Value; -use model2vec_rs::model::StaticModel; - -#[test] -fn test_encode_hello_against_fixture() { - // Load the embeddings generated by the Python Model2Vec library - let fixture = fs::read_to_string("tests/fixtures/embeddings.json") - .expect("Fixture not found"); - let expected: Vec> = serde_json::from_str(&fixture) - .expect("Failed to parse fixture"); - - // Load the model - let model = StaticModel::from_pretrained("tests/fixtures/test-model-float32", None) - .expect("Failed to load model"); - - // Encode the same input used to generate the fixture - let output = model.encode(&["hello world".to_string()]); - - // Verify dimensions - assert_eq!(output.len(), expected.len()); - assert_eq!(output[0].len(), expected[0].len()); - - // Compare element-wise within tolerance - for (o, e) in output[0].iter().zip(expected[0].iter()) { - assert_relative_eq!(o, e, max_relative = 1e-5); - } -} diff --git a/tests/test_load.rs b/tests/test_load.rs index 538a54e..e93d4a6 100644 --- a/tests/test_load.rs +++ b/tests/test_load.rs @@ -1,10 +1,15 @@ use approx::assert_relative_eq; use model2vec_rs::model::StaticModel; -fn encode_hello(path: &str) -> Vec { +fn encode_with_model(path: &str) -> Vec { // Helper function to load the model and encode "hello world" - let model = StaticModel::from_pretrained(path, None) - .expect(&format!("Failed to load model at {}", path)); + let model = StaticModel::from_pretrained( + path, + None, + None, + None, + ).expect(&format!("Failed to load model at {}", path)); + let out = model.encode(&["hello world".to_string()]); assert_eq!(out.len(), 1); out.into_iter().next().unwrap() @@ -14,11 +19,11 @@ fn encode_hello(path: &str) -> Vec { fn quantized_models_match_float32() { // Compare quantized models against the float32 model let base = "tests/fixtures/test-model-float32"; - let ref_emb = encode_hello(base); + let ref_emb = encode_with_model(base); for quant in &["float16", "int8"] { let path = format!("tests/fixtures/test-model-{}", quant); - let emb = encode_hello(&path); + let emb = encode_with_model(&path); assert_eq!(emb.len(), ref_emb.len()); diff --git a/tests/test_model.rs b/tests/test_model.rs new file mode 100644 index 0000000..01353e9 --- /dev/null +++ b/tests/test_model.rs @@ -0,0 +1,74 @@ +mod common; +use common::load_test_model; +use approx::assert_relative_eq; +use std::fs; +use model2vec_rs::model::StaticModel; + +/// Test that encoding "hello world" matches the Python-generated fixture +#[test] +fn test_encode_matches_python_model2vec() { + let fixture = fs::read_to_string("tests/fixtures/embeddings.json") + .expect("Fixture not found"); + let expected: Vec> = serde_json::from_str(&fixture) + .expect("Failed to parse fixture"); + let model = load_test_model(); + let output = model.encode(&["hello world".to_string()]); + assert_eq!(output.len(), expected.len()); + assert_eq!(output[0].len(), expected[0].len()); + for (o, e) in output[0].iter().zip(expected[0].iter()) { + assert_relative_eq!(o, e, max_relative = 1e-5); + } +} + +/// Test that encoding an empty input slice yields an empty Vec +#[test] +fn test_encode_empty_input() { + let model = load_test_model(); + let embs: Vec> = model.encode(&[]); + assert!(embs.is_empty(), "Expected no embeddings for empty input"); +} + +/// Test encoding a single empty sentence produces a zero vector with no NaNs +#[test] +fn test_encode_empty_sentence() { + let model = load_test_model(); + let embs = model.encode(&["".to_string()]); + assert_eq!(embs.len(), 1); + let vec = &embs[0]; + assert!(vec.iter().all(|&x| x == 0.0), "All entries should be zero"); +} + +/// Test parallel vs sequential encoding consistency using encode_with_args +#[test] +fn test_encode_parallel_vs_sequential() { + let model = load_test_model(); + let texts: Vec = (0..1000).map(|_| "hello world".to_string()).collect(); + let seq = model.encode_with_args(&texts, false, Some(512), 100, false, 500); + let par = model.encode_with_args(&texts, false, Some(512), 100, true, 500); + assert_eq!(seq.len(), par.len()); + for (s, p) in seq.iter().zip(par.iter()) { + assert_relative_eq!(s.as_slice(), p.as_slice(), max_relative = 1e-6); + } +} + +/// Test override of `normalize` flag in from_pretrained +#[test] +fn test_normalization_flag_override() { + // first load with normalize = true (default in config) + let model_norm = StaticModel::from_pretrained( + "tests/fixtures/test-model-float32", None, None, None + ).unwrap(); + let emb_norm = model_norm.encode(&["test sentence".to_string()])[0].clone(); + let norm_norm = emb_norm.iter().map(|&x| x*x).sum::().sqrt(); + + // now load with normalize = false override + let model_no_norm = StaticModel::from_pretrained( + "tests/fixtures/test-model-float32", None, Some(false), None + ).unwrap(); + let emb_no = model_no_norm.encode(&["test sentence".to_string()])[0].clone(); + let norm_no = emb_no.iter().map(|&x| x*x).sum::().sqrt(); + + // normalized version should have unit length, override should give larger norm + assert!((norm_norm - 1.0).abs() < 1e-5, "Normalized vector should have unit norm"); + assert!(norm_no > norm_norm, "Without normalization override, norm should be larger"); +}