Skip to content

Commit f86dee7

Browse files
authored
fix: Fixed local distillation (MinishLab#149)
1 parent f06a775 commit f86dee7

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

model2vec/distill/distillation.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import logging
4+
import os
45
import re
56
from typing import Literal, Union
67

@@ -141,13 +142,19 @@ def distill_from_model(
141142
"hidden_dim": embeddings.shape[1],
142143
"seq_length": 1000000, # Set this to a high value since we don't have a sequence length limit.
143144
}
144-
# Get the language from the model card
145-
try:
146-
info = model_info(model_name)
147-
language = info.cardData.get("language")
148-
except RepositoryNotFoundError:
149-
logger.info("No model info found for model. Setting `language` to None.")
145+
146+
if os.path.exists(model_name):
147+
# Using a local model. Get the model name from the path.
148+
model_name = os.path.basename(model_name)
150149
language = None
150+
else:
151+
# Get the language from the model card.
152+
try:
153+
info = model_info(model_name)
154+
language = info.cardData.get("language", None)
155+
except RepositoryNotFoundError:
156+
logger.info("No model info found for the model. Setting language to None.")
157+
language = None
151158

152159
return StaticModel(
153160
vectors=embeddings, tokenizer=new_tokenizer, config=config, base_model_name=model_name, language=language

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)