Skip to content

Commit 3100730

Browse files
authored
Support numpy 2, upgrade tests to support torch 2.6 (#221)
* Drop numpy constraint * Implement workaround for pytables * better error message * pin numpy only for python 3.9 * update dependencies * Upgrade torch version * Fix based on python version * Add support for torch.load with weights_only=True * Implement safe loading for torch models starting in torch 2.6 * Fix windows specs * fix docstring * Revert changes to loading logic
1 parent 4e32661 commit 3100730

File tree

6 files changed

+80
-31
lines changed

6 files changed

+80
-31
lines changed

.github/workflows/build.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
# We aim to support the versions on pytorch.org
1919
# as well as selected previous versions on
2020
# https://pytorch.org/get-started/previous-versions/
21-
torch-version: ["2.2.2", "2.4.0"]
21+
torch-version: ["2.4.0", "2.6.0"]
2222
sklearn-version: ["latest"]
2323
include:
2424
- os: windows-latest

cebra/data/load.py

+22-4
Original file line numberDiff line numberDiff line change
@@ -275,11 +275,11 @@ def _is_dlc_df(h5_file: IO[bytes], df_keys: List[str]) -> bool:
275275
"""
276276
try:
277277
if ["_i_table", "table"] in df_keys:
278-
df = pd.read_hdf(h5_file, key="table")
278+
df = read_hdf(h5_file, key="table")
279279
else:
280-
df = pd.read_hdf(h5_file, key=df_keys[0])
280+
df = read_hdf(h5_file, key=df_keys[0])
281281
except KeyError:
282-
df = pd.read_hdf(h5_file)
282+
df = read_hdf(h5_file)
283283
return all(value in df.columns.names
284284
for value in ["scorer", "bodyparts", "coords"])
285285

@@ -348,7 +348,7 @@ def load_from_h5(file: Union[pathlib.Path, str], key: str,
348348
Returns:
349349
A :py:func:`numpy.array` containing the data of interest extracted from the :py:class:`pandas.DataFrame`.
350350
"""
351-
df = pd.read_hdf(file, key=key)
351+
df = read_hdf(file, key=key)
352352
if columns is None:
353353
loaded_array = df.values
354354
elif isinstance(columns, list) and df.columns.nlevels == 1:
@@ -716,3 +716,21 @@ def _get_loader(file_ending: str) -> _BaseLoader:
716716
if file_ending not in __loaders.keys() or file_ending == "":
717717
raise OSError(f"File ending {file_ending} not supported.")
718718
return __loaders[file_ending]
719+
720+
721+
def read_hdf(filename, key=None):
722+
"""Read HDF5 file using pandas, with fallback to h5py if pandas fails.
723+
724+
Args:
725+
filename: Path to HDF5 file
726+
key: Optional key to read from HDF5 file. If None, tries "df_with_missing"
727+
then falls back to first available key.
728+
729+
Returns:
730+
pandas.DataFrame: The loaded data
731+
732+
Raises:
733+
RuntimeError: If both pandas and h5py fail to load the file
734+
"""
735+
736+
return pd.read_hdf(filename, key=key)

cebra/integrations/sklearn/cebra.py

+40-8
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,11 @@
2727

2828
import numpy as np
2929
import numpy.typing as npt
30+
import packaging.version
3031
import pkg_resources
32+
import sklearn
3133
import sklearn.utils.validation as sklearn_utils_validation
3234
import torch
33-
import sklearn
3435
from sklearn.base import BaseEstimator
3536
from sklearn.base import TransformerMixin
3637
from sklearn.utils.metaestimators import available_if
@@ -43,11 +44,38 @@
4344
import cebra.models
4445
import cebra.solver
4546

47+
# NOTE(stes): From torch 2.6 onwards, we need to specify the following list
48+
# when loading CEBRA models to allow weights_only = True.
49+
CEBRA_LOAD_SAFE_GLOBALS = [
50+
cebra.data.Offset, torch.torch_version.TorchVersion, np.dtype,
51+
np.dtypes.Float64DType, np.dtypes.Int64DType
52+
]
53+
54+
4655
def check_version(estimator):
4756
# NOTE(stes): required as a check for the old way of specifying tags
4857
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
49-
from packaging import version
50-
return version.parse(sklearn.__version__) < version.parse("1.6.dev")
58+
return packaging.version.parse(
59+
sklearn.__version__) < packaging.version.parse("1.6.dev")
60+
61+
62+
def _safe_torch_load(filename, weights_only, **kwargs):
63+
if weights_only is None:
64+
if packaging.version.parse(
65+
torch.__version__) >= packaging.version.parse("2.6.0"):
66+
weights_only = True
67+
else:
68+
weights_only = False
69+
70+
if not weights_only:
71+
checkpoint = torch.load(filename, weights_only=False, **kwargs)
72+
else:
73+
# NOTE(stes): This is only supported for torch 2.6+
74+
with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS):
75+
checkpoint = torch.load(filename, weights_only=True, **kwargs)
76+
77+
return checkpoint
78+
5179

