@@ -113,55 +113,55 @@ def get_args(self) -> BaseArgs:
113
113
args .target_modules = ["to_q" , "to_k" , "to_v" , "to_out.0" ]
114
114
return args
115
115
116
- @parameterized ( "enable_precomputation" , [ False , True ])
116
+ @parameterized . expand ([( False ,), ( True ,) ])
117
117
def test___dp_degree_1___batch_size_1 (self , enable_precomputation : bool ):
118
118
args = self .get_args ()
119
119
args .dp_degree = 1
120
120
args .batch_size = 1
121
121
args .enable_precomputation = enable_precomputation
122
122
self ._test_training (args )
123
123
124
- @parameterized ( "enable_precomputation" , [ False , True ])
124
+ @parameterized . expand ([( False ,), ( True ,) ])
125
125
def test___dp_degree_1___batch_size_2 (self , enable_precomputation : bool ):
126
126
args = self .get_args ()
127
127
args .dp_degree = 1
128
128
args .batch_size = 2
129
129
args .enable_precomputation = enable_precomputation
130
130
self ._test_training (args )
131
131
132
- @parameterized ( "enable_precomputation" , [ False , True ])
132
+ @parameterized . expand ([( False ,), ( True ,) ])
133
133
def test___dp_degree_2___batch_size_1 (self , enable_precomputation : bool ):
134
134
args = self .get_args ()
135
135
args .dp_degree = 2
136
136
args .batch_size = 1
137
137
args .enable_precomputation = enable_precomputation
138
138
self ._test_training (args )
139
139
140
- @parameterized ( "enable_precomputation" , [ False , True ])
140
+ @parameterized . expand ([( False ,), ( True ,) ])
141
141
def test___dp_degree_2___batch_size_2 (self , enable_precomputation : bool ):
142
142
args = self .get_args ()
143
143
args .dp_degree = 2
144
144
args .batch_size = 2
145
145
args .enable_precomputation = enable_precomputation
146
146
self ._test_training (args )
147
147
148
- @parameterized ( "enable_precomputation" , [ False , True ])
148
+ @parameterized . expand ([( False ,), ( True ,) ])
149
149
def test___dp_shards_2___batch_size_1 (self , enable_precomputation : bool ):
150
150
args = self .get_args ()
151
151
args .dp_shards = 2
152
152
args .batch_size = 1
153
153
args .enable_precomputation = enable_precomputation
154
154
self ._test_training (args )
155
155
156
- @parameterized ( "enable_precomputation" , [ False , True ])
156
+ @parameterized . expand ([( False ,), ( True ,) ])
157
157
def test___dp_shards_2___batch_size_2 (self , enable_precomputation : bool ):
158
158
args = self .get_args ()
159
159
args .dp_shards = 2
160
160
args .batch_size = 1
161
161
args .enable_precomputation = enable_precomputation
162
162
self ._test_training (args )
163
163
164
- @parameterized ( "enable_precomputation" , [ False , True ])
164
+ @parameterized . expand ([( False ,), ( True ,) ])
165
165
def test___dp_degree_2___dp_shards_2___batch_size_1 (self , enable_precomputation : bool ):
166
166
args = self .get_args ()
167
167
args .dp_degree = 2
@@ -170,7 +170,7 @@ def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation:
170
170
args .enable_precomputation = enable_precomputation
171
171
self ._test_training (args )
172
172
173
- @parameterized ( "enable_precomputation" , [ False , True ])
173
+ @parameterized . expand ([( False ,), ( True ,) ])
174
174
def test___tp_degree_2___batch_size_2 (self , enable_precomputation : bool ):
175
175
args = self .get_args ()
176
176
args .tp_degree = 2
@@ -186,55 +186,55 @@ def get_args(self) -> BaseArgs:
186
186
args .training_type = TrainingType .FULL_FINETUNE
187
187
return args
188
188
189
- @parameterized ( "enable_precomputation" , [ False , True ])
189
+ @parameterized . expand ([( False ,), ( True ,) ])
190
190
def test___dp_degree_1___batch_size_1 (self , enable_precomputation : bool ):
191
191
args = self .get_args ()
192
192
args .dp_degree = 1
193
193
args .batch_size = 1
194
194
args .enable_precomputation = enable_precomputation
195
195
self ._test_training (args )
196
196
197
- @parameterized ( "enable_precomputation" , [ False , True ])
197
+ @parameterized . expand ([( False ,), ( True ,) ])
198
198
def test___dp_degree_1___batch_size_2 (self , enable_precomputation : bool ):
199
199
args = self .get_args ()
200
200
args .dp_degree = 1
201
201
args .batch_size = 2
202
202
args .enable_precomputation = enable_precomputation
203
203
self ._test_training (args )
204
204
205
- @parameterized ( "enable_precomputation" , [ False , True ])
205
+ @parameterized . expand ([( False ,), ( True ,) ])
206
206
def test___dp_degree_2___batch_size_1 (self , enable_precomputation : bool ):
207
207
args = self .get_args ()
208
208
args .dp_degree = 2
209
209
args .batch_size = 1
210
210
args .enable_precomputation = enable_precomputation
211
211
self ._test_training (args )
212
212
213
- @parameterized ( "enable_precomputation" , [ False , True ])
213
+ @parameterized . expand ([( False ,), ( True ,) ])
214
214
def test___dp_degree_2___batch_size_2 (self , enable_precomputation : bool ):
215
215
args = self .get_args ()
216
216
args .dp_degree = 2
217
217
args .batch_size = 2
218
218
args .enable_precomputation = enable_precomputation
219
219
self ._test_training (args )
220
220
221
- @parameterized ( "enable_precomputation" , [ False , True ])
221
+ @parameterized . expand ([( False ,), ( True ,) ])
222
222
def test___dp_shards_2___batch_size_1 (self , enable_precomputation : bool ):
223
223
args = self .get_args ()
224
224
args .dp_shards = 2
225
225
args .batch_size = 1
226
226
args .enable_precomputation = enable_precomputation
227
227
self ._test_training (args )
228
228
229
- @parameterized ( "enable_precomputation" , [ False , True ])
229
+ @parameterized . expand ([( False ,), ( True ,) ])
230
230
def test___dp_shards_2___batch_size_2 (self , enable_precomputation : bool ):
231
231
args = self .get_args ()
232
232
args .dp_shards = 2
233
233
args .batch_size = 1
234
234
args .enable_precomputation = enable_precomputation
235
235
self ._test_training (args )
236
236
237
- @parameterized ( "enable_precomputation" , [ False , True ])
237
+ @parameterized . expand ([( False ,), ( True ,) ])
238
238
def test___dp_degree_2___dp_shards_2___batch_size_1 (self , enable_precomputation : bool ):
239
239
args = self .get_args ()
240
240
args .dp_degree = 2
@@ -243,7 +243,7 @@ def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation:
243
243
args .enable_precomputation = enable_precomputation
244
244
self ._test_training (args )
245
245
246
- @parameterized ( "enable_precomputation" , [ False , True ])
246
+ @parameterized . expand ([( False ,), ( True ,) ])
247
247
def test___tp_degree_2___batch_size_2 (self , enable_precomputation : bool ):
248
248
args = self .get_args ()
249
249
args .tp_degree = 2
0 commit comments