File tree Expand file tree Collapse file tree 2 files changed +9
-5
lines changed
src/lightning/pytorch/core
tests/tests_pytorch/models Expand file tree Collapse file tree 2 files changed +9
-5
lines changed Original file line number Diff line number Diff line change @@ -1395,7 +1395,7 @@ def forward(self, x):
1395
1395
input_sample = self ._on_before_batch_transfer (input_sample )
1396
1396
input_sample = self ._apply_batch_transfer_handler (input_sample )
1397
1397
1398
- torch .onnx .export (self , input_sample , file_path , ** kwargs )
1398
+ torch .onnx .export (self , input_sample , str ( file_path ) , ** kwargs )
1399
1399
self .train (mode )
1400
1400
1401
1401
@torch .no_grad ()
Original file line number Diff line number Diff line change 13
13
# limitations under the License.
14
14
import operator
15
15
import os
16
+ from pathlib import Path
16
17
from unittest .mock import patch
17
18
18
19
import numpy as np
32
33
def test_model_saves_with_input_sample (tmp_path ):
33
34
"""Test that ONNX model saves with input sample and size is greater than 3 MB."""
34
35
model = BoringModel ()
35
- trainer = Trainer (fast_dev_run = True )
36
- trainer .fit (model )
37
-
38
- file_path = os .path .join (tmp_path , "model.onnx" )
39
36
input_sample = torch .randn ((1 , 32 ))
37
+
38
+ file_path = os .path .join (tmp_path , "os.path.onnx" )
39
+ model .to_onnx (file_path , input_sample )
40
+ assert os .path .isfile (file_path )
41
+ assert os .path .getsize (file_path ) > 4e2
42
+
43
+ file_path = Path (tmp_path ) / "pathlib.onnx"
40
44
model .to_onnx (file_path , input_sample )
41
45
assert os .path .isfile (file_path )
42
46
assert os .path .getsize (file_path ) > 4e2
You can’t perform that action at this time.
0 commit comments