Skip to content

Commit 7ae5e1e

Browse files
CeliaBenquetgonlairostesMMathisLabicarosadero
authored
Batched inference CEBRA & padding at the Solver level (#168)
* first proposal for batching in tranform method * first running version of padding with batched inference * start tests * add pad_before_transform to fit function and add support for convolutional models in _transform * remove print statements * first passing test * add support for hybrid models * rewrite transform in sklearn API * baseline version of a torch.Datset * move batching logic outside solver * move functionality to base file in solver and separate in functions * add test_select_model for single session * add checks and test for _process_batch * add test_select_model for multisession * make self.num_sessions compatible with single session training * improve test_batched_transform_singlesession * make it work with small batches * make test with multisession work * change to torch padding * add argument to sklearn api * add torch padding to _transform * convert to torch if numpy array as inputs * add distinction between pad with data and pad with zeros and modify test accordingly * differentiate between data padding and zero padding * remove float16 * change argument position * clean test * clean test * Fix warning * Improve modularity remove duplicate code and todos * Add tests to solver * Remove unused import in solver/utils * Fix test plot * Add some coverage * Fix save/load * Remove duplicate configure_for in multi dataset * Make save/load cleaner * Fix codespell errors * Fix docs compilation errors * Fix formatting * Fix extra docs errors * Fix offset in docs * Remove attribute ref * Add review updates * apply ruff auto-fixes * Concatenate last batches for batched inference (#200) * Concatenate last to batches for batched inference * Add test case * Fix linting errors in tests (#188) * apply auto-fixes * Fix linting errors in tests/ * Fix version check * Fix `scikit-learn` reference in conda environment files (#195) * Add support for new __sklearn_tags__ (#205) * Add support for new __sklearn_tags__ * fix inheritance order * Add more tests * fix added test * Update workflows to actions/setup-python@v5, actions/cache@v4 (#212) * Fix deprecation warning force_all_finite -> ensure_all_finite for sklearn>=1.6 (#206) * Add tests to check legacy model loading (#214) * Add improved goodness of fit implementation (#190) * Started implementing improved goodness of fit implementation * add tests and improve implementation * Fix examples * Fix docstring error * Handle batch size = None for goodness of fit computation * adapt GoF implementation * Fix docstring tests * Update docstring for goodness_of_fit_score Co-authored-by: Célia Benquet <[email protected]> * add annotations to goodness_of_fit_history Co-authored-by: Célia Benquet <[email protected]> * fix typo Co-authored-by: Célia Benquet <[email protected]> * improve err message Co-authored-by: Célia Benquet <[email protected]> * make numerical test less conversative * Add tests for exception handling * fix tests --------- Co-authored-by: Célia Benquet <[email protected]> * Support numpy 2, upgrade tests to support torch 2.6 (#221) * Drop numpy constraint * Implement workaround for pytables * better error message * pin numpy only for python 3.9 * update dependencies * Upgrade torch version * Fix based on python version * Add support for torch.load with weights_only=True * Implement safe loading for torch models starting in torch 2.6 * Fix windows specs * fix docstring * Revert changes to loading logic * Release 0.5.0rc1 (#189) * Make bump_version script runnable on MacOS * Bump version to 0.5.0rc1 * fix minor formatting issues * remove commented code --------- Co-authored-by: Mackenzie Mathis <[email protected]> * Fix pypi action (#222) * force packaging upgrade to 24.2 for twine * Bump version to 0.5.0rc2 * remove universal compatibility option * revert tag * adapt files to new wheel name due to py3 * Update base.py (#224) This is a lazy solution to #223 * Change max consistency value to 100 instead of 99 (#227) * Change text consistency max from 99 to 100 * Update cebra/integrations/matplotlib.py --------- Co-authored-by: Mackenzie Mathis <[email protected]> Co-authored-by: Steffen Schneider <[email protected]> * Update assets.py --> force check for parent dir (#230) Update assets.py - mkdir was failing in 0.5.0rc1; attempt to fix * User docs minor edit (#229) * user note added to usage.rst - link added * Update usage.rst - more detailed note on the effect of temp. * Update usage.rst - add in temp to demo model - testout put thanks @stes * Update docs/source/usage.rst Co-authored-by: Steffen Schneider <[email protected]> * Update docs/source/usage.rst Co-authored-by: Steffen Schneider <[email protected]> * Update docs/source/usage.rst Co-authored-by: Steffen Schneider <[email protected]> --------- Co-authored-by: Steffen Schneider <[email protected]> * General Doc refresher (#232) * Update installation.rst - python 3.9+ * Update index.rst * Update figures.rst * Update index.rst -typo fix * Update usage.rst - update suggestion on data split * Update docs/source/usage.rst Co-authored-by: Steffen Schneider <[email protected]> * Update usage.rst - indent error fixed * Update usage.rst - changed infoNCE to new GoF * Update usage.rst - finx numpy() doctest * Update usage.rst - small typo fix (label) * Update usage.rst --------- Co-authored-by: Steffen Schneider <[email protected]> * render plotly in our docs, show code/doc version (#231) * Update layout.html (#233) * Update conf.py (#234) - adding link to new notebook icon * Refactoring setup.cfg (#228) * Home page landing update (#235) * website refresh * v0.5.0 (#238) * Upgrade docs build (#241) * Improve build setup for docs * update pydata theme options * Add README for docs folder * Fix demo notebook build * Finish build setup * update git workflow * add timeout to workflow * add timeout also to docs build * switch build back to sphinx for gh actions * attempt to fix build workflow * update to sphinx-build * fix build workflow * fix indent error * fix build system * revert demos to main * increase timeout to 30 * Allow indexing of the cebra docs (#242) * Allow indexing of the cebra docs * Fix docs workflow * Fix broken docs coverage workflows (#246) * Add xCEBRA implementation (AISTATS 2025) (#225) * Add multiobjective solver and regularized training (#783) * Add multiobjective solver and regularized training * Add example for multiobjective training * Add jacobian regularizer and SAM * update license headers * add api draft for multiobjective training * add all necessary modules to run the complete xcebra pipeline * add notebooks to reproduce xcebra pipeline * add first working notebook * add notebook with hybrid learning * add notebook with creation of synthetic data * add notebook with hybrid training * add plot with R2 for different parts of the embedding * add new API * update api wrapper with more checks and messages * add tests and notebook with new api * merge xcebra into attribution * separate xcebra dataset from cebra * some minor refactoring of cebra dataset * separate xcebra loader from cebra * remove xcebra distributions from cebra * minor refactoring with distributions * separate xcebra criterions from cebra * minor refactoring on criterion * separate xcebra models/criterions/layers from cebra * refactoring multiobjective * more refactoring... * separate xcebra solvers from cebra * more refactoring * move xcebra to its own package * move more files into xcebra package * more files and remove changes with the registry * remove unncessary import * add folder structure * move back distributions * add missing init * remove wrong init * make loader and dataset run with new imports * making it run! * make attribution run * Run pre-commit * move xcebra repo one level up * update gitignore and add __init__ from data * add init to distributions * add correct init for attribution pacakge * add correct init for model package * fix remaining imports * fix tests * add examples back to xcebra repo * update imports from graphs_xcebra * add setup.py to create a package * update imports of graph_xcebra * update notebooks * Formatting code for submission Co-authored-by: Rodrigo Gonzalez <[email protected]> * move test into xcebra * Add README * move distributions back to main package * clean up examples * adapt tests * Add LICENSE * add train/eval notebook again * add notebook with clean results * rm synthetic data * change name from xcebra to regcl * change names of modules and adapt imports * change name from graphs_xcebra to synthetic_data * Integrate into CEBRA * Fix remaining imports and make notebook runnable * Add dependencies, add version flag * Remove synthetic data files * reset dockerfile, move vmf * apply pre-commit * Update notice * add some docstrings * Apply license headers * add new scd notebook * add notebook with scd --------- Co-authored-by: Steffen Schneider <[email protected]> * Fix tests * bump version * update dockerfile * fix progress bar * remove outdated test * rename models * Apply fixes to pass ruff tests * Fix typos * Update license headers, fix additional ruff errors * remove unused comment * rename regcl in codebase * change regcl name in dockerfile * Improve attribution module * Fix imports name naming * add basic integration test * temp disable of binary check * Add legacy multiobjective model for backward compat * add synth import back in * Fix docstrings and type annot in cebra/models/jacobian_regularizer.py * add xcebra to tests * add missing cvxpy dep * fix docstrings * more docstrings to fix attr error * Improve build setup for docs * update pydata theme options * Add README for docs folder * Fix demo notebook build * Finish build setup * update git workflow * Move demo notebooks to CEBRA-demos repo See AdaptiveMotorControlLab/CEBRA-demos#28 * revert unneeded changes in solver * formatting in solver * further minimize solver diff * Revert unneeded updates to the solver * fix citation * fix docs build, missing refs * remove file dependency from xcebra int test * remove unneeded change in registry * update gitignore * update docs * exclude some assets * include binary file check again * add timeout to workflow * add timeout also to docs build * switch build back to sphinx for gh actions * pin sphinx version in setup.cfg * attempt workflow fix * attempt to fix build workflow * update to sphinx-build * fix build workflow * fix indent error * fix build system * revert demos to main * adapt workflow for testing * bump version to 0.6.0rc1 * format imports * docs writing * enable build on dev branch * fix some review comments * extend multiobjective docs * Set version to alpha * make tempdir platform independent * Remove ratinabox and ephysiopy as deps * Apply review comments * Update Makefile - setting coverage threshold to 80% to not delay good code being made public. In the near future this can be fixed and raised again to 90%. --------- Co-authored-by: Steffen Schneider <[email protected]> Co-authored-by: Steffen Schneider <[email protected]> Co-authored-by: Mackenzie Mathis <[email protected]> * start tests * remove print statements * first passing test * move functionality to base file in solver and separate in functions * add test_select_model for multisession * remove float16 * Improve modularity remove duplicate code and todos * Add tests to solver * Fix save/load * Fix extra docs errors * Add review updates * apply ruff auto-fixes * fix linting errors * Run isort, ruff, yapf * Fix gaussian mixture dataset import * Fix all tests but xcebra tests * Fix pytorch API usage example * Make xCEBRA compatible with the batched inference & padding in solver * Add some tests on transform() with xCEBRA * Add some docstrings and typings and clean unnecessary changes * Implement review comments * Fix sklearn test * Add name in NOTE Co-authored-by: Steffen Schneider <[email protected]> * Implement reviews on tests and typing * Fix import errors * Add select_model to aux solvers * Fix docs error * Add tests on the private functions in base solver * Update tests and duplicate code based on review --------- Co-authored-by: Rodrigo <[email protected]> Co-authored-by: Steffen Schneider <[email protected]> Co-authored-by: Mackenzie Mathis <[email protected]> Co-authored-by: Steffen Schneider <[email protected]> Co-authored-by: Ícaro <[email protected]> Co-authored-by: Mackenzie Mathis <[email protected]> Co-authored-by: Steffen Schneider <[email protected]> Co-authored-by: Rodrigo González Laiz <[email protected]>
1 parent a5814bb commit 7ae5e1e

File tree

17 files changed

+2007
-244
lines changed

17 files changed

+2007
-244
lines changed

cebra/data/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,8 @@ class Loader(abc.ABC, cebra.io.HasDevice):
227227
doc="""A dataset instance specifying a ``__getitem__`` function.""",
228228
)
229229

230+
time_offset: int = dataclasses.field(default=10)
231+
230232
num_steps: int = dataclasses.field(
231233
default=None,
232234
doc=

cebra/data/multi_session.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@
2626

2727
import literate_dataclasses as dataclasses
2828
import torch
29+
import torch.nn as nn
2930

3031
import cebra.data as cebra_data
31-
import cebra.distributions as cebra_distr
32+
import cebra.distributions
3233
from cebra.data.datatypes import Batch
3334
from cebra.data.datatypes import BatchIndex
3435

@@ -104,10 +105,25 @@ def load_batch(self, index: BatchIndex) -> List[Batch]:
104105
) for session_id, session in enumerate(self.iter_sessions())
105106
]
106107

