File tree Expand file tree Collapse file tree 2 files changed +24
-0
lines changed
src/lightning/pytorch/core
tests/tests_pytorch/models Expand file tree Collapse file tree 2 files changed +24
-0
lines changed Original file line number Diff line number Diff line change @@ -1472,6 +1472,10 @@ def forward(self, x):
1472
1472
)
1473
1473
example_inputs = self .example_input_array
1474
1474
1475
+ if kwargs .get ("check_inputs" ) is not None :
1476
+ kwargs ["check_inputs" ] = self ._on_before_batch_transfer (kwargs ["check_inputs" ])
1477
+ kwargs ["check_inputs" ] = self ._apply_batch_transfer_handler (kwargs ["check_inputs" ])
1478
+
1475
1479
# automatically send example inputs to the right device and use trace
1476
1480
example_inputs = self ._on_before_batch_transfer (example_inputs )
1477
1481
example_inputs = self ._apply_batch_transfer_handler (example_inputs )
Original file line number Diff line number Diff line change @@ -105,6 +105,26 @@ def test_torchscript_device(device_str):
105
105
assert script_output .device == device
106
106
107
107
108
+ @pytest .mark .parametrize (
109
+ "device_str" ,
110
+ [
111
+ "cpu" ,
112
+ pytest .param ("cuda:0" , marks = RunIf (min_cuda_gpus = 1 )),
113
+ pytest .param ("mps:0" , marks = RunIf (mps = True )),
114
+ ],
115
+ )
116
+ def test_torchscript_device_with_check_inputs (device_str ):
117
+ """Test that scripted module is on the correct device."""
118
+ device = torch .device (device_str )
119
+ model = BoringModel ().to (device )
120
+ model .example_input_array = torch .randn (5 , 32 )
121
+
122
+ check_inputs = torch .rand (5 , 32 )
123
+
124
+ script = model .to_torchscript (method = "trace" , check_inputs = check_inputs )
125
+ assert isinstance (script , torch .jit .ScriptModule )
126
+
127
+
108
128
def test_torchscript_retain_training_state ():
109
129
"""Test that torchscript export does not alter the training mode of original model."""
110
130
model = BoringModel ()
You can’t perform that action at this time.
0 commit comments