Skip to content

Support numpy 2, upgrade tests to support torch 2.6 #221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 2, 2025
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
# We aim to support the versions on pytorch.org
# as well as selected previous versions on
# https://pytorch.org/get-started/previous-versions/
torch-version: ["2.2.2", "2.4.0"]
torch-version: ["2.4.0", "2.6.0"]
sklearn-version: ["latest"]
include:
- os: windows-latest
Expand Down
26 changes: 22 additions & 4 deletions cebra/data/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,11 @@ def _is_dlc_df(h5_file: IO[bytes], df_keys: List[str]) -> bool:
"""
try:
if ["_i_table", "table"] in df_keys:
df = pd.read_hdf(h5_file, key="table")
df = read_hdf(h5_file, key="table")
else:
df = pd.read_hdf(h5_file, key=df_keys[0])
df = read_hdf(h5_file, key=df_keys[0])
except KeyError:
df = pd.read_hdf(h5_file)
df = read_hdf(h5_file)
return all(value in df.columns.names
for value in ["scorer", "bodyparts", "coords"])

Expand Down Expand Up @@ -348,7 +348,7 @@ def load_from_h5(file: Union[pathlib.Path, str], key: str,
Returns:
A :py:func:`numpy.array` containing the data of interest extracted from the :py:class:`pandas.DataFrame`.
"""
df = pd.read_hdf(file, key=key)
df = read_hdf(file, key=key)
if columns is None:
loaded_array = df.values
elif isinstance(columns, list) and df.columns.nlevels == 1:
Expand Down Expand Up @@ -716,3 +716,21 @@ def _get_loader(file_ending: str) -> _BaseLoader:
if file_ending not in __loaders.keys() or file_ending == "":
raise OSError(f"File ending {file_ending} not supported.")
return __loaders[file_ending]


def read_hdf(filename, key=None):
"""Read HDF5 file using pandas, with fallback to h5py if pandas fails.
Args:
filename: Path to HDF5 file
key: Optional key to read from HDF5 file. If None, tries "df_with_missing"
then falls back to first available key.
Returns:
pandas.DataFrame: The loaded data
Raises:
RuntimeError: If both pandas and h5py fail to load the file
"""

return pd.read_hdf(filename, key=key)
48 changes: 40 additions & 8 deletions cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@

import numpy as np
import numpy.typing as npt
import packaging.version
import pkg_resources
import sklearn
import sklearn.utils.validation as sklearn_utils_validation
import torch
import sklearn
from sklearn.base import BaseEstimator
from sklearn.base import TransformerMixin
from sklearn.utils.metaestimators import available_if
Expand All @@ -43,11 +44,38 @@
import cebra.models
import cebra.solver

# NOTE(stes): From torch 2.6 onwards, we need to specify the following list
# when loading CEBRA models to allow weights_only = True.
CEBRA_LOAD_SAFE_GLOBALS = [
cebra.data.Offset, torch.torch_version.TorchVersion, np.dtype,
np.dtypes.Float64DType, np.dtypes.Int64DType
]


def check_version(estimator):
# NOTE(stes): required as a check for the old way of specifying tags
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
from packaging import version
return version.parse(sklearn.__version__) < version.parse("1.6.dev")
return packaging.version.parse(
sklearn.__version__) < packaging.version.parse("1.6.dev")


def _safe_torch_load(filename, weights_only, **kwargs):
if weights_only is None:
if packaging.version.parse(
torch.__version__) >= packaging.version.parse("2.6.0"):
weights_only = True
else:
weights_only = False

if not weights_only:
checkpoint = torch.load(filename, weights_only=False, **kwargs)
else:
# NOTE(stes): This is only supported for torch 2.6+
with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS):
checkpoint = torch.load(filename, weights_only=True, **kwargs)

return checkpoint


