Skip to content

Commit cd8527a

Browse files
authored
fix (#307)
1 parent 60ab7f0 commit cd8527a

File tree

2 files changed

+29
-26
lines changed

2 files changed

+29
-26
lines changed

finetrainers/trainer/sft_trainer/trainer.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -918,15 +918,17 @@ def _prepare_data(
918918
else:
919919
logger.info("Precomputed condition & latent data exhausted. Loading & preprocessing new data.")
920920

921-
parallel_backend = self.state.parallel_backend
922-
train_state = self.state.train_state
923-
self.checkpointer.save(
924-
train_state.step,
925-
force=True,
926-
_device=parallel_backend.device,
927-
_is_main_process=parallel_backend.is_main_process,
928-
)
929-
self._delete_components(component_names=["transformer", "unet"])
921+
# TODO(aryan): This needs to be revisited. For some reason, the tests did not detect that self.transformer
922+
# had become None after this but should have been loaded back from the checkpoint.
923+
# parallel_backend = self.state.parallel_backend
924+
# train_state = self.state.train_state
925+
# self.checkpointer.save(
926+
# train_state.step,
927+
# force=True,
928+
# _device=parallel_backend.device,
929+
# _is_main_process=parallel_backend.is_main_process,
930+
# )
931+
# self._delete_components(component_names=["transformer", "unet"])
930932

931933
if self.args.precomputation_once:
932934
consume_fn = preprocessor.consume_once
@@ -967,7 +969,8 @@ def _prepare_data(
967969
self._delete_components(component_names)
968970
del latent_components, component_names, component_modules
969971

970-
self.checkpointer.load()
972+
# self.checkpointer.load()
973+
# self.transformer = self.checkpointer.states["model"].model[0]
971974

972975
return condition_iterator, latent_iterator
973976

tests/trainer/test_sft_trainer.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -113,55 +113,55 @@ def get_args(self) -> BaseArgs:
113113
args.target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
114114
return args
115115

116-
@parameterized("enable_precomputation", [False, True])
116+
@parameterized.expand([(False,), (True,)])
117117
def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool):
118118
args = self.get_args()
119119
args.dp_degree = 1
120120
args.batch_size = 1
121121
args.enable_precomputation = enable_precomputation
122122
self._test_training(args)
123123

124-
@parameterized("enable_precomputation", [False, True])
124+
@parameterized.expand([(False,), (True,)])
125125
def test___dp_degree_1___batch_size_2(self, enable_precomputation: bool):
126126
args = self.get_args()
127127
args.dp_degree = 1
128128
args.batch_size = 2
129129
args.enable_precomputation = enable_precomputation
130130
self._test_training(args)
131131

132-
@parameterized("enable_precomputation", [False, True])
132+
@parameterized.expand([(False,), (True,)])
133133
def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool):
134134
args = self.get_args()
135135
args.dp_degree = 2
136136
args.batch_size = 1
137137
args.enable_precomputation = enable_precomputation
138138
self._test_training(args)
139139

140-
@parameterized("enable_precomputation", [False, True])
140+
@parameterized.expand([(False,), (True,)])
141141
def test___dp_degree_2___batch_size_2(self, enable_precomputation: bool):
142142
args = self.get_args()
143143
args.dp_degree = 2
144144
args.batch_size = 2
145145
args.enable_precomputation = enable_precomputation
146146
self._test_training(args)
147147

148-
@parameterized("enable_precomputation", [False, True])
148+
@parameterized.expand([(False,), (True,)])
149149
def test___dp_shards_2___batch_size_1(self, enable_precomputation: bool):
150150
args = self.get_args()
151151
args.dp_shards = 2
152152
args.batch_size = 1
153153
args.enable_precomputation = enable_precomputation
154154
self._test_training(args)
155155

156-
@parameterized("enable_precomputation", [False, True])
156+
@parameterized.expand([(False,), (True,)])
157157
def test___dp_shards_2___batch_size_2(self, enable_precomputation: bool):
158158
args = self.get_args()
159159
args.dp_shards = 2
160160
args.batch_size = 1
161161
args.enable_precomputation = enable_precomputation
162162
self._test_training(args)
163163

164-
@parameterized("enable_precomputation", [False, True])
164+
@parameterized.expand([(False,), (True,)])
165165
def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation: bool):
166166
args = self.get_args()
167167
args.dp_degree = 2
@@ -170,7 +170,7 @@ def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation:
170170
args.enable_precomputation = enable_precomputation
171171
self._test_training(args)
172172

