Skip to content

Commit 892ba83

Browse files
committed
ci: add unittest for tie-embedding empty_init
1 parent a695579 commit 892ba83

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

tests/test_big_modeling.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,13 @@ def test_init_empty_weights(self):
188188
assert module.weight.device == torch.device("cpu")
189189
assert module.running_mean.device == torch.device("cpu")
190190

191+
def test_init_empty_weights_with_tie_embedding(self):
192+
with init_empty_weights():
193+
module = torch.nn.ModuleList([torch.nn.Embedding(12, 12), torch.nn.Linear(12, 12)])
194+
# tie embedding
195+
module[0].weight = module[1].weight
196+
assert module[0].weight is module[1].weight
197+
191198
def test_init_empty_weights_very_large_model(self):
192199
# This is a 100 billion parameters model.
193200
with init_empty_weights():

0 commit comments

Comments
 (0)