Skip to content

feat: Added optional arguments for encode and from_pretrained #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ version = "0.1.0"
edition = "2021"
description = "Fast State-of-the-Art Static Embeddings in Rust"
readme = "README.md"
license = "MIT"
license-file = "LICENSE"
authors = ["Thomas van Dongen <[email protected]>", "Stéphan Tulkens <[email protected]>"]
homepage = "https://github.com/MinishLab/model2vec-rs"
Expand All @@ -21,6 +20,7 @@ clap = { version = "4.0", features = ["derive"] }
anyhow = "1.0"
serde_json = "1.0"
half = "2.0"
rayon = "1.7"

[dev-dependencies]
approx = "0.5"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use model2vec_rust::inference::StaticModel;

fn main() -> Result<()> {
// Load a model from the Hugging Face Hub or a local path
let model = StaticModel::from_pretrained("minishlab/potion-base-8M", None)?;
let model = StaticModel::from_pretrained("minishlab/potion-base-8M", None, None, None)?;

// Prepare a list of sentences
let texts = vec![
Expand Down
7 changes: 4 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ fn main() -> Result<()> {
vec![input]
};

let m = StaticModel::from_pretrained(&model, None)?;
let m = StaticModel::from_pretrained(&model, None, None, None)?;
let embs = m.encode(&texts);

if let Some(path) = output {
Expand All @@ -67,8 +67,9 @@ fn main() -> Result<()> {
vec![input]
};

let m = StaticModel::from_pretrained(&model, None)?;
let ids = m.tokenize(&texts);
let m = StaticModel::from_pretrained(&model, None, None, None)?;
// Provide default None for max_tokens to include all tokens
let ids = m.tokenize(&texts, None);
println!("Token ID sequences: {:#?}", ids);
}
}
Expand Down
184 changes: 124 additions & 60 deletions src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use tokenizers::Tokenizer;
use safetensors::{SafeTensors, tensor::Dtype};
use half::f16;
use ndarray::Array2;
use std::fs::read;
use std::path::Path;
use rayon::prelude::*;
use std::{fs::read, path::Path, env};
use anyhow::{Result, Context, anyhow};
use serde_json::Value;

Expand All @@ -13,104 +13,168 @@ pub struct StaticModel {
tokenizer: Tokenizer,
embeddings: Array2<f32>,
normalize: bool,
median_token_length: usize,
}

impl StaticModel {
/// Load a Model2Vec model from a local folder or the HF Hub.
///
/// # Arguments
/// * `repo_or_path` - HF repo ID or local filesystem path
/// * `subfolder` - optional subdirectory inside the repo or folder
pub fn from_pretrained(repo_or_path: &str, subfolder: Option<&str>) -> Result<Self> {
pub fn from_pretrained(
repo_or_path: &str,
token: Option<&str>,
normalize: Option<bool>,
subfolder: Option<&str>,
) -> Result<Self> {
// If provided, set HF token for authenticated downloads
if let Some(tok) = token {
env::set_var("HF_HUB_TOKEN", tok);
}

// Determine file paths
let (tok_path, mdl_path, cfg_path) = {
let base = Path::new(repo_or_path);
if base.exists() {
// Local path
let folder = subfolder.map(|s| base.join(s)).unwrap_or_else(|| base.to_path_buf());
let t = folder.join("tokenizer.json");
let m = folder.join("model.safetensors");
let c = folder.join("config.json");
if !t.exists() || !m.exists() || !c.exists() {
return Err(anyhow!("Local path {:?} missing tokenizer/model/config files", folder));
return Err(anyhow!("Local path {:?} missing files", folder));
}
(t, m, c)
} else {
// HF Hub path
let api = Api::new().context("Failed to initialize HF Hub API")?;
let api = Api::new().context("HF Hub API init failed")?;
let repo = api.model(repo_or_path.to_string());
// note: token not used with sync Api
let prefix = subfolder.map(|s| format!("{}/", s)).unwrap_or_default();
let t = repo.get(&format!("{}tokenizer.json", prefix)).context("Failed to download tokenizer.json")?;
let m = repo.get(&format!("{}model.safetensors", prefix)).context("Failed to download model.safetensors")?;
let c = repo.get(&format!("{}config.json", prefix)).context("Failed to download config.json")?;
let t = repo.get(&format!("{}tokenizer.json", prefix))
.context("Download tokenizer.json failed")?;
let m = repo.get(&format!("{}model.safetensors", prefix))
.context("Download model.safetensors failed")?;
let c = repo.get(&format!("{}config.json", prefix))
.context("Download config.json failed")?;
(t.into(), m.into(), c.into())
}
};

// Load tokenizer
let tokenizer = Tokenizer::from_file(&tok_path)
.map_err(|e| anyhow!("Failed to load tokenizer: {}", e))?;
.map_err(|e| anyhow!("Tokenizer load error: {}", e))?;

// Median token length for char-level truncation
let mut lengths: Vec<usize> = tokenizer.get_vocab(false)
.keys().map(|tk| tk.len()).collect();
lengths.sort_unstable();
let median_token_length = *lengths.get(lengths.len() / 2).unwrap_or(&1);

// Read config.json for default normalize
let cfg: Value = serde_json::from_slice(&read(&cfg_path)?)
.context("Parse config.json failed")?;
let config_norm = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true);
let normalize = normalize.unwrap_or(config_norm);

// Read safetensors file
let bytes = read(&mdl_path).context("Failed to read model.safetensors")?;
let safet = SafeTensors::deserialize(&bytes).context("Failed to parse safetensors")?;
let tensor = safet.tensor("embeddings").or_else(|_| safet.tensor("0")).context("Embedding tensor not found")?;
// Read safetensors
let bytes = read(&mdl_path).context("Read safetensors failed")?;
let safet = SafeTensors::deserialize(&bytes).context("Parse safetensors failed")?;
let tensor = safet.tensor("embeddings").or_else(|_| safet.tensor("0"))
.context("No 'embeddings' tensor")?;
let shape = (tensor.shape()[0] as usize, tensor.shape()[1] as usize);
let raw = tensor.data();
let dtype = tensor.dtype();

// Read config.json for normalization flag
let cfg_bytes = read(&cfg_path).context("Failed to read config.json")?;
let cfg: Value = serde_json::from_slice(&cfg_bytes).context("Failed to parse config.json")?;
let normalize = cfg.get("normalize").and_then(Value::as_bool).unwrap_or(true);

// Decode raw bytes into Vec<f32> based on dtype
// Decode raw data to f32
let floats: Vec<f32> = match dtype {
Dtype::F32 => raw.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect(),
.map(|b| f32::from_le_bytes([b[0],b[1],b[2],b[3]])).collect(),
Dtype::F16 => raw.chunks_exact(2)
.map(|b| f16::from_le_bytes([b[0], b[1]]).to_f32())
.collect(),
Dtype::I8 => raw.iter()
.map(|&b| (b as i8) as f32)
.collect(),
other => return Err(anyhow!("Unsupported tensor dtype: {:?}", other)),
.map(|b| f16::from_le_bytes([b[0],b[1]]).to_f32()).collect(),
Dtype::I8 => raw.iter().map(|&b| b as i8 as f32).collect(),
other => return Err(anyhow!("Unsupported dtype: {:?}", other)),
};
let embeddings = Array2::from_shape_vec(shape, floats)
.context("Array shape error")?;

// Construct ndarray
let embeddings = Array2::from_shape_vec(shape, floats).context("Failed to create embeddings array")?;

Ok(Self { tokenizer, embeddings, normalize })
Ok(Self { tokenizer, embeddings, normalize, median_token_length })
}

/// Tokenize input texts into token ID sequences
pub fn tokenize(&self, texts: &[String]) -> Vec<Vec<u32>> {
texts.iter().map(|text| {
let enc = self.tokenizer.encode(text.as_str(), false).expect("Tokenization failed");
enc.get_ids().to_vec()
/// Tokenize input texts into token ID sequences with optional truncation.
pub fn tokenize(&self, texts: &[String], max_length: Option<usize>) -> Vec<Vec<u32>> {
let prepared: Vec<String> = texts.iter().map(|t| {
if let Some(max) = max_length {
t.chars().take(max.saturating_mul(self.median_token_length)).collect()
} else { t.clone() }
}).collect();
let encs = self.tokenizer.encode_batch(prepared, false).expect("Tokenization failed");
encs.into_iter().map(|enc| {
let mut ids = enc.get_ids().to_vec(); if let Some(max) = max_length { ids.truncate(max); } ids
}).collect()
}

/// Encode texts into embeddings via mean-pooling and optional L2-normalization
pub fn encode(&self, texts: &[String]) -> Vec<Vec<f32>> {
texts.iter().map(|text| {
let enc = self.tokenizer.encode(text.as_str(), false).expect("Tokenization failed");
let ids = enc.get_ids();
let mut sum = vec![0.0f32; self.embeddings.ncols()];
for &id in ids {
let row = self.embeddings.row(id as usize);
for (i, &v) in row.iter().enumerate() {
sum[i] += v;
/// Encode texts into embeddings.
///
/// # Arguments
/// * `texts` - slice of input strings
/// * `show_progress` - whether to print batch progress
/// * `max_length` - max tokens per text (truncation)
/// * `batch_size` - number of texts per batch
/// * `use_parallel` - use Rayon parallelism
/// * `parallel_threshold` - minimum texts to enable parallelism
pub fn encode_with_args(
&self,
texts: &[String],
show_progress: bool,
max_length: Option<usize>,
batch_size: usize,
use_multiprocessing: bool,
multiprocessing_threshold: usize,
) -> Vec<Vec<f32>> {
let total = texts.len();
let num_batches = (total + batch_size - 1) / batch_size;
let iter = texts.chunks(batch_size);

if use_multiprocessing && total > multiprocessing_threshold {
// disable tokenizer internal parallel
env::set_var("TOKENIZERS_PARALLELISM", "false");
iter
.enumerate()
.flat_map(|(b, chunk)| {
if show_progress { eprintln!("Batch {}/{}", b+1, num_batches); }
self.tokenize(chunk, max_length)
.into_par_iter()
.map(|ids| self.pool_ids(ids))
.collect::<Vec<_>>()
})
.collect()
} else {
let mut out = Vec::with_capacity(total);
for (b, chunk) in iter.enumerate() {
if show_progress { eprintln!("Batch {}/{}", b+1, num_batches); }
for ids in self.tokenize(chunk, max_length) {
out.push(self.pool_ids(ids));
}
}
let count = ids.len().max(1) as f32;
sum.iter_mut().for_each(|v| *v /= count);
if self.normalize {
let norm = sum.iter().map(|&x| x * x).sum::<f32>().sqrt().max(1e-12);
sum.iter_mut().for_each(|v| *v /= norm);
}
sum
}).collect()
out
}
}

/// Default encode: no progress, max_length=512, batch_size=1024, no parallel.
pub fn encode(&self, texts: &[String]) -> Vec<Vec<f32>> {
self.encode_with_args(texts, false, Some(512), 1024, true, 10_000)
}

/// Mean-pool one ID list to embedding
fn pool_ids(&self, ids: Vec<u32>) -> Vec<f32> {
let mut sum = vec![0.0; self.embeddings.ncols()];
for &id in &ids {
let row = self.embeddings.row(id as usize);
for (i, &v) in row.iter().enumerate() { sum[i] += v; }
}
let cnt = ids.len().max(1) as f32;
sum.iter_mut().for_each(|v| *v /= cnt);
if self.normalize {
let norm = sum.iter().map(|&x| x*x).sum::<f32>().sqrt().max(1e-12);
sum.iter_mut().for_each(|v| *v /= norm);
}
sum
}
}

12 changes: 12 additions & 0 deletions tests/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use model2vec_rs::model::StaticModel;

/// Load the small float32 test model from fixtures
pub fn load_test_model() -> StaticModel {
StaticModel::from_pretrained(
"tests/fixtures/test-model-float32",
None, // token
None, // normalize
None, // subfolder
)
.expect("Failed to load test model")
}
29 changes: 0 additions & 29 deletions tests/test_encode.rs

This file was deleted.

15 changes: 10 additions & 5 deletions tests/test_load.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
use approx::assert_relative_eq;
use model2vec_rs::model::StaticModel;

fn encode_hello(path: &str) -> Vec<f32> {
fn encode_with_model(path: &str) -> Vec<f32> {
// Helper function to load the model and encode "hello world"
let model = StaticModel::from_pretrained(path, None)
.expect(&format!("Failed to load model at {}", path));
let model = StaticModel::from_pretrained(
path,
None,
None,
None,
).expect(&format!("Failed to load model at {}", path));

let out = model.encode(&["hello world".to_string()]);
assert_eq!(out.len(), 1);
out.into_iter().next().unwrap()
Expand All @@ -14,11 +19,11 @@ fn encode_hello(path: &str) -> Vec<f32> {
fn quantized_models_match_float32() {
// Compare quantized models against the float32 model
let base = "tests/fixtures/test-model-float32";
let ref_emb = encode_hello(base);
let ref_emb = encode_with_model(base);

for quant in &["float16", "int8"] {
let path = format!("tests/fixtures/test-model-{}", quant);
let emb = encode_hello(&path);
let emb = encode_with_model(&path);

assert_eq!(emb.len(), ref_emb.len());

Expand Down
Loading