Skip to content

Commit 76b691d

Browse files
dominicgkerrdominicgkerr
and
dominicgkerr
authored
Support pathlib.Path file paths when saving ONNX models (#19727)
Co-authored-by: dominicgkerr <[email protected]>
1 parent ce88483 commit 76b691d

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

src/lightning/pytorch/core/module.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1395,7 +1395,7 @@ def forward(self, x):
13951395
input_sample = self._on_before_batch_transfer(input_sample)
13961396
input_sample = self._apply_batch_transfer_handler(input_sample)
13971397

1398-
torch.onnx.export(self, input_sample, file_path, **kwargs)
1398+
torch.onnx.export(self, input_sample, str(file_path), **kwargs)
13991399
self.train(mode)
14001400

14011401
@torch.no_grad()

tests/tests_pytorch/models/test_onnx.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import operator
1515
import os
16+
from pathlib import Path
1617
from unittest.mock import patch
1718

1819
import numpy as np
@@ -32,11 +33,14 @@
3233
def test_model_saves_with_input_sample(tmp_path):
3334
"""Test that ONNX model saves with input sample and size is greater than 3 MB."""
3435
model = BoringModel()
35-
trainer = Trainer(fast_dev_run=True)
36-
trainer.fit(model)
37-
38-
file_path = os.path.join(tmp_path, "model.onnx")
3936
input_sample = torch.randn((1, 32))
37+
38+
file_path = os.path.join(tmp_path, "os.path.onnx")
39+
model.to_onnx(file_path, input_sample)
40+
assert os.path.isfile(file_path)
41+
assert os.path.getsize(file_path) > 4e2
42+
43+
file_path = Path(tmp_path) / "pathlib.onnx"
4044
model.to_onnx(file_path, input_sample)
4145
assert os.path.isfile(file_path)
4246
assert os.path.getsize(file_path) > 4e2

0 commit comments

Comments
 (0)