22
22
"""
23
23
# You can also adapt this script on your own clm task. Pointers for this are left as comments.
24
24
25
+ import json
26
+
25
27
# region Imports
26
28
import logging
27
29
import math
46
48
TF_MODEL_FOR_CAUSAL_LM_MAPPING ,
47
49
AutoConfig ,
48
50
AutoTokenizer ,
49
- DefaultDataCollator ,
50
51
HfArgumentParser ,
52
+ PushToHubCallback ,
51
53
TFAutoModelForCausalLM ,
52
54
TFTrainingArguments ,
53
55
create_optimizer ,
@@ -205,21 +207,6 @@ def __post_init__(self):
205
207
assert extension in ["csv" , "json" , "txt" ], "`validation_file` should be a csv, a json or a txt file."
206
208
207
209
208
- # endregion
209
-
210
- # region Helper classes
211
- class SavePretrainedCallback (tf .keras .callbacks .Callback ):
212
- # Hugging Face models have a save_pretrained() method that saves both the weights and the necessary
213
- # metadata to allow them to be loaded as a pretrained model in future. This is a simple Keras callback
214
- # that saves the model with this method after each epoch.
215
- def __init__ (self , output_dir , ** kwargs ):
216
- super ().__init__ ()
217
- self .output_dir = output_dir
218
-
219
- def on_epoch_end (self , epoch , logs = None ):
220
- self .model .save_pretrained (self .output_dir )
221
-
222
-
223
210
# endregion
224
211
225
212
@@ -299,19 +286,22 @@ def main():
299
286
raw_datasets = load_dataset (
300
287
data_args .dataset_name ,
301
288
data_args .dataset_config_name ,
289
+ cache_dir = model_args .cache_dir ,
302
290
use_auth_token = True if model_args .use_auth_token else None ,
303
291
)
304
292
if "validation" not in raw_datasets .keys ():
305
293
raw_datasets ["validation" ] = load_dataset (
306
294
data_args .dataset_name ,
307
295
data_args .dataset_config_name ,
308
296
split = f"train[:{ data_args .validation_split_percentage } %]" ,
297
+ cache_dir = model_args .cache_dir ,
309
298
use_auth_token = True if model_args .use_auth_token else None ,
310
299
)
311
300
raw_datasets ["train" ] = load_dataset (
312
301
data_args .dataset_name ,
313
302
data_args .dataset_config_name ,
314
303
split = f"train[{ data_args .validation_split_percentage } %:]" ,
304
+ cache_dir = model_args .cache_dir ,
315
305
use_auth_token = True if model_args .use_auth_token else None ,
316
306
)
317
307
else :
@@ -321,16 +311,39 @@ def main():
321
311
data_files ["train" ] = data_args .train_file
322
312
if data_args .validation_file is not None :
323
313
data_files ["validation" ] = data_args .validation_file
324
- extension = data_args .train_file .split ("." )[- 1 ]
314
+ extension = (
315
+ data_args .train_file .split ("." )[- 1 ]
316
+ if data_args .train_file is not None
317
+ else data_args .validation_file .split ("." )[- 1 ]
318
+ )
325
319
if extension == "txt" :
326
320
extension = "text"
327
321
dataset_args ["keep_linebreaks" ] = data_args .keep_linebreaks
328
322
raw_datasets = load_dataset (
329
323
extension ,
330
324
data_files = data_files ,
325
+ cache_dir = model_args .cache_dir ,
331
326
use_auth_token = True if model_args .use_auth_token else None ,
332
327
** dataset_args ,
333
328
)
329
+ # If no validation data is there, validation_split_percentage will be used to divide the dataset.
330
+ if "validation" not in raw_datasets .keys ():
331
+ raw_datasets ["validation" ] = load_dataset (
332
+ extension ,
333
+ data_files = data_files ,
334
+ split = f"train[:{ data_args .validation_split_percentage } %]" ,
335
+ cache_dir = model_args .cache_dir ,
336
+ use_auth_token = True if model_args .use_auth_token else None ,
337
+ ** dataset_args ,
338
+ )
339
+ raw_datasets ["train" ] = load_dataset (
340
+ extension ,
341
+ data_files = data_files ,
342
+ split = f"train[{ data_args .validation_split_percentage } %:]" ,
343
+ cache_dir = model_args .cache_dir ,
344
+ use_auth_token = True if model_args .use_auth_token else None ,
345
+ ** dataset_args ,
346
+ )
334
347
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
335
348
# https://huggingface.co/docs/datasets/loading_datasets.html.
336
349
# endregion
@@ -446,7 +459,7 @@ def group_texts(examples):
446
459
eval_dataset = eval_dataset .select (range (max_eval_samples ))
447
460
448
461
# Log a few random samples from the training set:
449
- for index in random .sample (range (len (train_dataset )), 3 ):
462
+ for index in random .sample (range (len (train_dataset )), min ( 3 , len ( train_dataset )) ):
450
463
logger .info (f"Sample { index } of the training set: { train_dataset [index ]} ." )
451
464
# endregion
452
465
@@ -465,44 +478,88 @@ def group_texts(examples):
465
478
466
479
# region TF Dataset preparation
467
480
num_replicas = training_args .strategy .num_replicas_in_sync
468
- data_collator = DefaultDataCollator (return_tensors = "tf" )
469
481
options = tf .data .Options ()
470
482
options .experimental_distribute .auto_shard_policy = tf .data .experimental .AutoShardPolicy .OFF
471
483
472
- tf_train_dataset = train_dataset .to_tf_dataset (
473
- # labels are passed as input, as we will use the model's internal loss
474
- columns = [col for col in train_dataset .features if col != "special_tokens_mask" ],
484
+ # model.prepare_tf_dataset() wraps a Hugging Face dataset in a tf.data.Dataset which is ready to use in
485
+ # training. This is the recommended way to use a Hugging Face dataset when training with Keras. You can also
486
+ # use the lower-level dataset.to_tf_dataset() method, but you will have to specify things like column names
487
+ # yourself if you use this method, whereas they are automatically inferred from the model input names when
488
+ # using model.prepare_tf_dataset()
489
+ # For more info see the docs:
490
+ # https://huggingface.co/docs/transformers/main/en/main_classes/model#transformers.TFPreTrainedModel.prepare_tf_dataset
491
+ # https://huggingface.co/docs/datasets/main/en/package_reference/main_classes#datasets.Dataset.to_tf_dataset
492
+
493
+ tf_train_dataset = model .prepare_tf_dataset (
494
+ train_dataset ,
475
495
shuffle = True ,
476
496
batch_size = num_replicas * training_args .per_device_train_batch_size ,
477
- collate_fn = data_collator ,
478
- drop_remainder = True ,
479
497
).with_options (options )
480
498
481
- tf_eval_dataset = eval_dataset .to_tf_dataset (
482
- # labels are passed as input, as we will use the model's internal loss
483
- columns = [col for col in eval_dataset .features if col != "special_tokens_mask" ],
499
+ tf_eval_dataset = model .prepare_tf_dataset (
500
+ eval_dataset ,
484
501
shuffle = False ,
485
- batch_size = num_replicas * training_args .per_device_train_batch_size ,
486
- collate_fn = data_collator ,
502
+ batch_size = num_replicas * training_args .per_device_eval_batch_size ,
487
503
drop_remainder = True ,
488
504
).with_options (options )
489
505
# endregion
490
506
491
507
# region Optimizer and loss
492
- batches_per_epoch = len (train_dataset ) // (num_replicas * training_args .per_device_train_batch_size )
508
+ num_train_steps = len (tf_train_dataset ) * int (training_args .num_train_epochs )
509
+ if training_args .warmup_steps > 0 :
510
+ num_warmup_steps = training_args .warmup_steps
511
+ elif training_args .warmup_ratio > 0 :
512
+ num_warmup_steps = int (num_train_steps * training_args .warmup_ratio )
513
+ else :
514
+ num_warmup_steps = 0
515
+
493
516
# Bias and layernorm weights are automatically excluded from the decay
494
517
optimizer , lr_schedule = create_optimizer (
495
518
init_lr = training_args .learning_rate ,
496
- num_train_steps = int ( training_args . num_train_epochs * batches_per_epoch ) ,
497
- num_warmup_steps = training_args . warmup_steps ,
519
+ num_train_steps = num_train_steps ,
520
+ num_warmup_steps = num_warmup_steps ,
498
521
adam_beta1 = training_args .adam_beta1 ,
499
522
adam_beta2 = training_args .adam_beta2 ,
500
523
adam_epsilon = training_args .adam_epsilon ,
501
524
weight_decay_rate = training_args .weight_decay ,
525
+ adam_global_clipnorm = training_args .max_grad_norm ,
502
526
)
503
527
504
528
# no user-specified loss = will use the model internal loss
505
- model .compile (optimizer = optimizer )
529
+ model .compile (optimizer = optimizer , jit_compile = training_args .xla )
530
+ # endregion
531
+
532
+ # region Preparing push_to_hub and model card
533
+ push_to_hub_model_id = training_args .push_to_hub_model_id
534
+ model_name = model_args .model_name_or_path .split ("/" )[- 1 ]
535
+ if not push_to_hub_model_id :
536
+ if data_args .dataset_name is not None :
537
+ push_to_hub_model_id = f"{ model_name } -finetuned-{ data_args .dataset_name } "
538
+ else :
539
+ push_to_hub_model_id = f"{ model_name } -finetuned-clm"
540
+
541
+ model_card_kwargs = {"finetuned_from" : model_args .model_name_or_path , "tasks" : "text-generation" }
542
+ if data_args .dataset_name is not None :
543
+ model_card_kwargs ["dataset_tags" ] = data_args .dataset_name
544
+ if data_args .dataset_config_name is not None :
545
+ model_card_kwargs ["dataset_args" ] = data_args .dataset_config_name
546
+ model_card_kwargs ["dataset" ] = f"{ data_args .dataset_name } { data_args .dataset_config_name } "
547
+ else :
548
+ model_card_kwargs ["dataset" ] = data_args .dataset_name
549
+
550
+ if training_args .push_to_hub :
551
+ callbacks = [
552
+ PushToHubCallback (
553
+ output_dir = training_args .output_dir ,
554
+ model_id = push_to_hub_model_id ,
555
+ organization = training_args .push_to_hub_organization ,
556
+ token = training_args .push_to_hub_token ,
557
+ tokenizer = tokenizer ,
558
+ ** model_card_kwargs ,
559
+ )
560
+ ]
561
+ else :
562
+ callbacks = []
506
563
# endregion
507
564
508
565
# region Training and validation
@@ -512,33 +569,45 @@ def group_texts(examples):
512
569
logger .info (f" Instantaneous batch size per device = { training_args .per_device_train_batch_size } " )
513
570
logger .info (f" Total train batch size = { training_args .per_device_train_batch_size * num_replicas } " )
514
571
572
+ # For long training runs, you may wish to use the PushToHub() callback here to save intermediate checkpoints
573
+ # to the Hugging Face Hub rather than just pushing the finished model.
574
+ # See https://huggingface.co/docs/transformers/main_classes/keras_callbacks#transformers.PushToHubCallback
575
+
515
576
history = model .fit (
516
577
tf_train_dataset ,
517
578
validation_data = tf_eval_dataset ,
518
579
epochs = int (training_args .num_train_epochs ),
519
- steps_per_epoch = len (train_dataset ) // (training_args .per_device_train_batch_size * num_replicas ),
520
- callbacks = [SavePretrainedCallback (output_dir = training_args .output_dir )],
580
+ callbacks = callbacks ,
521
581
)
582
+ train_loss = history .history ["loss" ][- 1 ]
522
583
try :
523
- train_perplexity = math .exp (history . history [ "loss" ][ - 1 ] )
584
+ train_perplexity = math .exp (train_loss )
524
585
except OverflowError :
525
586
train_perplexity = math .inf
587
+ logger .info (f" Final train loss: { train_loss :.3f} " )
588
+ logger .info (f" Final train perplexity: { train_perplexity :.3f} " )
589
+ validation_loss = history .history ["val_loss" ][- 1 ]
526
590
try :
527
- validation_perplexity = math .exp (history . history [ "val_loss" ][ - 1 ] )
591
+ validation_perplexity = math .exp (validation_loss )
528
592
except OverflowError :
529
593
validation_perplexity = math .inf
530
- logger .info (f" Final train loss: { history .history ['loss' ][- 1 ]:.3f} " )
531
- logger .info (f" Final train perplexity: { train_perplexity :.3f} " )
532
- logger .info (f" Final validation loss: { history .history ['val_loss' ][- 1 ]:.3f} " )
594
+ logger .info (f" Final validation loss: { validation_loss :.3f} " )
533
595
logger .info (f" Final validation perplexity: { validation_perplexity :.3f} " )
534
- # endregion
535
596
536
597
if training_args .output_dir is not None :
537
- model .save_pretrained (training_args .output_dir )
598
+ output_eval_file = os .path .join (training_args .output_dir , "all_results.json" )
599
+ results_dict = dict ()
600
+ results_dict ["train_loss" ] = train_loss
601
+ results_dict ["train_perplexity" ] = train_perplexity
602
+ results_dict ["eval_loss" ] = validation_loss
603
+ results_dict ["eval_perplexity" ] = validation_perplexity
604
+ with open (output_eval_file , "w" ) as writer :
605
+ writer .write (json .dumps (results_dict ))
606
+ # endregion
538
607
539
- if training_args .push_to_hub :
540
- # You'll probably want to include some of your own metadata here!
541
- model .push_to_hub ( )
608
+ if training_args .output_dir is not None and not training_args . push_to_hub :
609
+ # If we're not pushing to hub, at least save a local copy when we're done
610
+ model .save_pretrained ( training_args . output_dir )
542
611
543
612
544
613
if __name__ == "__main__" :
0 commit comments