def _init_loader(
is_cont: bool,
Expand Down Expand Up @@ -1409,15 +1437,22 @@ def save(self,
def load(cls,
filename: str,
backend: Literal["auto", "sklearn", "torch"] = "auto",
weights_only: bool = None,
**kwargs) -> "CEBRA":
"""Load a model from disk.

Args:
filename: The path to the file in which to save the trained model.
backend: A string identifying the used backend.
weights_only: Indicates whether unpickler should be restricted to loading only tensors, primitive types,
dictionaries and any types added via :py:func:`torch.serialization.add_safe_globals`.
See :py:func:`torch.load` with ``weights_only=True`` for more details. It it recommended to leave this
at the default value of ``None``, which sets the argument to ``False`` for torch<2.6, and ``True`` for
higher versions of torch. If you experience issues with loading custom models (specified outside
of the CEBRA package), you can try to set this to ``False`` if you trust the source of the model.
kwargs: Optional keyword arguments passed directly to the loader.

Return:
Returns:
The model to load.

Note:
Expand All @@ -1427,7 +1462,6 @@ def load(cls,
For information about the file format please refer to :py:meth:`cebra.CEBRA.save`.

Example:

>>> import cebra
>>> import numpy as np
>>> import tempfile
Expand All @@ -1441,16 +1475,14 @@ def load(cls,
>>> loaded_model = cebra.CEBRA.load(tmp_file)
>>> embedding = loaded_model.transform(dataset)
>>> tmp_file.unlink()

"""

supported_backends = ["auto", "sklearn", "torch"]
if backend not in supported_backends:
raise NotImplementedError(
f"Unsupported backend: '{backend}'. Supported backends are: {', '.join(supported_backends)}"
)

checkpoint = torch.load(filename, **kwargs)
checkpoint = _safe_torch_load(filename, weights_only, **kwargs)

if backend == "auto":
backend = "sklearn" if isinstance(checkpoint, dict) else "torch"
Expand Down
6 changes: 4 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ where =
python_requires = >=3.9
install_requires =
joblib
numpy<2.0.0
numpy<2.0;platform_system=="Windows"
numpy<2.0;platform_system!="Windows" and python_version<"3.10"
numpy;platform_system!="Windows" and python_version>="3.10"
literate-dataclasses
scikit-learn
scipy
torch
torch>=2.4.0
tqdm
matplotlib
requests
Expand Down
7 changes: 2 additions & 5 deletions tests/test_dlc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import cebra.integrations.deeplabcut as cebra_dlc
from cebra import CEBRA
from cebra import load_data
from cebra.data.load import read_hdf

# NOTE(stes): The original data URL is
# https://github.com/DeepLabCut/DeepLabCut/blob/main/examples
Expand All @@ -54,11 +55,7 @@ def test_imports():


def _load_dlc_dataframe(filename):
try:
df = pd.read_hdf(filename, "df_with_missing")
except KeyError:
df = pd.read_hdf(filename)
return df
return read_hdf(filename)


def _get_annotated_data(url, keypoints):
Expand Down
22 changes: 11 additions & 11 deletions tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def generate_h5_no_array(filename, dtype):
def generate_h5_dataframe(filename, dtype):
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
df_A.to_hdf(filename, "df_A")
df_A.to_hdf(filename, key="df_A")
loaded_A = cebra_load.load(filename, key="df_A")
return A, loaded_A

Expand All @@ -258,7 +258,7 @@ def generate_h5_dataframe_columns(filename, dtype):
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
A_col = A[:, :2]
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
df_A.to_hdf(filename, "df_A")
df_A.to_hdf(filename, key="df_A")
loaded_A = cebra_load.load(filename, key="df_A", columns=["a", "b"])
return A_col, loaded_A

Expand All @@ -269,8 +269,8 @@ def generate_h5_multi_dataframe(filename, dtype):
B = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
df_B = pd.DataFrame(np.array(B), columns=["c", "d", "e"])
df_A.to_hdf(filename, "df_A")
df_B.to_hdf(filename, "df_B")
df_A.to_hdf(filename, key="df_A")
df_B.to_hdf(filename, key="df_B")
loaded_A = cebra_load.load(filename, key="df_A")
return A, loaded_A

Expand All @@ -279,7 +279,7 @@ def generate_h5_multi_dataframe(filename, dtype):
def generate_h5_single_dataframe_no_key(filename, dtype):
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype)
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
df_A.to_hdf(filename, "df_A")
df_A.to_hdf(filename, key="df_A")
loaded_A = cebra_load.load(filename)
return A, loaded_A

Expand All @@ -290,8 +290,8 @@ def generate_h5_multi_dataframe_no_key(filename, dtype):
B = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype)
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
df_B = pd.DataFrame(np.array(B), columns=["c", "d", "e"])
df_A.to_hdf(filename, "df_A")
df_B.to_hdf(filename, "df_B")
df_A.to_hdf(filename, key="df_A")
df_B.to_hdf(filename, key="df_B")
_ = cebra_load.load(filename)


Expand All @@ -304,7 +304,7 @@ def generate_h5_multicol_dataframe(filename, dtype):
df_A = pd.DataFrame(A,
columns=pd.MultiIndex.from_product([animals,
keypoints]))
df_A.to_hdf(filename, "df_A")
df_A.to_hdf(filename, key="df_A")
loaded_A = cebra_load.load(filename, key="df_A")
return A, loaded_A

Expand All @@ -313,15 +313,15 @@ def generate_h5_multicol_dataframe(filename, dtype):
def generate_h5_dataframe_invalid_key(filename, dtype):
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype)
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
df_A.to_hdf(filename, "df_A")
df_A.to_hdf(filename, key="df_A")
_ = cebra_load.load(filename, key="df_B")


@register_error("h5", "hdf", "hdf5", "h")
def generate_h5_dataframe_invalid_column(filename, dtype):
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype)
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
df_A.to_hdf(filename, "df_A")
df_A.to_hdf(filename, key="df_A")
_ = cebra_load.load(filename, key="df_A", columns=["d", "b"])


Expand All @@ -334,7 +334,7 @@ def generate_h5_multicol_dataframe_columns(filename, dtype):
df_A = pd.DataFrame(A,
columns=pd.MultiIndex.from_product([animals,
keypoints]))
df_A.to_hdf(filename, "df_A")
df_A.to_hdf(filename, key="df_A")
_ = cebra_load.load(filename, key="df_A", columns=["a", "b"])


Expand Down
Loading