Skip to content

Commit f37b62d

Browse files
committed
Added support for models with external data - #10
1 parent 8319075 commit f37b62d

File tree

4 files changed

+26
-6
lines changed

4 files changed

+26
-6
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 1.2.0 (unreleased)
2+
3+
- Added support for models with external data
4+
15
## 1.1.1 (2024-10-14)
26

37
- Added `audio-classification` pipeline

lib/informers/models.rb

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,15 @@ def self.construct_session(pretrained_model_name_or_path, file_name, **options)
178178
model_file_name = "#{prefix}#{file_name}#{options[:quantized] ? "_quantized" : ""}.onnx"
179179
path = Utils::Hub.get_model_file(pretrained_model_name_or_path, model_file_name, true, **options)
180180

181-
OnnxRuntime::InferenceSession.new(path)
181+
begin
182+
OnnxRuntime::InferenceSession.new(path)
183+
rescue OnnxRuntime::Error => e
184+
raise e unless e.message.include?(".onnx_data")
185+
186+
Utils::Hub.get_model_file(pretrained_model_name_or_path, "#{model_file_name}_data", true, **options)
187+
188+
OnnxRuntime::InferenceSession.new(path)
189+
end
182190
end
183191

184192
def call(model_inputs, **kwargs)

lib/informers/utils/hub.rb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,14 @@ def match(request)
8181
file if file.exists
8282
end
8383

84-
def put(request, buffer)
84+
def put(request, response)
8585
output_path = resolve_path(request)
8686

8787
begin
8888
FileUtils.mkdir_p(File.dirname(output_path))
89-
File.binwrite(output_path, buffer)
89+
File.open(output_path, "wb") do |f|
90+
f.write(response.read(1024 * 1024)) until response.eof?
91+
end
9092
rescue => e
9193
warn "An error occurred while writing the file to cache: #{e}"
9294
end
@@ -189,10 +191,8 @@ def self.get_model_file(path_or_repo_id, filename, fatal = true, **options)
189191
to_cache_response = cache && !response.is_a?(FileResponse) && response.status[0] == "200"
190192
end
191193

192-
buffer = response.read
193-
194194
if to_cache_response && cache_key && cache.match(cache_key).nil?
195-
cache.put(cache_key, buffer)
195+
cache.put(cache_key, response)
196196
end
197197

198198
Utils.dispatch_callback(options[:progress_callback], {

test/model_test.rb

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,14 @@ def test_all_mpnet
174174
assert_elements_in_delta [0.04170236, 0.00109747, -0.01553415], embeddings[1][..2]
175175
end
176176

177+
# https://huggingface.co/BAAI/bge-m3
178+
def test_bge_m3
179+
sentences = ["This is an example sentence", "Each sentence is converted"]
180+
181+
model = Informers.pipeline("embedding", "BAAI/bge-m3")
182+
model.(sentences, model_output: "token_embeddings")
183+
end
184+
177185
# https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1
178186
def test_mxbai_rerank
179187
query = "How many people live in London?"

0 commit comments

Comments
 (0)