5280
def _init_loader(
5381
is_cont: bool,
@@ -1409,15 +1437,22 @@ def save(self,
14091437
def load(cls,
14101438
filename: str,
14111439
backend: Literal["auto", "sklearn", "torch"] = "auto",
1440+
weights_only: bool = None,
14121441
**kwargs) -> "CEBRA":
14131442
"""Load a model from disk.
14141443
14151444
Args:
14161445
filename: The path to the file in which to save the trained model.
14171446
backend: A string identifying the used backend.
1447+
weights_only: Indicates whether unpickler should be restricted to loading only tensors, primitive types,
1448+
dictionaries and any types added via :py:func:`torch.serialization.add_safe_globals`.
1449+
See :py:func:`torch.load` with ``weights_only=True`` for more details. It it recommended to leave this
1450+
at the default value of ``None``, which sets the argument to ``False`` for torch<2.6, and ``True`` for
1451+
higher versions of torch. If you experience issues with loading custom models (specified outside
1452+
of the CEBRA package), you can try to set this to ``False`` if you trust the source of the model.
14181453
kwargs: Optional keyword arguments passed directly to the loader.
14191454
1420-
Return:
1455+
Returns:
14211456
The model to load.
14221457
14231458
Note:
@@ -1427,7 +1462,6 @@ def load(cls,
14271462
For information about the file format please refer to :py:meth:`cebra.CEBRA.save`.
14281463
14291464
Example:
1430-
14311465
>>> import cebra
14321466
>>> import numpy as np
14331467
>>> import tempfile
@@ -1441,16 +1475,14 @@ def load(cls,
14411475
>>> loaded_model = cebra.CEBRA.load(tmp_file)
14421476
>>> embedding = loaded_model.transform(dataset)
14431477
>>> tmp_file.unlink()
1444-
14451478
"""
1446-
14471479
supported_backends = ["auto", "sklearn", "torch"]
14481480
if backend not in supported_backends:
14491481
raise NotImplementedError(
14501482
f"Unsupported backend: '{backend}'. Supported backends are: {', '.join(supported_backends)}"
14511483
)
14521484

1453-
checkpoint = torch.load(filename, **kwargs)
1485+
checkpoint = _safe_torch_load(filename, weights_only, **kwargs)
14541486

14551487
if backend == "auto":
14561488
backend = "sklearn" if isinstance(checkpoint, dict) else "torch"

setup.cfg

+4-2
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,13 @@ where =
3131
python_requires = >=3.9
3232
install_requires =
3333
joblib
34-
numpy<2.0.0
34+
numpy<2.0;platform_system=="Windows"
35+
numpy<2.0;platform_system!="Windows" and python_version<"3.10"
36+
numpy;platform_system!="Windows" and python_version>="3.10"
3537
literate-dataclasses
3638
scikit-learn
3739
scipy
38-
torch
40+
torch>=2.4.0
3941
tqdm
4042
matplotlib
4143
requests

tests/test_dlc.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import cebra.integrations.deeplabcut as cebra_dlc
3030
from cebra import CEBRA
3131
from cebra import load_data
32+
from cebra.data.load import read_hdf
3233

3334
# NOTE(stes): The original data URL is
3435
# https://github.com/DeepLabCut/DeepLabCut/blob/main/examples
@@ -54,11 +55,7 @@ def test_imports():
5455

5556

5657
def _load_dlc_dataframe(filename):
57-
try:
58-
df = pd.read_hdf(filename, "df_with_missing")
59-
except KeyError:
60-
df = pd.read_hdf(filename)
61-
return df
58+
return read_hdf(filename)
6259

6360

6461
def _get_annotated_data(url, keypoints):

tests/test_load.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ def generate_h5_no_array(filename, dtype):
248248
def generate_h5_dataframe(filename, dtype):
249249
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
250250
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
251-
df_A.to_hdf(filename, "df_A")
251+
df_A.to_hdf(filename, key="df_A")
252252
loaded_A = cebra_load.load(filename, key="df_A")
253253
return A, loaded_A
254254

@@ -258,7 +258,7 @@ def generate_h5_dataframe_columns(filename, dtype):
258258
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
259259
A_col = A[:, :2]
260260
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
261-
df_A.to_hdf(filename, "df_A")
261+
df_A.to_hdf(filename, key="df_A")
262262
loaded_A = cebra_load.load(filename, key="df_A", columns=["a", "b"])
263263
return A_col, loaded_A
264264

@@ -269,8 +269,8 @@ def generate_h5_multi_dataframe(filename, dtype):
269269
B = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
270270
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
271271
df_B = pd.DataFrame(np.array(B), columns=["c", "d", "e"])
272-
df_A.to_hdf(filename, "df_A")
273-
df_B.to_hdf(filename, "df_B")
272+
df_A.to_hdf(filename, key="df_A")
273+
df_B.to_hdf(filename, key="df_B")
274274
loaded_A = cebra_load.load(filename, key="df_A")
275275
return A, loaded_A
276276

@@ -279,7 +279,7 @@ def generate_h5_multi_dataframe(filename, dtype):
279279
def generate_h5_single_dataframe_no_key(filename, dtype):
280280
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype)
281281
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
282-
df_A.to_hdf(filename, "df_A")
282+
df_A.to_hdf(filename, key="df_A")
283283
loaded_A = cebra_load.load(filename)
284284
return A, loaded_A
285285

@@ -290,8 +290,8 @@ def generate_h5_multi_dataframe_no_key(filename, dtype):
290290
B = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype)
291291
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
292292
df_B = pd.DataFrame(np.array(B), columns=["c", "d", "e"])
293-
df_A.to_hdf(filename, "df_A")
294-
df_B.to_hdf(filename, "df_B")
293+
df_A.to_hdf(filename, key="df_A")
294+
df_B.to_hdf(filename, key="df_B")
295295
_ = cebra_load.load(filename)
296296

