Skip to content

Commit 382a02c

Browse files
authored
Enable test for model conversion. (#32)
* Enable test for model conversion. * update. * Update test_model_conversion.py skip some tests * Update test_model_conversion.py disable some tests * Update test_model_conversion.py remove old comments. * style fix.
1 parent 01a4c3d commit 382a02c

File tree

1 file changed

+90
-80
lines changed

1 file changed

+90
-80
lines changed

ai_edge_torch/generative/test/test_model_conversion.py

Lines changed: 90 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ class TestModelConversion(unittest.TestCase):
3333
"""Unit tests that check for model conversion and correctness."""
3434

3535
def test_toy_model_with_kv_cache(self):
36-
self.skipTest("b/338288901")
3736
config = toy_model_with_kv_cache.get_model_config()
3837
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
3938
idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
@@ -42,19 +41,21 @@ def test_toy_model_with_kv_cache(self):
4241

4342
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
4443

45-
self.assertTrue(
46-
model_coverage.compare_tflite_torch(
47-
edge_model,
48-
pytorch_model,
49-
(idx, input_pos),
50-
num_valid_inputs=1,
51-
atol=1e-5,
52-
rtol=1e-5,
53-
)
54-
)
44+
# TODO(b/338288901): re-enable test to check output tensors.
45+
skip_output_check = True
46+
if skip_output_check is False:
47+
self.assertTrue(
48+
model_coverage.compare_tflite_torch(
49+
edge_model,
50+
pytorch_model,
51+
(idx, input_pos),
52+
num_valid_inputs=1,
53+
atol=1e-5,
54+
rtol=1e-5,
55+
)
56+
)
5557

5658
def test_toy_model_with_kv_cache_with_hlfb(self):
57-
self.skipTest("b/338288901")
5859
config = toy_model_with_kv_cache.get_model_config()
5960
config.enable_hlfb = True
6061
pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
@@ -64,16 +65,19 @@ def test_toy_model_with_kv_cache_with_hlfb(self):
6465

6566
edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
6667

67-
self.assertTrue(
68-
model_coverage.compare_tflite_torch(
69-
edge_model,
70-
pytorch_model,
71-
(idx, input_pos),
72-
num_valid_inputs=1,
73-
atol=1e-5,
74-
rtol=1e-5,
75-
)
76-
)
68+
# TODO(b/338288901): re-enable test to check output tensors.
69+
skip_output_check = True
70+
if skip_output_check is False:
71+
self.assertTrue(
72+
model_coverage.compare_tflite_torch(
73+
edge_model,
74+
pytorch_model,
75+
(idx, input_pos),
76+
num_valid_inputs=1,
77+
atol=1e-5,
78+
rtol=1e-5,
79+
)
80+
)
7781

7882
def test_tiny_llama(self):
7983
self.skipTest("b/338288901")
@@ -87,19 +91,21 @@ def test_tiny_llama(self):
8791

8892
edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
8993

90-
self.assertTrue(
91-
model_coverage.compare_tflite_torch(
92-
edge_model,
93-
pytorch_model,
94-
(tokens, input_pos),
95-
num_valid_inputs=1,
96-
atol=1e-5,
97-
rtol=1e-5,
98-
)
99-
)
94+
# TODO(b/338288901): re-enable test to check output tensors.
95+
skip_output_check = True
96+
if skip_output_check is False:
97+
self.assertTrue(
98+
model_coverage.compare_tflite_torch(
99+
edge_model,
100+
pytorch_model,
101+
(tokens, input_pos),
102+
num_valid_inputs=1,
103+
atol=1e-5,
104+
rtol=1e-5,
105+
)
106+
)
100107

101108
def test_tiny_llama_multisig(self):
102-
self.skipTest("b/338288901")
103109
config = tiny_llama.get_fake_model_config_for_test()
104110
pytorch_model = tiny_llama.TinyLLamma(config)
105111

@@ -122,32 +128,30 @@ def test_tiny_llama_multisig(self):
122128
.convert()
123129
)
124130