107-
def configure_for(self, model):
108-
self.offset = model.get_offset()
109-
for session in self.iter_sessions():
110-
session.configure_for(model)
108+
def configure_for(self, model: "cebra.models.Model"):
109+
"""Configure the dataset offset for the provided model.
110+
111+
Call this function before indexing the dataset. This sets the
112+
:py:attr:`~.Dataset.offset` attribute of the dataset.
113+
114+
Args:
115+
model: The model to configure the dataset for.
116+
"""
117+
if not isinstance(model, nn.ModuleList):
118+
raise ValueError(
119+
"The model must be a nn.ModuleList to configure the dataset.")
120+
if len(model) != self.num_sessions:
121+
raise ValueError(
122+
f"The model must have {self.num_sessions} sessions, but got {len(model)}."
123+
)
124+
125+
for i, session in enumerate(self.iter_sessions()):
126+
session.configure_for(model[i])
111127

112128

113129
@dataclasses.dataclass
@@ -119,12 +135,10 @@ class MultiSessionLoader(cebra_data.Loader):
119135
dimension, it is better to use a :py:class:`cebra.data.single_session.MixedDataLoader`.
120136
"""
121137

122-
time_offset: int = dataclasses.field(default=10)
123-
124138
def __post_init__(self):
125139
super().__post_init__()
126-
self.sampler = cebra_distr.MultisessionSampler(self.dataset,
127-
self.time_offset)
140+
self.sampler = cebra.distributions.MultisessionSampler(
141+
self.dataset, self.time_offset)
128142

129143
def get_indices(self, num_samples: int) -> List[BatchIndex]:
130144
ref_idx = self.sampler.sample_prior(self.batch_size)
@@ -149,7 +163,6 @@ class ContinuousMultiSessionDataLoader(MultiSessionLoader):
149163
"""Contrastive learning conditioned on a continuous behavior variable."""
150164

151165
conditional: str = "time_delta"
152-
time_offset: int = dataclasses.field(default=10)
153166

154167
@property
155168
def index(self):
@@ -163,7 +176,8 @@ class DiscreteMultiSessionDataLoader(MultiSessionLoader):
163176
# Overwrite sampler with the discrete implementation
164177
# Generalize MultisessionSampler to avoid doing this?
165178
def __post_init__(self):
166-
self.sampler = cebra_distr.DiscreteMultisessionSampler(self.dataset)
179+
self.sampler = cebra.distributions.DiscreteMultisessionSampler(
180+
self.dataset)
167181

168182
@property
169183
def index(self):

cebra/data/single_session.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,6 @@ class ContinuousDataLoader(cebra_data.Loader):
189189
and become equivalent to time contrastive learning.
190190
""",
191191
)
192-
time_offset: int = dataclasses.field(default=10)
193192
delta: float = dataclasses.field(default=0.1)
194193

