Skip to content

Pickle module #176

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,11 @@ python setup.py develop

All of the code is formatted using [black](https://black.readthedocs.io) with the associated [config file](https://github.com/GRAAL-Research/poutyne/blob/master/pyproject.toml). In order to format the code of your submission, simply run

> See the [styling requirements](https://github.com/GRAAL-Research/poutyne/blob/master/styling_requirements.txt) for the proper black version to use.
> See the [styling requirements](https://github.com/GRAAL-Research/poutyne/blob/master/styling_requirements.txt) for the proper black and isort version to use.

```
black .
isort .
```

We also have our own `pylint` [config file](https://github.com/GRAAL-Research/poutyne/blob/master/.pylintrc). Try not to introduce code incoherences detected by the linting. You can run the linting procedure with
Expand Down
5 changes: 3 additions & 2 deletions poutyne/framework/callbacks/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
<https://www.gnu.org/licenses/>.
"""

import pickle
import warnings
from typing import IO, Dict

Expand Down Expand Up @@ -142,10 +143,10 @@

def save_file(self, fd: IO, epoch_number: int, logs: Dict):
states = {k: v.state_dict() for k, v in self.name_to_stateful.items()}
torch.save(states, fd)
torch.save(states, f=fd, pickle_module=pickle)

Check warning on line 146 in poutyne/framework/callbacks/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

poutyne/framework/callbacks/checkpoint.py#L146

Added line #L146 was not covered by tests

def restore(self, fd: IO):
states = torch.load(fd, map_location='cpu')
states = torch.load(fd, pickle_module=pickle, map_location='cpu')

Check warning on line 149 in poutyne/framework/callbacks/checkpoint.py

View check run for this annotation

Codecov / codecov/patch

poutyne/framework/callbacks/checkpoint.py#L149

Added line #L149 was not covered by tests

unexpected_keys = set(states.keys()) - set(self.name_to_stateful)
missing_keys = set(self.name_to_stateful) - set(states.keys())
Expand Down
5 changes: 3 additions & 2 deletions poutyne/framework/callbacks/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""

import inspect
import pickle
import sys
from typing import BinaryIO, Dict

Expand Down Expand Up @@ -74,10 +75,10 @@ def state_dict(self):
return self.scheduler.state_dict()

def load_state(self, f: BinaryIO):
self.load_state_dict(torch.load(f, map_location='cpu'))
self.load_state_dict(torch.load(f, pickle_module=pickle, map_location='cpu'))

def save_state(self, f: BinaryIO):
torch.save(self.state_dict(), f)
torch.save(self.state_dict(), f=f, pickle_module=pickle)


def new_init(torch_lr_scheduler):
Expand Down
9 changes: 5 additions & 4 deletions poutyne/framework/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# pylint: disable=too-many-lines,too-many-public-methods
import contextlib
import pickle
import timeit
from collections import defaultdict
from typing import Any, Iterable, List, Mapping, Tuple, Union
Expand Down Expand Up @@ -1547,7 +1548,7 @@ def load_weights(self, f, strict=True):
* **missing_keys** is a list of str containing the missing keys
* **unexpected_keys** is a list of str containing the unexpected keys
"""
return self.set_weights(torch.load(f, map_location='cpu'), strict=strict)
return self.set_weights(torch.load(f, pickle_module=pickle, map_location='cpu'), strict=strict)

def save_weights(self, f):
"""
Expand All @@ -1557,7 +1558,7 @@ def save_weights(self, f):
f: File-like object (has to implement fileno that returns a file descriptor) or string
containing a file name.
"""
torch.save(self.network.state_dict(), f)
torch.save(self.network.state_dict(), f=f, pickle_module=pickle)

def load_optimizer_state(self, f):
"""
Expand All @@ -1568,7 +1569,7 @@ def load_optimizer_state(self, f):
f: File-like object (has to implement fileno that returns a file descriptor) or string
containing a file name.
"""
self.optimizer.load_state_dict(torch.load(f, map_location='cpu'))
self.optimizer.load_state_dict(torch.load(f, pickle_module=pickle, map_location='cpu'))

def save_optimizer_state(self, f):
"""
Expand All @@ -1578,7 +1579,7 @@ def save_optimizer_state(self, f):
f: File-like object (has to implement fileno that returns a file descriptor) or string
containing a file name.
"""
torch.save(self.optimizer.state_dict(), f)
torch.save(self.optimizer.state_dict(), f=f, pickle_module=pickle)

def _transfer_optimizer_state_to_right_device(self):
if self.optimizer is None:
Expand Down
3 changes: 2 additions & 1 deletion poutyne/framework/model_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# pylint: disable=too-many-lines
import os
import pickle
import warnings
from typing import Any, Callable, Dict, List, Tuple, Union

Expand Down Expand Up @@ -663,7 +664,7 @@
best_checkpoint.best_filename = best_filename
best_checkpoint.current_best = best_epoch_stats[self.monitor_metric].item()
else:
best_restore.best_weights = torch.load(best_filename, map_location='cpu')
best_restore.best_weights = torch.load(best_filename, pickle_module=pickle, map_location='cpu')

Check warning on line 667 in poutyne/framework/model_bundle.py

View check run for this annotation

Codecov / codecov/patch

poutyne/framework/model_bundle.py#L667

Added line #L667 was not covered by tests
best_restore.current_best = best_epoch_stats[self.monitor_metric].item()

return callbacks
Expand Down
6 changes: 4 additions & 2 deletions poutyne/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

# -*- coding: utf-8 -*-
import os
import pickle
import random
import warnings
from typing import IO, Any, BinaryIO, Optional, Union
Expand Down Expand Up @@ -231,7 +232,8 @@ def save_random_states(f: Union[str, os.PathLike, BinaryIO, IO[bytes]]):
numpy=np.random.get_state(),
python=random.getstate(),
),
f,
f=f,
pickle_module=pickle,
)


Expand All @@ -243,7 +245,7 @@ def load_random_states(f: Any):
f: a file-like object (has to implement :meth:`read`, :meth:`readline`, :meth:`tell`, and :meth:`seek`),
or a string or os.PathLike object containing a file name
"""
states = torch.load(f)
states = torch.load(f, pickle_module=pickle)
torch.set_rng_state(states["cpu"])
torch.cuda.set_rng_state_all(states["cuda"])
np.random.set_state(states["numpy"])
Expand Down
25 changes: 20 additions & 5 deletions tests/framework/experiment/test_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""

import os
import pickle
from tempfile import TemporaryDirectory
from unittest import TestCase, skipIf

Expand Down Expand Up @@ -90,27 +91,41 @@ def test_load_checkpoint_with_int(self):
filename = self.checkpoint_paths[0]
self.test_experiment.load_checkpoint(index)

self.assertEqual(self.test_experiment.model.network.state_dict(), torch.load(filename, map_location="cpu"))
self.assertEqual(
self.test_experiment.model.network.state_dict(),
torch.load(filename, pickle_module=pickle, map_location="cpu"),
)

def test_load_checkpoint_best(self):
filename = self.checkpoint_paths[-1]
self.test_experiment.load_checkpoint("best")

self.assertEqual(self.test_experiment.model.network.state_dict(), torch.load(filename, map_location="cpu"))
self.assertEqual(
self.test_experiment.model.network.state_dict(),
torch.load(filename, pickle_module=pickle, map_location="cpu"),
)

def test_load_checkpoint_last(self):
self.test_experiment.load_checkpoint("last")

self.assertEqual(
self.test_experiment.model.network.state_dict(), torch.load(self.last_checkpoint_path, map_location="cpu")
self.test_experiment.model.network.state_dict(),
torch.load(self.last_checkpoint_path, pickle_module=pickle, map_location="cpu"),
)

def test_load_checkpoint_using_path(self):
cpkt_path = os.path.join(self.test_checkpoints_path, "test_model_weights_state_dict.p")
torch.save(torch.load(self.checkpoint_paths[0], map_location="cpu"), cpkt_path) # change the ckpt path
torch.save(
torch.load(self.checkpoint_paths[0], pickle_module=pickle, map_location="cpu"),
f=cpkt_path,
pickle_module=pickle,
) # change the ckpt path
self.test_experiment.load_checkpoint(cpkt_path)

self.assertEqual(self.test_experiment.model.network.state_dict(), torch.load(cpkt_path, map_location="cpu"))
self.assertEqual(
self.test_experiment.model.network.state_dict(),
torch.load(cpkt_path, pickle_module=pickle, map_location="cpu"),
)

def test_load_invalid_checkpoint(self):
with self.assertRaises(ValueError):
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.17.3
1.17.4