@@ -33,7 +33,6 @@ class TestModelConversion(unittest.TestCase):
33
33
"""Unit tests that check for model conversion and correctness."""
34
34
35
35
def test_toy_model_with_kv_cache (self ):
36
- self .skipTest ("b/338288901" )
37
36
config = toy_model_with_kv_cache .get_model_config ()
38
37
pytorch_model = toy_model_with_kv_cache .ToyModelWithKV (config )
39
38
idx , input_pos = torch .tensor ([[1 ]], dtype = torch .long ), torch .tensor (
@@ -42,19 +41,21 @@ def test_toy_model_with_kv_cache(self):
42
41
43
42
edge_model = ai_edge_torch .convert (pytorch_model , (idx , input_pos ))
44
43
45
- self .assertTrue (
46
- model_coverage .compare_tflite_torch (
47
- edge_model ,
48
- pytorch_model ,
49
- (idx , input_pos ),
50
- num_valid_inputs = 1 ,
51
- atol = 1e-5 ,
52
- rtol = 1e-5 ,
53
- )
54
- )
44
+ # TODO(b/338288901): re-enable test to check output tensors.
45
+ skip_output_check = True
46
+ if skip_output_check is False :
47
+ self .assertTrue (
48
+ model_coverage .compare_tflite_torch (
49
+ edge_model ,
50
+ pytorch_model ,
51
+ (idx , input_pos ),
52
+ num_valid_inputs = 1 ,
53
+ atol = 1e-5 ,
54
+ rtol = 1e-5 ,
55
+ )
56
+ )
55
57
56
58
def test_toy_model_with_kv_cache_with_hlfb (self ):
57
- self .skipTest ("b/338288901" )
58
59
config = toy_model_with_kv_cache .get_model_config ()
59
60
config .enable_hlfb = True
60
61
pytorch_model = toy_model_with_kv_cache .ToyModelWithKV (config )
@@ -64,16 +65,19 @@ def test_toy_model_with_kv_cache_with_hlfb(self):
64
65
65
66
edge_model = ai_edge_torch .convert (pytorch_model , (idx , input_pos ))
66
67
67
- self .assertTrue (
68
- model_coverage .compare_tflite_torch (
69
- edge_model ,
70
- pytorch_model ,
71
- (idx , input_pos ),
72
- num_valid_inputs = 1 ,
73
- atol = 1e-5 ,
74
- rtol = 1e-5 ,
75
- )
76
- )
68
+ # TODO(b/338288901): re-enable test to check output tensors.
69
+ skip_output_check = True
70
+ if skip_output_check is False :
71
+ self .assertTrue (
72
+ model_coverage .compare_tflite_torch (
73
+ edge_model ,
74
+ pytorch_model ,
75
+ (idx , input_pos ),
76
+ num_valid_inputs = 1 ,
77
+ atol = 1e-5 ,
78
+ rtol = 1e-5 ,
79
+ )
80
+ )
77
81
78
82
def test_tiny_llama (self ):
79
83
self .skipTest ("b/338288901" )
@@ -87,19 +91,21 @@ def test_tiny_llama(self):
87
91
88
92
edge_model = ai_edge_torch .convert (pytorch_model , (tokens , input_pos ))
89
93
90
- self .assertTrue (
91
- model_coverage .compare_tflite_torch (
92
- edge_model ,
93
- pytorch_model ,
94
- (tokens , input_pos ),
95
- num_valid_inputs = 1 ,
96
- atol = 1e-5 ,
97
- rtol = 1e-5 ,
98
- )
99
- )
94
+ # TODO(b/338288901): re-enable test to check output tensors.
95
+ skip_output_check = True
96
+ if skip_output_check is False :
97
+ self .assertTrue (
98
+ model_coverage .compare_tflite_torch (
99
+ edge_model ,
100
+ pytorch_model ,
101
+ (tokens , input_pos ),
102
+ num_valid_inputs = 1 ,
103
+ atol = 1e-5 ,
104
+ rtol = 1e-5 ,
105
+ )
106
+ )
100
107
101
108
def test_tiny_llama_multisig (self ):
102
- self .skipTest ("b/338288901" )
103
109
config = tiny_llama .get_fake_model_config_for_test ()
104
110
pytorch_model = tiny_llama .TinyLLamma (config )
105
111
@@ -122,32 +128,30 @@ def test_tiny_llama_multisig(self):
122
128
.convert ()
123
129
)
124
130
125
- # For the pytorch model, the KV cache is a persistent state internal to the model, and it
126
- # will be shared for prefill and decode. However, for tflite, currently we can't share
127
- # kv-cache between the two signatures. prefill will change the content in kv-cache,
128
- # but it won't be readable by the decode tflite model. This means the output of running `decode` after
129
- # running `prefill` in pytorch will be different from the output of running `decode` after `prefill` via ai_edge_torch.
130
- copied_model = copy .deepcopy (pytorch_model )
131
-
132
- self .assertTrue (
133
- model_coverage .compare_tflite_torch (
134
- edge_model ,
135
- pytorch_model ,
136
- (prefill_tokens , prefill_input_pos ),
137
- signature_name = "prefill" ,
138
- num_valid_inputs = 1 ,
139
- )
140
- )
141
-
142
- self .assertTrue (
143
- model_coverage .compare_tflite_torch (
144
- edge_model ,
145
- copied_model ,
146
- (decode_token , decode_input_pos ),
147
- signature_name = "decode" ,
148
- num_valid_inputs = 1 ,
149
- )
150
- )
131
+ # TODO(b/338288901): re-enable test to check output tensors.
132
+ skip_output_check = True
133
+ if skip_output_check is False :
134
+ copied_model = copy .deepcopy (pytorch_model )
135
+
136
+ self .assertTrue (
137
+ model_coverage .compare_tflite_torch (
138
+ edge_model ,
139
+ pytorch_model ,
140
+ (prefill_tokens , prefill_input_pos ),
141
+ signature_name = "prefill" ,
142
+ num_valid_inputs = 1 ,
143
+ )
144
+ )
145
+
146
+ self .assertTrue (
147
+ model_coverage .compare_tflite_torch (
148
+ edge_model ,
149
+ copied_model ,
150
+ (decode_token , decode_input_pos ),
151
+ signature_name = "decode" ,
152
+ num_valid_inputs = 1 ,
153
+ )
154
+ )
151
155
152
156
def test_gemma (self ):
153
157
self .skipTest ("b/338288901" )
@@ -161,17 +165,20 @@ def test_gemma(self):
161
165
162
166
edge_model = ai_edge_torch .convert (model , (tokens , input_pos ))
163
167
164
- # TODO(talumbau, haoliang): debug numerical diff.
165
- self .assertTrue (
166
- model_coverage .compare_tflite_torch (
167
- edge_model ,
168
- model ,
169
- (tokens , input_pos ),
170
- num_valid_inputs = 1 ,
171
- atol = 1e-2 ,
172
- rtol = 1e-5 ,
173
- )
174
- )
168
+ # TODO(b/338288901): re-enable test to check output tensors.
169
+ skip_output_check = True
170
+ if skip_output_check is False :
171
+ # TODO(talumbau, haoliang): debug numerical diff.
172
+ self .assertTrue (
173
+ model_coverage .compare_tflite_torch (
174
+ edge_model ,
175
+ model ,
176
+ (tokens , input_pos ),
177
+ num_valid_inputs = 1 ,
178
+ atol = 1e-2 ,
179
+ rtol = 1e-5 ,
180
+ )
181
+ )
175
182
176
183
def test_phi2 (self ):
177
184
self .skipTest ("b/338288901" )
@@ -185,16 +192,19 @@ def test_phi2(self):
185
192
186
193
edge_model = ai_edge_torch .convert (pytorch_model , (tokens , input_pos ))
187
194
188
- self .assertTrue (
189
- model_coverage .compare_tflite_torch (
190
- edge_model ,
191
- pytorch_model ,
192
- (tokens , input_pos ),
193
- num_valid_inputs = 1 ,
194
- atol = 1e-5 ,
195
- rtol = 1e-5 ,
196
- )
197
- )
195
+ # TODO(b/338288901): re-enable test to check output tensors.
196
+ skip_output_check = True
197
+ if skip_output_check is False :
198
+ self .assertTrue (
199
+ model_coverage .compare_tflite_torch (
200
+ edge_model ,
201
+ pytorch_model ,
202
+ (tokens , input_pos ),
203
+ num_valid_inputs = 1 ,
204
+ atol = 1e-5 ,
205
+ rtol = 1e-5 ,
206
+ )
207
+ )
198
208
199
209
200
210
if __name__ == "__main__" :
0 commit comments