195194
def __post_init__(self):
@@ -278,7 +277,6 @@ class MixedDataLoader(cebra_data.Loader):
278277
"""
279278

280279
conditional: str = dataclasses.field(default="time_delta")
281-
time_offset: int = dataclasses.field(default=10)
282280

283281
@property
284282
def dindex(self):

cebra/datasets/demo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import cebra.io
3333
from cebra.datasets import register
3434

35-
_DEFAULT_NUM_TIMEPOINTS = 100000
35+
_DEFAULT_NUM_TIMEPOINTS = 1_000
3636

3737

3838
class DemoDataset(cebra.data.SingleSessionDataset):

cebra/integrations/sklearn/cebra.py

Lines changed: 40 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
np.dtypes.Float64DType, np.dtypes.Int64DType
5252
]
5353

54+
5455
def check_version(estimator):
5556
# NOTE(stes): required as a check for the old way of specifying tags
5657
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
@@ -76,7 +77,6 @@ def _safe_torch_load(filename, weights_only, **kwargs):
7677
return checkpoint
7778

7879

79-
8080
def _init_loader(
8181
is_cont: bool,
8282
is_disc: bool,
@@ -129,7 +129,7 @@ def _init_loader(
129129
(not is_cont, not is_disc, is_multi),
130130
]
131131
if any(all(combination) for combination in incompatible_combinations):
132-
raise ValueError(f"Invalid index combination.\n"
132+
raise ValueError("Invalid index combination.\n"
133133
f"Continuous: {is_cont},\n"
134134
f"Discrete: {is_disc},\n"
135135
f"Hybrid training: {is_hybrid},\n"
@@ -293,7 +293,7 @@ def _require_arg(key):
293293
"single-session",
294294
)
295295

296-
error_message = (f"Invalid index combination.\n"
296+
error_message = ("Invalid index combination.\n"
297297
f"Continuous: {is_cont},\n"
298298
f"Discrete: {is_disc},\n"
299299
f"Hybrid training: {is_hybrid},\n"
@@ -340,7 +340,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA":
340340
if missing_keys:
341341
raise ValueError(
342342
f"Missing keys in data dictionary: {', '.join(missing_keys)}. "
343-
f"You can try loading the CEBRA model with the torch backend.")
343+
"You can try loading the CEBRA model with the torch backend.")
344344

345345
args, state, state_dict = cebra_info['args'], cebra_info[
346346
'state'], cebra_info['state_dict']
@@ -656,12 +656,12 @@ def _get_dataset_multi(X: List[Iterable], y: List[Iterable]):
656656
# TODO(celia): to make it work for multiple set of index. For now, y should be a tuple of one list only
657657
if isinstance(y, tuple) and len(y) > 1:
658658
raise NotImplementedError(
659-
f"Support for multiple set of index is not implemented in multissesion training, "
659+
"Support for multiple set of index is not implemented in multissesion training, "
660660
f"got {len(y)} sets of indexes.")
661661

662662
if not _are_sessions_equal(X, y):
663663
raise ValueError(
664-
f"Invalid number of sessions: number of sessions in X and y need to match, "
664+
"Invalid number of sessions: number of sessions in X and y need to match, "
665665
f"got X:{len(X)} and y:{[len(y_i) for y_i in y]}.")
666666

667667
for session in range(len(X)):
@@ -685,8 +685,8 @@ def _get_dataset_multi(X: List[Iterable], y: List[Iterable]):
685685
else:
686686
if not _are_sessions_equal(X, y):
687687
raise ValueError(
688-
f"Invalid number of samples or labels sessions: provide one session for single-session training, "
689-
f"and make sure the number of samples in X and y need match, "
688+
"Invalid number of samples or labels sessions: provide one session for single-session training, "
689+
"and make sure the number of samples in X and y match, "
690690
f"got {len(X)} and {[len(y_i) for y_i in y]}.")
691691
is_multisession = False
692692
dataset = _get_dataset(X, y)
@@ -813,8 +813,6 @@ def _configure_for_all(
813813
"receptive fields/offsets larger than 1 via the sklearn API. "
814814
"Please use a different model, or revert to the pytorch "
815815
"API for training.")
816-
817-
d.configure_for(model[n])
818816
else:
819817
if not isinstance(model, cebra.models.ConvolutionalModelMixin):
820818
if len(model.get_offset()) > 1:
@@ -824,37 +822,13 @@ def _configure_for_all(
824822
"Please use a different model, or revert to the pytorch "
825823
"API for training.")
826824

827-
dataset.configure_for(model)
825+
dataset.configure_for(model)
828826

829827
def _select_model(self, X: Union[npt.NDArray, torch.Tensor],
830828
session_id: int):
831-
# Choose the model and get its corresponding offset
832-
if self.num_sessions is not None: # multisession implementation
833-
if session_id is None:
834-
raise RuntimeError(
835-
"No session_id provided: multisession model requires a session_id to choose the model corresponding to your data shape."
836-
)
837-
if session_id >= self.num_sessions or session_id < 0:
838-
raise RuntimeError(
839-
f"Invalid session_id {session_id}: session_id for the current multisession model must be between 0 and {self.num_sessions-1}."
840-
)
841-
if self.n_features_[session_id] != X.shape[1]:
842-
raise ValueError(
843-
f"Invalid input shape: model for session {session_id} requires an input of shape"
844-
f"(n_samples, {self.n_features_[session_id]}), got (n_samples, {X.shape[1]})."
845-
)
846-
847-
model = self.model_[session_id]
848-
model.to(self.device_)
849-
else: # single session
850-
if session_id is not None and session_id > 0:
851-
raise RuntimeError(
852-
f"Invalid session_id {session_id}: single session models only takes an optional null session_id."
853-
)
854-
model = self.model_
855-
856-
offset = model.get_offset()
857-
return model, offset
829+
if isinstance(X, np.ndarray):
830+
X = torch.from_numpy(X)
831+
return self.solver_._select_model(X, session_id=session_id)
858832

859833
def _check_labels_types(self, y: tuple, session_id: Optional[int] = None):
860834
"""Check that the input labels are compatible with the labels used to fit the model.
@@ -876,7 +850,7 @@ def _check_labels_types(self, y: tuple, session_id: Optional[int] = None):
876850
# Check that same number of index
877851
if len(self.label_types_) != n_idx:
878852
raise ValueError(
879-
f"Number of index invalid: labels must have the same number of index as for fitting,"
853+
"Number of index invalid: labels must have the same number of index as for fitting,"
880854
f"expects {len(self.label_types_)}, got {n_idx} idx.")
881855

