5
5
from logging import getLogger
6
6
from pathlib import Path
7
7
from tempfile import TemporaryDirectory
8
- from typing import Any , Iterator , Union
8
+ from typing import Any , Iterator , Sequence , Union , overload
9
9
10
10
import numpy as np
11
11
from joblib import delayed
@@ -117,7 +117,7 @@ def save_pretrained(self, path: PathLike, model_name: str | None = None, subfold
117
117
subfolder = subfolder ,
118
118
)
119
119
120
- def tokenize (self , sentences : list [str ], max_length : int | None = None ) -> list [list [int ]]:
120
+ def tokenize (self , sentences : Sequence [str ], max_length : int | None = None ) -> list [list [int ]]:
121
121
"""
122
122
Tokenize a list of sentences.
123
123
@@ -245,9 +245,31 @@ def from_sentence_transformers(
245
245
language = metadata .get ("language" ),
246
246
)
247
247
248
+ @overload
248
249
def encode_as_sequence (
249
250
self ,
250
- sentences : list [str ] | str ,
251
+ sentences : str ,
252
+ max_length : int | None = None ,
253
+ batch_size : int = 1024 ,
254
+ show_progress_bar : bool = False ,
255
+ use_multiprocessing : bool = True ,
256
+ multiprocessing_threshold : int = 10_000 ,
257
+ ) -> np .ndarray : ...
258
+
259
+ @overload
260
+ def encode_as_sequence (
261
+ self ,
262
+ sentences : list [str ],
263
+ max_length : int | None = None ,
264
+ batch_size : int = 1024 ,
265
+ show_progress_bar : bool = False ,
266
+ use_multiprocessing : bool = True ,
267
+ multiprocessing_threshold : int = 10_000 ,
268
+ ) -> list [np .ndarray ]: ...
269
+
270
+ def encode_as_sequence (
271
+ self ,
272
+ sentences : str | list [str ],
251
273
max_length : int | None = None ,
252
274
batch_size : int = 1024 ,
253
275
show_progress_bar : bool = False ,
@@ -263,6 +285,9 @@ def encode_as_sequence(
263
285
This is about twice as slow.
264
286
Sentences that do not contain any tokens will be turned into an empty array.
265
287
288
+ NOTE: the input type is currently underspecified. The actual input type is `Sequence[str] | str`, but this
289
+ is not possible to implement in python typing currently.
290
+
266
291
:param sentences: The list of sentences to encode.
267
292
:param max_length: The maximum length of the sentences. Any tokens beyond this length will be truncated.
268
293
If this is None, no truncation is done.
@@ -320,7 +345,7 @@ def _encode_batch_as_sequence(self, sentences: list[str], max_length: int | None
320
345
321
346
def encode (
322
347
self ,
323
- sentences : list [str ] | str ,
348
+ sentences : Sequence [str ],
324
349
show_progress_bar : bool = False ,
325
350
max_length : int | None = 512 ,
326
351
batch_size : int = 1024 ,
@@ -334,6 +359,9 @@ def encode(
334
359
This function encodes a list of sentences by averaging the word embeddings of the tokens in the sentence.
335
360
For ease of use, we don't batch sentences together.
336
361
362
+ NOTE: the return type is currently underspecified. In the case of a single string, this returns a 1D array,
363
+ but in the case of a list of strings, this returns a 2D array. Not possible to implement in numpy currently.
364
+
337
365
:param sentences: The list of sentences to encode. You can also pass a single sentence.
338
366
:param show_progress_bar: Whether to show the progress bar.
339
367
:param max_length: The maximum length of the sentences. Any tokens beyond this length will be truncated.
@@ -378,7 +406,7 @@ def encode(
378
406
return out_array [0 ]
379
407
return out_array
380
408
381
- def _encode_batch (self , sentences : list [str ], max_length : int | None ) -> np .ndarray :
409
+ def _encode_batch (self , sentences : Sequence [str ], max_length : int | None ) -> np .ndarray :
382
410
"""Encode a batch of sentences."""
383
411
ids = self .tokenize (sentences = sentences , max_length = max_length )
384
412
out : list [np .ndarray ] = []
@@ -396,7 +424,7 @@ def _encode_batch(self, sentences: list[str], max_length: int | None) -> np.ndar
396
424
return out_array
397
425
398
426
@staticmethod
399
- def _batch (sentences : list [str ], batch_size : int ) -> Iterator [list [str ]]:
427
+ def _batch (sentences : Sequence [str ], batch_size : int ) -> Iterator [Sequence [str ]]:
400
428
"""Batch the sentences into equal-sized."""
401
429
return (sentences [i : i + batch_size ] for i in range (0 , len (sentences ), batch_size ))
402
430
0 commit comments