173-
@parameterized("enable_precomputation", [False, True])
173+
@parameterized.expand([(False,), (True,)])
174174
def test___tp_degree_2___batch_size_2(self, enable_precomputation: bool):
175175
args = self.get_args()
176176
args.tp_degree = 2
@@ -186,55 +186,55 @@ def get_args(self) -> BaseArgs:
186186
args.training_type = TrainingType.FULL_FINETUNE
187187
return args
188188

189-
@parameterized("enable_precomputation", [False, True])
189+
@parameterized.expand([(False,), (True,)])
190190
def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool):
191191
args = self.get_args()
192192
args.dp_degree = 1
193193
args.batch_size = 1
194194
args.enable_precomputation = enable_precomputation
195195
self._test_training(args)
196196

197-
@parameterized("enable_precomputation", [False, True])
197+
@parameterized.expand([(False,), (True,)])
198198
def test___dp_degree_1___batch_size_2(self, enable_precomputation: bool):
199199
args = self.get_args()
200200
args.dp_degree = 1
201201
args.batch_size = 2
202202
args.enable_precomputation = enable_precomputation
203203
self._test_training(args)
204204

205-
@parameterized("enable_precomputation", [False, True])
205+
@parameterized.expand([(False,), (True,)])
206206
def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool):
207207
args = self.get_args()
208208
args.dp_degree = 2
209209
args.batch_size = 1
210210
args.enable_precomputation = enable_precomputation
211211
self._test_training(args)
212212

213-
@parameterized("enable_precomputation", [False, True])
213+
@parameterized.expand([(False,), (True,)])
214214
def test___dp_degree_2___batch_size_2(self, enable_precomputation: bool):
215215
args = self.get_args()
216216
args.dp_degree = 2
217217
args.batch_size = 2
218218
args.enable_precomputation = enable_precomputation
219219
self._test_training(args)
220220

221-
@parameterized("enable_precomputation", [False, True])
221+
@parameterized.expand([(False,), (True,)])
222222
def test___dp_shards_2___batch_size_1(self, enable_precomputation: bool):
223223
args = self.get_args()
224224
args.dp_shards = 2
225225
args.batch_size = 1
226226
args.enable_precomputation = enable_precomputation
227227
self._test_training(args)
228228

229-
@parameterized("enable_precomputation", [False, True])
229+
@parameterized.expand([(False,), (True,)])
230230
def test___dp_shards_2___batch_size_2(self, enable_precomputation: bool):
231231
args = self.get_args()
232232
args.dp_shards = 2
233233
args.batch_size = 1
234234
args.enable_precomputation = enable_precomputation
235235
self._test_training(args)
236236

237-
@parameterized("enable_precomputation", [False, True])
237+
@parameterized.expand([(False,), (True,)])
238238
def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation: bool):
239239
args = self.get_args()
240240
args.dp_degree = 2
@@ -243,7 +243,7 @@ def test___dp_degree_2___dp_shards_2___batch_size_1(self, enable_precomputation:
243243
args.enable_precomputation = enable_precomputation
244244
self._test_training(args)
245245

246-
@parameterized("enable_precomputation", [False, True])
246+
@parameterized.expand([(False,), (True,)])
247247
def test___tp_degree_2___batch_size_2(self, enable_precomputation: bool):
248248
args = self.get_args()
249249
args.tp_degree = 2

0 commit comments

Comments
 (0)