Skip to content

Commit 8074182

Browse files
authored
feat: Added multiprocessing threshold parameter (MinishLab#142)
1 parent ecf022f commit 8074182

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

model2vec/model.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
logger = getLogger(__name__)
2020

21-
_MULTIPROCESSING_THRESHOLD = 10_000 # Minimum number of sentences to use multiprocessing
22-
2321

2422
class StaticModel:
2523
def __init__(
@@ -175,6 +173,7 @@ def encode_as_sequence(
175173
batch_size: int = 1024,
176174
show_progress_bar: bool = False,
177175
use_multiprocessing: bool = True,
176+
multiprocessing_threshold: int = 10_000,
178177
) -> list[np.ndarray] | np.ndarray:
179178
"""
180179
Encode a list of sentences as a list of numpy arrays of tokens.
@@ -191,7 +190,8 @@ def encode_as_sequence(
191190
:param batch_size: The batch size to use.
192191
:param show_progress_bar: Whether to show the progress bar.
193192
:param use_multiprocessing: Whether to use multiprocessing.
194-
By default, this is enabled for inputs > 10k sentences and disabled otherwise.
193+
By default, this is enabled for inputs > multiprocessing_threshold sentences and disabled otherwise.
194+
:param multiprocessing_threshold: The threshold in number of sentences for using multiprocessing.
195195
:return: The encoded sentences with an embedding per token.
196196
"""
197197
was_single = False
@@ -204,7 +204,7 @@ def encode_as_sequence(
204204
total_batches = math.ceil(len(sentences) / batch_size)
205205

206206
# Use joblib for multiprocessing if requested, and if we have enough sentences
207-
if use_multiprocessing and len(sentences) > _MULTIPROCESSING_THRESHOLD:
207+
if use_multiprocessing and len(sentences) > multiprocessing_threshold:
208208
# Disable parallelism for tokenizers
209209
os.environ["TOKENIZERS_PARALLELISM"] = "false"
210210

@@ -246,6 +246,7 @@ def encode(
246246
max_length: int | None = 512,
247247
batch_size: int = 1024,
248248
use_multiprocessing: bool = True,
249+
multiprocessing_threshold: int = 10_000,
249250
**kwargs: Any,
250251
) -> np.ndarray:
251252
"""
@@ -260,7 +261,8 @@ def encode(
260261
If this is None, no truncation is done.
261262
:param batch_size: The batch size to use.
262263
:param use_multiprocessing: Whether to use multiprocessing.
263-
By default, this is enabled for inputs > 10k sentences and disabled otherwise.
264+
By default, this is enabled for inputs > multiprocessing_threshold sentences and disabled otherwise.
265+
:param multiprocessing_threshold: The threshold in number of sentences for using multiprocessing.
264266
:param **kwargs: Any additional arguments. These are ignored.
265267
:return: The encoded sentences. If a single sentence was passed, a vector is returned.
266268
"""
@@ -276,7 +278,7 @@ def encode(
276278
ids = self.tokenize(sentences=sentences, max_length=max_length)
277279

278280
# Use joblib for multiprocessing if requested, and if we have enough sentences
279-
if use_multiprocessing and len(sentences) > _MULTIPROCESSING_THRESHOLD:
281+
if use_multiprocessing and len(sentences) > multiprocessing_threshold:
280282
# Disable parallelism for tokenizers
281283
os.environ["TOKENIZERS_PARALLELISM"] = "false"
282284

0 commit comments

Comments
 (0)