Skip to content

Commit ab7b118

Browse files
authored
fix: move check_inputs to target device if available during to_torchscript. (#20873)
1 parent 980ec50 commit ab7b118

File tree

2 files changed

+24
-0
lines changed

2 files changed

+24
-0
lines changed

src/lightning/pytorch/core/module.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1472,6 +1472,10 @@ def forward(self, x):
14721472
)
14731473
example_inputs = self.example_input_array
14741474

1475+
if kwargs.get("check_inputs") is not None:
1476+
kwargs["check_inputs"] = self._on_before_batch_transfer(kwargs["check_inputs"])
1477+
kwargs["check_inputs"] = self._apply_batch_transfer_handler(kwargs["check_inputs"])
1478+
14751479
# automatically send example inputs to the right device and use trace
14761480
example_inputs = self._on_before_batch_transfer(example_inputs)
14771481
example_inputs = self._apply_batch_transfer_handler(example_inputs)

tests/tests_pytorch/models/test_torchscript.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,26 @@ def test_torchscript_device(device_str):
105105
assert script_output.device == device
106106

107107

108+
@pytest.mark.parametrize(
109+
"device_str",
110+
[
111+
"cpu",
112+
pytest.param("cuda:0", marks=RunIf(min_cuda_gpus=1)),
113+
pytest.param("mps:0", marks=RunIf(mps=True)),
114+
],
115+
)
116+
def test_torchscript_device_with_check_inputs(device_str):
117+
"""Test that scripted module is on the correct device."""
118+
device = torch.device(device_str)
119+
model = BoringModel().to(device)
120+
model.example_input_array = torch.randn(5, 32)
121+
122+
check_inputs = torch.rand(5, 32)
123+
124+
script = model.to_torchscript(method="trace", check_inputs=check_inputs)
125+
assert isinstance(script, torch.jit.ScriptModule)
126+
127+
108128
def test_torchscript_retain_training_state():
109129
"""Test that torchscript export does not alter the training mode of original model."""
110130
model = BoringModel()

0 commit comments

Comments
 (0)