Skip to content

Commit 5994600

Browse files
authored
Add test for the loader module. (#11)
1 parent a03a13b commit 5994600

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2024 The AI Edge Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
# Testing weight loader utilities.
16+
17+
import os
18+
import tempfile
19+
import unittest
20+
21+
import safetensors.torch
22+
import torch
23+
24+
from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
25+
from ai_edge_torch.generative.utilities import loader as loading_utils
26+
27+
28+
class TestLoader(unittest.TestCase):
29+
"""Unit tests that check weight loader."""
30+
31+
def test_load_safetensors(self):
32+
with tempfile.TemporaryDirectory() as temp_dir:
33+
file_path = os.path.join(temp_dir, "test.safetensors")
34+
test_data = {"weight": torch.randn(20, 10), "bias": torch.randn(20)}
35+
safetensors.torch.save_file(test_data, file_path)
36+
37+
loaded_tensors = loading_utils.load_safetensors(file_path)
38+
self.assertIn("weight", loaded_tensors)
39+
self.assertIn("bias", loaded_tensors)
40+
41+
def test_load_statedict(self):
42+
with tempfile.TemporaryDirectory() as temp_dir:
43+
file_path = os.path.join(temp_dir, "test.pt")
44+
model = torch.nn.Linear(10, 5)
45+
state_dict = model.state_dict()
46+
torch.save(state_dict, file_path)
47+
48+
loaded_tensors = loading_utils.load_pytorch_statedict(file_path)
49+
self.assertIn("weight", loaded_tensors)
50+
self.assertIn("bias", loaded_tensors)
51+
52+
def test_model_loader(self):
53+
with tempfile.TemporaryDirectory() as temp_dir:
54+
file_path = os.path.join(temp_dir, "test.safetensors")
55+
test_weights = {
56+
"lm_head.weight": torch.randn((32000, 2048)),
57+
"model.embed_tokens.weight": torch.randn((32000, 2048)),
58+
"model.layers.0.input_layernorm.weight": torch.randn((2048,)),
59+
"model.layers.0.mlp.down_proj.weight": torch.randn((2048, 5632)),
60+
"model.layers.0.mlp.gate_proj.weight": torch.randn((5632, 2048)),
61+
"model.layers.0.mlp.up_proj.weight": torch.randn((5632, 2048)),
62+
"model.layers.0.post_attention_layernorm.weight": torch.randn((2048,)),
63+
"model.layers.0.self_attn.k_proj.weight": torch.randn((256, 2048)),
64+
"model.layers.0.self_attn.o_proj.weight": torch.randn((2048, 2048)),
65+
"model.layers.0.self_attn.q_proj.weight": torch.randn((2048, 2048)),
66+
"model.layers.0.self_attn.v_proj.weight": torch.randn((256, 2048)),
67+
"model.norm.weight": torch.randn((2048,)),
68+
}
69+
safetensors.torch.save_file(test_weights, file_path)
70+
cfg = tiny_llama.get_model_config()
71+
cfg.num_layers = 1
72+
model = tiny_llama.TinyLLamma(cfg)
73+
74+
loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
75+
# if returns successfully, it means all the tensors were initiallized.
76+
loader.load(model, strict=True)
77+
78+
79+
if __name__ == "__main__":
80+
unittest.main()

0 commit comments

Comments
 (0)