18
18
19
19
logger = getLogger (__name__ )
20
20
21
- _MULTIPROCESSING_THRESHOLD = 10_000 # Minimum number of sentences to use multiprocessing
22
-
23
21
24
22
class StaticModel :
25
23
def __init__ (
@@ -175,6 +173,7 @@ def encode_as_sequence(
175
173
batch_size : int = 1024 ,
176
174
show_progress_bar : bool = False ,
177
175
use_multiprocessing : bool = True ,
176
+ multiprocessing_threshold : int = 10_000 ,
178
177
) -> list [np .ndarray ] | np .ndarray :
179
178
"""
180
179
Encode a list of sentences as a list of numpy arrays of tokens.
@@ -191,7 +190,8 @@ def encode_as_sequence(
191
190
:param batch_size: The batch size to use.
192
191
:param show_progress_bar: Whether to show the progress bar.
193
192
:param use_multiprocessing: Whether to use multiprocessing.
194
- By default, this is enabled for inputs > 10k sentences and disabled otherwise.
193
+ By default, this is enabled for inputs > multiprocessing_threshold sentences and disabled otherwise.
194
+ :param multiprocessing_threshold: The threshold in number of sentences for using multiprocessing.
195
195
:return: The encoded sentences with an embedding per token.
196
196
"""
197
197
was_single = False
@@ -204,7 +204,7 @@ def encode_as_sequence(
204
204
total_batches = math .ceil (len (sentences ) / batch_size )
205
205
206
206
# Use joblib for multiprocessing if requested, and if we have enough sentences
207
- if use_multiprocessing and len (sentences ) > _MULTIPROCESSING_THRESHOLD :
207
+ if use_multiprocessing and len (sentences ) > multiprocessing_threshold :
208
208
# Disable parallelism for tokenizers
209
209
os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
210
210
@@ -246,6 +246,7 @@ def encode(
246
246
max_length : int | None = 512 ,
247
247
batch_size : int = 1024 ,
248
248
use_multiprocessing : bool = True ,
249
+ multiprocessing_threshold : int = 10_000 ,
249
250
** kwargs : Any ,
250
251
) -> np .ndarray :
251
252
"""
@@ -260,7 +261,8 @@ def encode(
260
261
If this is None, no truncation is done.
261
262
:param batch_size: The batch size to use.
262
263
:param use_multiprocessing: Whether to use multiprocessing.
263
- By default, this is enabled for inputs > 10k sentences and disabled otherwise.
264
+ By default, this is enabled for inputs > multiprocessing_threshold sentences and disabled otherwise.
265
+ :param multiprocessing_threshold: The threshold in number of sentences for using multiprocessing.
264
266
:param **kwargs: Any additional arguments. These are ignored.
265
267
:return: The encoded sentences. If a single sentence was passed, a vector is returned.
266
268
"""
@@ -276,7 +278,7 @@ def encode(
276
278
ids = self .tokenize (sentences = sentences , max_length = max_length )
277
279
278
280
# Use joblib for multiprocessing if requested, and if we have enough sentences
279
- if use_multiprocessing and len (sentences ) > _MULTIPROCESSING_THRESHOLD :
281
+ if use_multiprocessing and len (sentences ) > multiprocessing_threshold :
280
282
# Disable parallelism for tokenizers
281
283
os .environ ["TOKENIZERS_PARALLELISM" ] = "false"
282
284
0 commit comments