|
4 | 4 | import pytest
|
5 | 5 | import torch
|
6 | 6 | from tokenizers import Tokenizer
|
| 7 | +from transformers import AutoTokenizer |
7 | 8 |
|
8 | 9 | from model2vec.model import StaticModel
|
9 | 10 | from model2vec.train import StaticModelForClassification
|
@@ -154,6 +155,50 @@ def test_train_test_split(mock_trained_pipeline: StaticModelForClassification) -
|
154 | 155 | assert len(d) == len(b)
|
155 | 156 |
|
156 | 157 |
|
| 158 | +def test_y_val_none() -> None: |
| 159 | + """Test the y_val function.""" |
| 160 | + tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer |
| 161 | + torch.random.manual_seed(42) |
| 162 | + vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12) |
| 163 | + model = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu") |
| 164 | + |
| 165 | + X = ["dog", "cat"] |
| 166 | + y = ["0", "1"] |
| 167 | + |
| 168 | + X_val = ["dog", "cat"] |
| 169 | + y_val = ["0", "1"] |
| 170 | + |
| 171 | + with pytest.raises(ValueError): |
| 172 | + model.fit(X, y, X_val=X_val, y_val=None) |
| 173 | + with pytest.raises(ValueError): |
| 174 | + model.fit(X, y, X_val=None, y_val=y_val) |
| 175 | + model.fit(X, y, X_val=None, y_val=None) |
| 176 | + |
| 177 | + |
| 178 | +@pytest.mark.parametrize( |
| 179 | + "y_multi,y_val_multi,should_crash", |
| 180 | + [[True, True, False], [False, False, False], [True, False, True], [False, True, True]], |
| 181 | +) |
| 182 | +def test_y_val(y_multi: bool, y_val_multi: bool, should_crash: bool) -> None: |
| 183 | + """Test the y_val function.""" |
| 184 | + tokenizer = AutoTokenizer.from_pretrained("tests/data/test_tokenizer").backend_tokenizer |
| 185 | + torch.random.manual_seed(42) |
| 186 | + vectors_torched = torch.randn(len(tokenizer.get_vocab()), 12) |
| 187 | + model = StaticModelForClassification(vectors=vectors_torched, tokenizer=tokenizer, hidden_dim=12).to("cpu") |
| 188 | + |
| 189 | + X = ["dog", "cat"] |
| 190 | + y = [["0", "1"], ["0"]] if y_multi else ["0", "1"] # type: ignore |
| 191 | + |
| 192 | + X_val = ["dog", "cat"] |
| 193 | + y_val = [["0", "1"], ["0"]] if y_val_multi else ["0", "1"] # type: ignore |
| 194 | + |
| 195 | + if should_crash: |
| 196 | + with pytest.raises(ValueError): |
| 197 | + model.fit(X, y, X_val=X_val, y_val=y_val) |
| 198 | + else: |
| 199 | + model.fit(X, y, X_val=X_val, y_val=y_val) |
| 200 | + |
| 201 | + |
157 | 202 | def test_evaluate(mock_trained_pipeline: StaticModelForClassification) -> None:
|
158 | 203 | """Test the evaluate function."""
|
159 | 204 | if mock_trained_pipeline.multilabel:
|
|
0 commit comments