882856
for i in range(len(self.label_types_)): # for each index
@@ -889,12 +863,12 @@ def _check_labels_types(self, y: tuple, session_id: Optional[int] = None):
889863
> 1): # is there more than one feature in the index
890864
if label_types_idx[1][1] != y[i].shape[1]:
891865
raise ValueError(
892-
f"Labels invalid: must have the same number of features as the ones used for fitting,"
866+
"Labels invalid: must have the same number of features as the ones used for fitting,"
893867
f"expects {label_types_idx[1]}, got {y[i].shape}.")
894868

895869
if label_types_idx[0] != y[i].dtype:
896870
raise ValueError(
897-
f"Labels invalid: must have the same type of features as the ones used for fitting,"
871+
"Labels invalid: must have the same type of features as the ones used for fitting,"
898872
f"expects {label_types_idx[0]}, got {y[i].dtype}.")
899873

900874
def _prepare_fit(
@@ -1081,14 +1055,13 @@ def _partial_fit(
10811055

10821056
# Save variables of interest as semi-private attributes
10831057
self.model_ = model
1084-
self.n_features_ = ([
1085-
loader.dataset.get_input_dimension(session_id)
1086-
for session_id in range(loader.dataset.num_sessions)
1087-
] if is_multisession else loader.dataset.input_dimension)
1058+
1059+
self.n_features_ = solver.n_features
1060+
self.num_sessions_ = solver.num_sessions if hasattr(
1061+
solver, "num_sessions") else None
10881062
self.solver_ = solver
10891063
self.n_features_in_ = ([model[n].num_input for n in range(len(model))]
10901064
if is_multisession else model.num_input)
1091-
self.num_sessions_ = loader.dataset.num_sessions if is_multisession else None
10921065

10931066
return self
10941067

@@ -1236,11 +1209,13 @@ def fit(
12361209

12371210
def transform(self,
12381211
X: Union[npt.NDArray, torch.Tensor],
1212+
batch_size: Optional[int] = None,
12391213
session_id: Optional[int] = None) -> npt.NDArray:
12401214
"""Transform an input sequence and return the embedding.
12411215
12421216
Args:
12431217
X: A numpy array or torch tensor of size ``time x dimension``.
1218+
batch_size:
12441219
session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for
12451220
multisession, set to ``None`` for single session.
12461221
@@ -1255,37 +1230,28 @@ def transform(self,
12551230
>>> cebra_model = cebra.CEBRA(max_iterations=10)
12561231
>>> cebra_model.fit(dataset)
12571232
CEBRA(max_iterations=10)
1258-
>>> embedding = cebra_model.transform(dataset)
1233+
>>> embedding = cebra_model.transform(dataset, batch_size=200)
12591234
12601235
"""
1261-
12621236
sklearn_utils_validation.check_is_fitted(self, "n_features_")
1263-
model, offset = self._select_model(X, session_id)
1237+
self.solver_._check_is_session_id_valid(session_id=session_id)
12641238

1265-
# Input validation
1266-
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))
1267-
input_dtype = X.dtype
1239+
if torch.is_tensor(X):
1240+
X = X.detach().cpu()
12681241

1269-
with torch.no_grad():
1270-
model.eval()
1271-
1272-
if self.pad_before_transform:
1273-
X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)),
1274-
mode="edge")
1275-
X = torch.from_numpy(X).float().to(self.device_)
1242+
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))
12761243

