diff --git a/cebra/integrations/sklearn/metrics.py b/cebra/integrations/sklearn/metrics.py index ccecaa11..0af44ecb 100644 --- a/cebra/integrations/sklearn/metrics.py +++ b/cebra/integrations/sklearn/metrics.py @@ -108,6 +108,149 @@ def infonce_loss( return avg_loss +def goodness_of_fit_score(cebra_model: cebra_sklearn_cebra.CEBRA, + X: Union[npt.NDArray, torch.Tensor], + *y, + session_id: Optional[int] = None, + num_batches: int = 500) -> float: + """Compute the goodness of fit score on a *single session* dataset on the model. + + This function uses the :func:`infonce_loss` function to compute the InfoNCE loss + for a given `cebra_model` and the :func:`infonce_to_goodness_of_fit` function + to derive the goodness of fit from the InfoNCE loss. + + Args: + cebra_model: The model to use to compute the InfoNCE loss on the samples. + X: A 2D data matrix, corresponding to a *single session* recording. + y: An arbitrary amount of continuous indices passed as 2D matrices, and up to one + discrete index passed as a 1D array. Each index has to match the length of ``X``. + session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`cebra.CEBRA.num_sessions` + for multisession, set to ``None`` for single session. + num_batches: The number of iterations to consider to evaluate the model on the new data. + Higher values will give a more accurate estimate. Set it to at least 500 iterations. + + Returns: + The average GoF score estimated over ``num_batches`` batches from the data distribution. + + Related: + :func:`infonce_to_goodness_of_fit` + + Example: + + >>> import cebra + >>> import numpy as np + >>> neural_data = np.random.uniform(0, 1, (1000, 20)) + >>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512) + >>> cebra_model.fit(neural_data) + CEBRA(batch_size=512, max_iterations=10) + >>> gof = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model, neural_data) + """ + loss = infonce_loss(cebra_model, + X, + *y, + session_id=session_id, + num_batches=num_batches, + correct_by_batchsize=False) + return infonce_to_goodness_of_fit(loss, cebra_model) + + +def goodness_of_fit_history(model: cebra_sklearn_cebra.CEBRA) -> np.ndarray: + """Return the history of the goodness of fit score. + + Args: + model: A trained CEBRA model. + + Returns: + A numpy array containing the goodness of fit values, measured in bits. + + Related: + :func:`infonce_to_goodness_of_fit` + + Example: + + >>> import cebra + >>> import numpy as np + >>> neural_data = np.random.uniform(0, 1, (1000, 20)) + >>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512) + >>> cebra_model.fit(neural_data) + CEBRA(batch_size=512, max_iterations=10) + >>> gof_history = cebra.sklearn.metrics.goodness_of_fit_history(cebra_model) + """ + infonce = np.array(model.state_dict_["log"]["total"]) + return infonce_to_goodness_of_fit(infonce, model) + + +def infonce_to_goodness_of_fit( + infonce: Union[float, np.ndarray], + model: Optional[cebra_sklearn_cebra.CEBRA] = None, + batch_size: Optional[int] = None, + num_sessions: Optional[int] = None) -> Union[float, np.ndarray]: + """Given a trained CEBRA model, return goodness of fit metric. + + The goodness of fit ranges from 0 (lowest meaningful value) + to a positive number with the unit "bits", the higher the + better. + + Values lower than 0 bits are possible, but these only occur + due to numerical effects. A perfectly collapsed embedding + (e.g., because the data cannot be fit with the provided + auxiliary variables) will have a goodness of fit of 0. + + The conversion between the generalized InfoNCE metric that + CEBRA is trained with and the goodness of fit computed with this + function is + + .. math:: + + S = \\log N - \\text{InfoNCE} + + To use this function, either provide a trained CEBRA model or the + batch size and number of sessions. + + Args: + infonce: The InfoNCE loss, either a single value or an iterable of values. + model: The trained CEBRA model. + batch_size: The batch size used to train the model. + num_sessions: The number of sessions used to train the model. + + Returns: + Numpy array containing the goodness of fit values, measured in bits + + Raises: + RuntimeError: If the provided model is not fit to data. + ValueError: If both ``model`` and ``(batch_size, num_sessions)`` are provided. + """ + if model is not None: + if batch_size is not None or num_sessions is not None: + raise ValueError( + "batch_size and num_sessions should not be provided if model is provided." + ) + if not hasattr(model, "state_dict_"): + raise RuntimeError("Fit the CEBRA model first.") + if model.batch_size is None: + raise ValueError( + "Computing the goodness of fit is not yet supported for " + "models trained on the full dataset (batchsize = None). ") + batch_size = model.batch_size + num_sessions = model.num_sessions_ + if num_sessions is None: + num_sessions = 1 + + if model.batch_size is None: + raise ValueError( + "Computing the goodness of fit is not yet supported for " + "models trained on the full dataset (batchsize = None). ") + else: + if batch_size is None or num_sessions is None: + raise ValueError( + f"batch_size ({batch_size}) and num_sessions ({num_sessions})" + f"should be provided if model is not provided.") + + nats_to_bits = np.log2(np.e) + chance_level = np.log(batch_size * num_sessions) + return (chance_level - infonce) * nats_to_bits + + def _consistency_scores( embeddings: List[Union[npt.NDArray, torch.Tensor]], datasets: List[Union[int, str]], diff --git a/tests/test_sklearn_metrics.py b/tests/test_sklearn_metrics.py index 58e12010..4e765ba7 100644 --- a/tests/test_sklearn_metrics.py +++ b/tests/test_sklearn_metrics.py @@ -383,3 +383,132 @@ def test_sklearn_runs_consistency(): with pytest.raises(ValueError, match="Invalid.*embeddings"): _, _, _ = cebra_sklearn_metrics.consistency_score( invalid_embeddings_runs, between="runs") + + +@pytest.mark.parametrize("seed", [42, 24, 10]) +def test_goodness_of_fit_score(seed): + """ + Ensure that the GoF score is close to 0 for a model fit on random data. + """ + cebra_model = cebra_sklearn_cebra.CEBRA( + model_architecture="offset1-model", + max_iterations=5, + batch_size=512, + ) + generator = torch.Generator().manual_seed(seed) + X = torch.rand(5000, 50, dtype=torch.float32, generator=generator) + y = torch.rand(5000, 5, dtype=torch.float32, generator=generator) + cebra_model.fit(X, y) + score = cebra_sklearn_metrics.goodness_of_fit_score(cebra_model, + X, + y, + session_id=0, + num_batches=500) + assert isinstance(score, float) + assert np.isclose(score, 0, atol=0.01) + + +@pytest.mark.parametrize("seed", [42, 24, 10]) +def test_goodness_of_fit_history(seed): + """ + Ensure that the GoF score is higher for a model fit on data with underlying + structure than for a model fit on random data. + """ + + # Generate data + generator = torch.Generator().manual_seed(seed) + X = torch.rand(1000, 50, dtype=torch.float32, generator=generator) + y_random = torch.rand(len(X), 5, dtype=torch.float32, generator=generator) + linear_map = torch.randn(50, 5, dtype=torch.float32, generator=generator) + y_linear = X @ linear_map + + def _fit_and_get_history(X, y): + cebra_model = cebra_sklearn_cebra.CEBRA( + model_architecture="offset1-model", + max_iterations=150, + batch_size=512, + device="cpu") + cebra_model.fit(X, y) + history = cebra_sklearn_metrics.goodness_of_fit_history(cebra_model) + # NOTE(stes): Ignore the first 5 iterations, they can have nonsensical values + # due to numerical issues. + return history[5:] + + history_random = _fit_and_get_history(X, y_random) + history_linear = _fit_and_get_history(X, y_linear) + + assert isinstance(history_random, np.ndarray) + assert history_random.shape[0] > 0 + # NOTE(stes): Ignore the first 5 iterations, they can have nonsensical values + # due to numerical issues. + history_random_non_negative = history_random[history_random >= 0] + np.testing.assert_allclose(history_random_non_negative, 0, atol=0.075) + + assert isinstance(history_linear, np.ndarray) + assert history_linear.shape[0] > 0 + + assert np.all(history_linear[-20:] > history_random[-20:]) + + +@pytest.mark.parametrize("seed", [42, 24, 10]) +def test_infonce_to_goodness_of_fit(seed): + """Test the conversion from InfoNCE loss to goodness of fit metric.""" + # Test with model + cebra_model = cebra_sklearn_cebra.CEBRA( + model_architecture="offset10-model", + max_iterations=5, + batch_size=128, + ) + generator = torch.Generator().manual_seed(seed) + X = torch.rand(1000, 50, dtype=torch.float32, generator=generator) + cebra_model.fit(X) + + # Test single value + gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, + model=cebra_model) + assert isinstance(gof, float) + + # Test array of values + infonce_values = np.array([1.0, 2.0, 3.0]) + gof_array = cebra_sklearn_metrics.infonce_to_goodness_of_fit( + infonce_values, model=cebra_model) + assert isinstance(gof_array, np.ndarray) + assert gof_array.shape == infonce_values.shape + + # Test with explicit batch_size and num_sessions + gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, + batch_size=128, + num_sessions=1) + assert isinstance(gof, float) + + # Test error cases + with pytest.raises(ValueError, match="batch_size.*should not be provided"): + cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, + model=cebra_model, + batch_size=128) + + with pytest.raises(ValueError, match="batch_size.*should not be provided"): + cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, + model=cebra_model, + num_sessions=1) + + # Test with unfitted model + unfitted_model = cebra_sklearn_cebra.CEBRA(max_iterations=5) + with pytest.raises(RuntimeError, match="Fit the CEBRA model first"): + cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, + model=unfitted_model) + + # Test with model having batch_size=None + none_batch_model = cebra_sklearn_cebra.CEBRA(batch_size=None, + max_iterations=5) + none_batch_model.fit(X) + with pytest.raises(ValueError, match="Computing the goodness of fit"): + cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, + model=none_batch_model) + + # Test missing batch_size or num_sessions when model is None + with pytest.raises(ValueError, match="batch_size.*and num_sessions"): + cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, batch_size=128) + + with pytest.raises(ValueError, match="batch_size.*and num_sessions"): + cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, num_sessions=1)