125-
# For the pytorch model, the KV cache is a persistent state internal to the model, and it
126-
# will be shared for prefill and decode. However, for tflite, currently we can't share
127-
# kv-cache between the two signatures. prefill will change the content in kv-cache,
128-
# but it won't be readable by the decode tflite model. This means the output of running `decode` after
129-
# running `prefill` in pytorch will be different from the output of running `decode` after `prefill` via ai_edge_torch.
130-
copied_model = copy.deepcopy(pytorch_model)
131-
132-
self.assertTrue(
133-
model_coverage.compare_tflite_torch(
134-
edge_model,
135-
pytorch_model,
136-
(prefill_tokens, prefill_input_pos),
137-
signature_name="prefill",
138-
num_valid_inputs=1,
139-
)
140-
)
141-
142-
self.assertTrue(
143-
model_coverage.compare_tflite_torch(
144-
edge_model,
145-
copied_model,
146-
(decode_token, decode_input_pos),
147-
signature_name="decode",
148-
num_valid_inputs=1,
149-
)
150-
)
131+
# TODO(b/338288901): re-enable test to check output tensors.
132+
skip_output_check = True
133+
if skip_output_check is False:
134+
copied_model = copy.deepcopy(pytorch_model)
135+
136+
self.assertTrue(
137+
model_coverage.compare_tflite_torch(
138+
edge_model,
139+
pytorch_model,
140+
(prefill_tokens, prefill_input_pos),
141+
signature_name="prefill",
142+
num_valid_inputs=1,
143+
)
144+
)
145+
146+
self.assertTrue(
147+
model_coverage.compare_tflite_torch(
148+
edge_model,
149+
copied_model,
150+
(decode_token, decode_input_pos),
151+
signature_name="decode",
152+
num_valid_inputs=1,
153+
)
154+
)
151155

152156
def test_gemma(self):
153157
self.skipTest("b/338288901")
@@ -161,17 +165,20 @@ def test_gemma(self):
161165

162166
edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
163167

164-
# TODO(talumbau, haoliang): debug numerical diff.
165-
self.assertTrue(
166-
model_coverage.compare_tflite_torch(
167-
edge_model,
168-
model,
169-
(tokens, input_pos),
170-
num_valid_inputs=1,
171-
atol=1e-2,
172-
rtol=1e-5,
173-
)
174-
)
168+
# TODO(b/338288901): re-enable test to check output tensors.
169+
skip_output_check = True
170+
if skip_output_check is False:
171+
# TODO(talumbau, haoliang): debug numerical diff.
172+
self.assertTrue(
173+
model_coverage.compare_tflite_torch(
174+
edge_model,
175+
model,
176+
(tokens, input_pos),
177+
num_valid_inputs=1,
178+
atol=1e-2,
179+
rtol=1e-5,
180+
)
181+
)
175182

176183
def test_phi2(self):
177184
self.skipTest("b/338288901")
@@ -185,16 +192,19 @@ def test_phi2(self):
185192

186193
edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
187194

188-
self.assertTrue(
189-
model_coverage.compare_tflite_torch(
190-
edge_model,
191-
pytorch_model,
192-
(tokens, input_pos),
193-
num_valid_inputs=1,
194-
atol=1e-5,
195-
rtol=1e-5,
196-
)
197-
)
195+
# TODO(b/338288901): re-enable test to check output tensors.
196+
skip_output_check = True
197+
if skip_output_check is False:
198+
self.assertTrue(
199+
model_coverage.compare_tflite_torch(
200+
edge_model,
201+
pytorch_model,
202+
(tokens, input_pos),
203+
num_valid_inputs=1,
204+
atol=1e-5,
205+
rtol=1e-5,
206+
)
207+
)
198208

199209

200210
if __name__ == "__main__":

0 commit comments

Comments
 (0)