Skip to content

Commit 4e32661

Browse files
stesCeliaBenquet
andauthored
Add improved goodness of fit implementation (#190)
* Started implementing improved goodness of fit implementation * add tests and improve implementation * Fix examples * Fix docstring error * Handle batch size = None for goodness of fit computation * adapt GoF implementation * Fix docstring tests * Update docstring for goodness_of_fit_score Co-authored-by: Célia Benquet <[email protected]> * add annotations to goodness_of_fit_history Co-authored-by: Célia Benquet <[email protected]> * fix typo Co-authored-by: Célia Benquet <[email protected]> * improve err message Co-authored-by: Célia Benquet <[email protected]> * make numerical test less conversative * Add tests for exception handling * fix tests --------- Co-authored-by: Célia Benquet <[email protected]>
1 parent 7e74eda commit 4e32661

File tree

2 files changed

+272
-0
lines changed

2 files changed

+272
-0
lines changed

cebra/integrations/sklearn/metrics.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,149 @@ def infonce_loss(
108108
return avg_loss
109109

110110

111+
def goodness_of_fit_score(cebra_model: cebra_sklearn_cebra.CEBRA,
112+
X: Union[npt.NDArray, torch.Tensor],
113+
*y,
114+
session_id: Optional[int] = None,
115+
num_batches: int = 500) -> float:
116+
"""Compute the goodness of fit score on a *single session* dataset on the model.
117+
118+
This function uses the :func:`infonce_loss` function to compute the InfoNCE loss
119+
for a given `cebra_model` and the :func:`infonce_to_goodness_of_fit` function
120+
to derive the goodness of fit from the InfoNCE loss.
121+
122+
Args:
123+
cebra_model: The model to use to compute the InfoNCE loss on the samples.
124+
X: A 2D data matrix, corresponding to a *single session* recording.
125+
y: An arbitrary amount of continuous indices passed as 2D matrices, and up to one
126+
discrete index passed as a 1D array. Each index has to match the length of ``X``.
127+
session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`cebra.CEBRA.num_sessions`
128+
for multisession, set to ``None`` for single session.
129+
num_batches: The number of iterations to consider to evaluate the model on the new data.
130+
Higher values will give a more accurate estimate. Set it to at least 500 iterations.
131+
132+
Returns:
133+
The average GoF score estimated over ``num_batches`` batches from the data distribution.
134+
135+
Related:
136+
:func:`infonce_to_goodness_of_fit`
137+
138+
Example:
139+
140+
>>> import cebra
141+
>>> import numpy as np
142+
>>> neural_data = np.random.uniform(0, 1, (1000, 20))
143+
>>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512)
144+
>>> cebra_model.fit(neural_data)
145+
CEBRA(batch_size=512, max_iterations=10)
146+
>>> gof = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model, neural_data)
147+
"""
148+
loss = infonce_loss(cebra_model,
149+
X,
150+
*y,
151+
session_id=session_id,
152+
num_batches=num_batches,
153+
correct_by_batchsize=False)
154+
return infonce_to_goodness_of_fit(loss, cebra_model)
155+
156+
157+
def goodness_of_fit_history(model: cebra_sklearn_cebra.CEBRA) -> np.ndarray:
158+
"""Return the history of the goodness of fit score.
159+
160+
Args:
161+
model: A trained CEBRA model.
162+
163+
Returns:
164+
A numpy array containing the goodness of fit values, measured in bits.
165+
166+
Related:
167+
:func:`infonce_to_goodness_of_fit`
168+
169+
Example:
170+
171+
>>> import cebra
172+
>>> import numpy as np
173+
>>> neural_data = np.random.uniform(0, 1, (1000, 20))
174+
>>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512)
175+
>>> cebra_model.fit(neural_data)
176+
CEBRA(batch_size=512, max_iterations=10)
177+
>>> gof_history = cebra.sklearn.metrics.goodness_of_fit_history(cebra_model)
178+
"""
179+
infonce = np.array(model.state_dict_["log"]["total"])
180+
return infonce_to_goodness_of_fit(infonce, model)
181+
182+
183+
def infonce_to_goodness_of_fit(
184+
infonce: Union[float, np.ndarray],
185+
model: Optional[cebra_sklearn_cebra.CEBRA] = None,
186+
batch_size: Optional[int] = None,
187+
num_sessions: Optional[int] = None) -> Union[float, np.ndarray]:
188+
"""Given a trained CEBRA model, return goodness of fit metric.
189+
190+
The goodness of fit ranges from 0 (lowest meaningful value)
191+
to a positive number with the unit "bits", the higher the
192+
better.
193+
194+
Values lower than 0 bits are possible, but these only occur
195+
due to numerical effects. A perfectly collapsed embedding
196+
(e.g., because the data cannot be fit with the provided
197+
auxiliary variables) will have a goodness of fit of 0.
198+
199+
The conversion between the generalized InfoNCE metric that
200+
CEBRA is trained with and the goodness of fit computed with this
201+
function is
202+
203+
.. math::
204+
205+
S = \\log N - \\text{InfoNCE}
206+
207+
To use this function, either provide a trained CEBRA model or the
208+
batch size and number of sessions.
209+
210+
Args:
211+
infonce: The InfoNCE loss, either a single value or an iterable of values.
212+
model: The trained CEBRA model.
213+
batch_size: The batch size used to train the model.
214+
num_sessions: The number of sessions used to train the model.
215+
216+
Returns:
217+
Numpy array containing the goodness of fit values, measured in bits
218+
219+
Raises:
220+
RuntimeError: If the provided model is not fit to data.
221+
ValueError: If both ``model`` and ``(batch_size, num_sessions)`` are provided.
222+
"""
223+
if model is not None:
224+
if batch_size is not None or num_sessions is not None:
225+
raise ValueError(
226+
"batch_size and num_sessions should not be provided if model is provided."
227+
)
228+
if not hasattr(model, "state_dict_"):
229+
raise RuntimeError("Fit the CEBRA model first.")
230+
if model.batch_size is None:
231+
raise ValueError(
232+
"Computing the goodness of fit is not yet supported for "
233+
"models trained on the full dataset (batchsize = None). ")
234+
batch_size = model.batch_size
235+
num_sessions = model.num_sessions_
236+
if num_sessions is None:
237+
num_sessions = 1
238+
239+
if model.batch_size is None:
240+
raise ValueError(
241+
"Computing the goodness of fit is not yet supported for "
242+
"models trained on the full dataset (batchsize = None). ")
243+
else:
244+
if batch_size is None or num_sessions is None:
245+
raise ValueError(
246+
f"batch_size ({batch_size}) and num_sessions ({num_sessions})"
247+
f"should be provided if model is not provided.")
248+
249+
nats_to_bits = np.log2(np.e)
250+
chance_level = np.log(batch_size * num_sessions)
251+
return (chance_level - infonce) * nats_to_bits
252+
253+
111254
def _consistency_scores(
112255
embeddings: List[Union[npt.NDArray, torch.Tensor]],
113256
datasets: List[Union[int, str]],

tests/test_sklearn_metrics.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,3 +383,132 @@ def test_sklearn_runs_consistency():
383383
with pytest.raises(ValueError, match="Invalid.*embeddings"):
384384
_, _, _ = cebra_sklearn_metrics.consistency_score(
385385
invalid_embeddings_runs, between="runs")
386+
387+
388+
@pytest.mark.parametrize("seed", [42, 24, 10])
389+
def test_goodness_of_fit_score(seed):
390+
"""
391+
Ensure that the GoF score is close to 0 for a model fit on random data.
392+
"""
393+
cebra_model = cebra_sklearn_cebra.CEBRA(
394+
model_architecture="offset1-model",
395+
max_iterations=5,
396+
batch_size=512,
397+
)
398+
generator = torch.Generator().manual_seed(seed)
399+
X = torch.rand(5000, 50, dtype=torch.float32, generator=generator)
400+
y = torch.rand(5000, 5, dtype=torch.float32, generator=generator)
401+
cebra_model.fit(X, y)
402+
score = cebra_sklearn_metrics.goodness_of_fit_score(cebra_model,
403+
X,
404+
y,
405+
session_id=0,
406+
num_batches=500)
407+
assert isinstance(score, float)
408+
assert np.isclose(score, 0, atol=0.01)
409+
410+
411+
@pytest.mark.parametrize("seed", [42, 24, 10])
412+
def test_goodness_of_fit_history(seed):
413+
"""
414+
Ensure that the GoF score is higher for a model fit on data with underlying
415+
structure than for a model fit on random data.
416+
"""
417+
418+
# Generate data
419+
generator = torch.Generator().manual_seed(seed)
420+
X = torch.rand(1000, 50, dtype=torch.float32, generator=generator)
421+
y_random = torch.rand(len(X), 5, dtype=torch.float32, generator=generator)
422+
linear_map = torch.randn(50, 5, dtype=torch.float32, generator=generator)
423+
y_linear = X @ linear_map
424+
425+
def _fit_and_get_history(X, y):
426+
cebra_model = cebra_sklearn_cebra.CEBRA(
427+
model_architecture="offset1-model",
428+
max_iterations=150,
429+
batch_size=512,
430+
device="cpu")
431+
cebra_model.fit(X, y)
432+
history = cebra_sklearn_metrics.goodness_of_fit_history(cebra_model)
433+
# NOTE(stes): Ignore the first 5 iterations, they can have nonsensical values
434+
# due to numerical issues.
435+
return history[5:]
436+
437+
history_random = _fit_and_get_history(X, y_random)
438+
history_linear = _fit_and_get_history(X, y_linear)
439+
440+
assert isinstance(history_random, np.ndarray)
441+
assert history_random.shape[0] > 0
442+
# NOTE(stes): Ignore the first 5 iterations, they can have nonsensical values
443+
# due to numerical issues.
444+
history_random_non_negative = history_random[history_random >= 0]
445+
np.testing.assert_allclose(history_random_non_negative, 0, atol=0.075)
446+
447+
assert isinstance(history_linear, np.ndarray)
448+
assert history_linear.shape[0] > 0
449+
450+
assert np.all(history_linear[-20:] > history_random[-20:])
451+
452+
453+
@pytest.mark.parametrize("seed", [42, 24, 10])
454+
def test_infonce_to_goodness_of_fit(seed):
455+
"""Test the conversion from InfoNCE loss to goodness of fit metric."""
456+
# Test with model
457+
cebra_model = cebra_sklearn_cebra.CEBRA(
458+
model_architecture="offset10-model",
459+
max_iterations=5,
460+
batch_size=128,
461+
)
462+
generator = torch.Generator().manual_seed(seed)
463+
X = torch.rand(1000, 50, dtype=torch.float32, generator=generator)
464+
cebra_model.fit(X)
465+
466+
# Test single value
467+
gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
468+
model=cebra_model)
469+
assert isinstance(gof, float)
470+
471+
# Test array of values
472+
infonce_values = np.array([1.0, 2.0, 3.0])
473+
gof_array = cebra_sklearn_metrics.infonce_to_goodness_of_fit(
474+
infonce_values, model=cebra_model)
475+
assert isinstance(gof_array, np.ndarray)
476+
assert gof_array.shape == infonce_values.shape
477+
478+
# Test with explicit batch_size and num_sessions
479+
gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
480+
batch_size=128,
481+
num_sessions=1)
482+
assert isinstance(gof, float)
483+
484+
# Test error cases
485+
with pytest.raises(ValueError, match="batch_size.*should not be provided"):
486+
cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
487+
model=cebra_model,
488+
batch_size=128)
489+
490+
with pytest.raises(ValueError, match="batch_size.*should not be provided"):
491+
cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
492+
model=cebra_model,
493+
num_sessions=1)
494+
495+
# Test with unfitted model
496+
unfitted_model = cebra_sklearn_cebra.CEBRA(max_iterations=5)
497+
with pytest.raises(RuntimeError, match="Fit the CEBRA model first"):
498+
cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
499+
model=unfitted_model)
500+
501+
# Test with model having batch_size=None
502+
none_batch_model = cebra_sklearn_cebra.CEBRA(batch_size=None,
503+
max_iterations=5)
504+
none_batch_model.fit(X)
505+
with pytest.raises(ValueError, match="Computing the goodness of fit"):
506+
cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
507+
model=none_batch_model)
508+
509+
# Test missing batch_size or num_sessions when model is None
510+
with pytest.raises(ValueError, match="batch_size.*and num_sessions"):
511+
cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, batch_size=128)
512+
513+
with pytest.raises(ValueError, match="batch_size.*and num_sessions"):
514+
cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, num_sessions=1)

0 commit comments

Comments
 (0)