297297

@@ -304,7 +304,7 @@ def generate_h5_multicol_dataframe(filename, dtype):
304304
df_A = pd.DataFrame(A,
305305
columns=pd.MultiIndex.from_product([animals,
306306
keypoints]))
307-
df_A.to_hdf(filename, "df_A")
307+
df_A.to_hdf(filename, key="df_A")
308308
loaded_A = cebra_load.load(filename, key="df_A")
309309
return A, loaded_A
310310

@@ -313,15 +313,15 @@ def generate_h5_multicol_dataframe(filename, dtype):
313313
def generate_h5_dataframe_invalid_key(filename, dtype):
314314
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype)
315315
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
316-
df_A.to_hdf(filename, "df_A")
316+
df_A.to_hdf(filename, key="df_A")
317317
_ = cebra_load.load(filename, key="df_B")
318318

319319

320320
@register_error("h5", "hdf", "hdf5", "h")
321321
def generate_h5_dataframe_invalid_column(filename, dtype):
322322
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype)
323323
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
324-
df_A.to_hdf(filename, "df_A")
324+
df_A.to_hdf(filename, key="df_A")
325325
_ = cebra_load.load(filename, key="df_A", columns=["d", "b"])
326326

327327

@@ -334,7 +334,7 @@ def generate_h5_multicol_dataframe_columns(filename, dtype):
334334
df_A = pd.DataFrame(A,
335335
columns=pd.MultiIndex.from_product([animals,
336336
keypoints]))
337-
df_A.to_hdf(filename, "df_A")
337+
df_A.to_hdf(filename, key="df_A")
338338
_ = cebra_load.load(filename, key="df_A", columns=["a", "b"])
339339

340340

0 commit comments

Comments
 (0)