1277-
if isinstance(model, cebra.models.ConvolutionalModelMixin):
1278-
# Fully convolutional evaluation, switch (T, C) -> (1, C, T)
1279-
X = X.transpose(1, 0).unsqueeze(0)
1280-
output = model(X).cpu().numpy().squeeze(0).transpose(1, 0)
1281-
else:
1282-
# Standard evaluation, (T, C, dt)
1283-
output = model(X).cpu().numpy()
1244+
if isinstance(X, np.ndarray):
1245+
X = torch.from_numpy(X)
12841246

1285-
if input_dtype == "float64":
1286-
return output.astype(input_dtype)
1247+
with torch.no_grad():
1248+
output = self.solver_.transform(
1249+
inputs=X,
1250+
pad_before_transform=self.pad_before_transform,
1251+
session_id=session_id,
1252+
batch_size=batch_size)
12871253

1288-
return output
1254+
return output.detach().cpu().numpy()
12891255

12901256
def fit_transform(
12911257
self,
@@ -1501,6 +1467,11 @@ def load(cls,
15011467
else:
15021468
cebra_ = _check_type_checkpoint(checkpoint)
15031469

1470+
n_features = cebra_.n_features_
1471+
cebra_.solver_.n_features = ([
1472+
session_n_features for session_n_features in n_features
1473+
] if isinstance(n_features, list) else n_features)
1474+
15041475
return cebra_
15051476

15061477
def to(self, device: Union[str, torch.device]):

cebra/integrations/sklearn/metrics.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,8 @@ def infonce_loss(
8383
f"got {len(y[0])} sessions.")
8484

8585
model, _ = cebra_model._select_model(
86-
X, session_id) # check session_id validity and corresponding model
86+
X, session_id=session_id
87+
) # check session_id validity and corresponding model
8788
cebra_model._check_labels_types(y, session_id=session_id)
8889

8990
dataset, is_multisession = cebra_model._prepare_data(X, y) # single session

cebra/integrations/sklearn/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray:
9292
X,
9393
accept_sparse=False,
9494
accept_large_sparse=False,
95-
dtype=("float16", "float32", "float64"),
95+
# NOTE(celia): remove float16 because F.pad does not allow float16.
96+
dtype=("float32", "float64"),
9697
order=None,
9798
copy=False,
9899
ensure_2d=True,

0 commit comments

Comments
 (0)