Skip to content

Commit 0408ab1

Browse files
warn for invalid keys (#3613)
* warn for invalid keys * add test for check_device_map invalid keys * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 55e518a commit 0408ab1

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

src/accelerate/utils/modeling.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,6 +1607,14 @@ def check_device_map(model: nn.Module, device_map: dict[str, Union[int, str, tor
16071607
model (`torch.nn.Module`): The model to check the device map against.
16081608
device_map (`Dict[str, Union[int, str, torch.device]]`): The device map to check.
16091609
"""
1610+
all_module_names = dict(model.named_modules())
1611+
invalid_keys = [k for k in device_map if k != "" and k not in all_module_names]
1612+
1613+
if invalid_keys:
1614+
warnings.warn(
1615+
f"The following device_map keys do not match any submodules in the model: {invalid_keys}", UserWarning
1616+
)
1617+
16101618
all_model_tensors = [name for name, _ in model.state_dict().items()]
16111619
for module_name in device_map.keys():
16121620
if module_name == "":

tests/test_modeling_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,26 @@ def test_check_device_map(self):
349349

350350
check_device_map(model, {"linear1": 0, "linear2": 1, "batchnorm": 1})
351351

352+
def test_check_device_map_invalid_keys(self):
353+
model = ModelForTest()
354+
355+
device_map = {
356+
"linear1": "cpu", # Valid module
357+
"batchnorm": "cpu", # Valid module
358+
"linear2": "cpu", # Valid module
359+
"invalid_module": 0, # Invalid - should trigger warning
360+
"another_invalid": 1, # Invalid - should trigger warning
361+
}
362+
363+
# Test for the warning about invalid keys
364+
with self.assertWarns(UserWarning) as cm:
365+
check_device_map(model, device_map)
366+
367+
warning_msg = str(cm.warning)
368+
self.assertIn("device_map keys do not match any submodules", warning_msg)
369+
self.assertIn("invalid_module", warning_msg)
370+
self.assertIn("another_invalid", warning_msg)
371+
352372
def shard_test_model(self, model, tmp_dir):
353373
module_index = {
354374
"linear1": "checkpoint_part1.bin",

0 commit comments

Comments
 (0)