Skip to content

Commit 77f16df

Browse files
authored
fix: typing issues, bug in infernece (MinishLab#224)
1 parent 39f02f6 commit 77f16df

File tree

2 files changed

+42
-13
lines changed

2 files changed

+42
-13
lines changed

model2vec/inference/model.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import re
44
from pathlib import Path
55
from tempfile import TemporaryDirectory
6-
from typing import TypeVar
6+
from typing import Sequence, TypeVar
77

88
import huggingface_hub
99
import numpy as np
@@ -65,11 +65,12 @@ def save_pretrained(self, path: str) -> None:
6565
"""Save the model to a folder."""
6666
save_pipeline(self, path)
6767

68-
def push_to_hub(self, repo_id: str, token: str | None = None, private: bool = False) -> None:
68+
def push_to_hub(self, repo_id: str, subfolder: str, token: str | None = None, private: bool = False) -> None:
6969
"""
7070
Save a model to a folder, and then push that folder to the hf hub.
7171
7272
:param repo_id: The id of the repository to push to.
73+
:param subfolder: The subfolder to push to.
7374
:param token: The token to use to push to the hub.
7475
:param private: Whether the repository should be private.
7576
"""
@@ -78,11 +79,11 @@ def push_to_hub(self, repo_id: str, token: str | None = None, private: bool = Fa
7879
with TemporaryDirectory() as temp_dir:
7980
save_pipeline(self, temp_dir)
8081
self.model.save_pretrained(temp_dir)
81-
push_folder_to_hub(Path(temp_dir), repo_id, private, token)
82+
push_folder_to_hub(Path(temp_dir), subfolder, repo_id, private, token)
8283

8384
def _encode_and_coerce_to_2d(
8485
self,
85-
X: list[str] | str,
86+
X: Sequence[str],
8687
show_progress_bar: bool,
8788
max_length: int | None,
8889
batch_size: int,
@@ -105,7 +106,7 @@ def _encode_and_coerce_to_2d(
105106

106107
def predict(
107108
self,
108-
X: list[str] | str,
109+
X: Sequence[str],
109110
show_progress_bar: bool = False,
110111
max_length: int | None = 512,
111112
batch_size: int = 1024,
@@ -145,7 +146,7 @@ def predict(
145146

146147
def predict_proba(
147148
self,
148-
X: list[str] | str,
149+
X: Sequence[str],
149150
show_progress_bar: bool = False,
150151
max_length: int | None = 512,
151152
batch_size: int = 1024,
@@ -175,7 +176,7 @@ def predict_proba(
175176
return self.head.predict_proba(encoded)
176177

177178
def evaluate(
178-
self, X: list[str], y: LabelType, batch_size: int = 1024, threshold: float = 0.5, output_dict: bool = False
179+
self, X: Sequence[str], y: LabelType, batch_size: int = 1024, threshold: float = 0.5, output_dict: bool = False
179180
) -> str | dict[str, dict[str, float]]:
180181
"""
181182
Evaluate the classifier on a given dataset using scikit-learn's classification report.

model2vec/model.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from logging import getLogger
66
from pathlib import Path
77
from tempfile import TemporaryDirectory
8-
from typing import Any, Iterator, Union
8+
from typing import Any, Iterator, Sequence, Union, overload
99

1010
import numpy as np
1111
from joblib import delayed
@@ -117,7 +117,7 @@ def save_pretrained(self, path: PathLike, model_name: str | None = None, subfold
117117
subfolder=subfolder,
118118
)
119119

120-
def tokenize(self, sentences: list[str], max_length: int | None = None) -> list[list[int]]:
120+
def tokenize(self, sentences: Sequence[str], max_length: int | None = None) -> list[list[int]]:
121121
"""
122122
Tokenize a list of sentences.
123123
@@ -245,9 +245,31 @@ def from_sentence_transformers(
245245
language=metadata.get("language"),
246246
)
247247

248+
@overload
248249
def encode_as_sequence(
249250
self,
250-
sentences: list[str] | str,
251+
sentences: str,
252+
max_length: int | None = None,
253+
batch_size: int = 1024,
254+
show_progress_bar: bool = False,
255+
use_multiprocessing: bool = True,
256+
multiprocessing_threshold: int = 10_000,
257+
) -> np.ndarray: ...
258+
259+
@overload
260+
def encode_as_sequence(
261+
self,
262+
sentences: list[str],
263+
max_length: int | None = None,
264+
batch_size: int = 1024,
265+
show_progress_bar: bool = False,
266+
use_multiprocessing: bool = True,
267+
multiprocessing_threshold: int = 10_000,
268+
) -> list[np.ndarray]: ...
269+
270+
def encode_as_sequence(
271+
self,
272+
sentences: str | list[str],
251273
max_length: int | None = None,
252274
batch_size: int = 1024,
253275
show_progress_bar: bool = False,
@@ -263,6 +285,9 @@ def encode_as_sequence(
263285
This is about twice as slow.
264286
Sentences that do not contain any tokens will be turned into an empty array.
265287
288+
NOTE: the input type is currently underspecified. The actual input type is `Sequence[str] | str`, but this
289+
is not possible to implement in python typing currently.
290+
266291
:param sentences: The list of sentences to encode.
267292
:param max_length: The maximum length of the sentences. Any tokens beyond this length will be truncated.
268293
If this is None, no truncation is done.
@@ -320,7 +345,7 @@ def _encode_batch_as_sequence(self, sentences: list[str], max_length: int | None
320345

321346
def encode(
322347
self,
323-
sentences: list[str] | str,
348+
sentences: Sequence[str],
324349
show_progress_bar: bool = False,
325350
max_length: int | None = 512,
326351
batch_size: int = 1024,
@@ -334,6 +359,9 @@ def encode(
334359
This function encodes a list of sentences by averaging the word embeddings of the tokens in the sentence.
335360
For ease of use, we don't batch sentences together.
336361
362+
NOTE: the return type is currently underspecified. In the case of a single string, this returns a 1D array,
363+
but in the case of a list of strings, this returns a 2D array. Not possible to implement in numpy currently.
364+
337365
:param sentences: The list of sentences to encode. You can also pass a single sentence.
338366
:param show_progress_bar: Whether to show the progress bar.
339367
:param max_length: The maximum length of the sentences. Any tokens beyond this length will be truncated.
@@ -378,7 +406,7 @@ def encode(
378406
return out_array[0]
379407
return out_array
380408

381-
def _encode_batch(self, sentences: list[str], max_length: int | None) -> np.ndarray:
409+
def _encode_batch(self, sentences: Sequence[str], max_length: int | None) -> np.ndarray:
382410
"""Encode a batch of sentences."""
383411
ids = self.tokenize(sentences=sentences, max_length=max_length)
384412
out: list[np.ndarray] = []
@@ -396,7 +424,7 @@ def _encode_batch(self, sentences: list[str], max_length: int | None) -> np.ndar
396424
return out_array
397425

398426
@staticmethod
399-
def _batch(sentences: list[str], batch_size: int) -> Iterator[list[str]]:
427+
def _batch(sentences: Sequence[str], batch_size: int) -> Iterator[Sequence[str]]:
400428
"""Batch the sentences into equal-sized."""
401429
return (sentences[i : i + batch_size] for i in range(0, len(sentences), batch_size))
402430

0 commit comments

Comments
 (0)