Skip to content

Commit af6f67c

Browse files
JarbasAlstephantul
andauthored
feat: allow passing validation set explicitly (MinishLab#245)
* feat: allow passing validation set explicitly * fix: typing * fix: input validation * add tests, extra check, pre-commit --------- Co-authored-by: stephantul <[email protected]>
1 parent eb3ac19 commit af6f67c

File tree

2 files changed

+71
-5
lines changed

2 files changed

+71
-5
lines changed

model2vec/train/classifier.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@ def fit(
135135
early_stopping_patience: int | None = 5,
136136
test_size: float = 0.1,
137137
device: str = "auto",
138+
X_val: list[str] | None = None,
139+
y_val: LabelType | None = None,
138140
) -> StaticModelForClassification:
139141
"""
140142
Fit a model.
@@ -146,6 +148,9 @@ def fit(
146148
This function seeds everything with a seed of 42, so the results are reproducible.
147149
It also splits the data into a train and validation set, again with a random seed.
148150
151+
If `X_val` and `y_val` are not provided, the function will automatically
152+
split the training data into a train and validation set using `test_size`.
153+
149154
:param X: The texts to train on.
150155
:param y: The labels to train on. If the first element is a list, multi-label classification is assumed.
151156
:param learning_rate: The learning rate.
@@ -157,7 +162,10 @@ def fit(
157162
If this is None, early stopping is disabled.
158163
:param test_size: The test size for the train-test split.
159164
:param device: The device to train on. If this is "auto", the device is chosen automatically.
165+
:param X_val: The texts to be used for validation.
166+
:param y_val: The labels to be used for validation.
160167
:return: The fitted model.
168+
:raises ValueError: If either X_val or y_val are provided, but not both.
161169
"""
162170
pl.seed_everything(_RANDOM_SEED)
163171
logger.info("Re-initializing model.")
@@ -166,11 +174,24 @@ def fit(
166174

167175
self._initialize(y)
168176

169-
train_texts, validation_texts, train_labels, validation_labels = self._train_test_split(
170-
X,
171-
y,
172-
test_size=test_size,
173-
)
177+
if (X_val is not None) != (y_val is not None):
178+
raise ValueError("Both X_val and y_val must be provided together, or neither.")
179+
180+
if X_val is not None and y_val is not None:
181+
# Additional check to ensure y_val is of the same type as y
182+
if type(y_val[0]) != type(y[0]):
183+
raise ValueError("X_val and y_val must be of the same type as X and y.")
184+
185+
train_texts = X
186+
train_labels = y
187+
validation_texts = X_val
188+
validation_labels = y_val
189+
else:
190+
train_texts, validation_texts, train_labels, validation_labels = self._train_test_split(
191+
X,
192+
y,
193+
test_size=test_size,
194+
)
174195

175196
if batch_size is None:
176197
# Set to a multiple of 32

tests/test_trainable.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
import torch
66
from tokenizers import Tokenizer
7+
from transformers import AutoTokenizer
78

89
from model2vec.model import StaticModel
910
from model2vec.train import StaticModelForClassification
@@ -154,6 +155,50 @@ def test_train_test_split(mock_trained_pipeline: StaticModelForClassification) -
154155
assert len(d) == len(b)
155156

156157

158+
def test_y_val_none() -> None:
159+
"""Test the y_val function."""
160+
tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer
161+
torch.random.manual_seed(42)
162+
vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12)
163+
model = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu")
164+
165+
X = ["dog", "cat"]
166+
y = ["0", "1"]
167+
168+
X_val = ["dog", "cat"]
169+
y_val = ["0", "1"]
170+
171+
with pytest.raises(ValueError):
172+
model.fit(X, y, X_val=X_val, y_val=None)
173+
with pytest.raises(ValueError):
174+
model.fit(X, y, X_val=None, y_val=y_val)
175+
model.fit(X, y, X_val=None, y_val=None)
176+
177+
178+
@pytest.mark.parametrize(
179+
"y_multi,y_val_multi,should_crash",
180+
[[True, True, False], [False, False, False], [True, False, True], [False, True, True]],
181+
)
182+
def test_y_val(y_multi: bool, y_val_multi: bool, should_crash: bool) -> None:
183+
"""Test the y_val function."""
184+
tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer
185+
torch.random.manual_seed(42)
186+
vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12)
187+
model = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu")
188+
189+
X = ["dog", "cat"]
190+
y = [["0", "1"], ["0"]] if y_multi else ["0", "1"] # type: ignore
191+
192+
X_val = ["dog", "cat"]
193+
y_val = [["0", "1"], ["0"]] if y_val_multi else ["0", "1"] # type: ignore
194+
195+
if should_crash:
196+
with pytest.raises(ValueError):
197+
model.fit(X, y, X_val=X_val, y_val=y_val)
198+
else:
199+
model.fit(X, y, X_val=X_val, y_val=y_val)
200+
201+
157202
def test_evaluate(mock_trained_pipeline: StaticModelForClassification) -> None:
158203
"""Test the evaluate function."""
159204
if mock_trained_pipeline.multilabel:

0 commit comments

Comments
 (0)