Skip to content

Commit 167cb5e

Browse files
authored
[tests] fix bug in torch_device (#2909)
1 parent 947f64e commit 167cb5e

File tree

1 file changed

+1
-4
lines changed

1 file changed

+1
-4
lines changed

tests/test_modeling_utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -748,10 +748,7 @@ def test_load_state_dict(self):
748748

749749
for param, device in device_map.items():
750750
device = device if device != "disk" else "cpu"
751-
expected_device = (
752-
torch.device(f"{torch_device}:{device}") if isinstance(device, int) else torch.device(device)
753-
)
754-
assert loaded_state_dict[param].device == expected_device
751+
assert loaded_state_dict[param].device == torch.device(device)
755752

756753
def test_convert_file_size(self):
757754
result = convert_file_size_to_int("0MB")

0 commit comments

Comments
 (0)