-
Notifications
You must be signed in to change notification settings - Fork 67
/
Copy pathpretrain.py
463 lines (442 loc) · 18.1 KB
/
pretrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-Apache2
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO(@mgreaves, @jstjohn, @jomitchell) Consider different abstractions for pretraining, inference, and fine-tuning and see
# how they would address code duplication in the case of ESM2+Geneformer as well as a third hypothetical model that does
# not share the same types/loaders, such as OpenFold. The design should be flexible enough to allow for those differeht
# use cases and not hide too much complexity that a user would want to customize, while reducing code duplication
# between scripts.
import argparse
import math
from pathlib import Path
from typing import Optional, Sequence, get_args
from megatron.core.optimizer import OptimizerConfig
from nemo import lightning as nl
from nemo.collections import llm
from nemo.lightning import io, resume
from nemo.lightning.pytorch import callbacks as nl_callbacks
from nemo.lightning.pytorch.optim import MegatronOptimizerModule
from nemo.lightning.pytorch.optim.lr_scheduler import CosineAnnealingScheduler
from nemo.utils import logging
from pytorch_lightning.callbacks import LearningRateMonitor, RichModelSummary
from torch.nn import functional as F
from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype
from bionemo.geneformer.api import GeneformerConfig
from bionemo.geneformer.data.singlecell.datamodule import SingleCellDataModule
from bionemo.geneformer.data.singlecell.preprocess import GeneformerPreprocess
from bionemo.llm.lightning import LossLoggingCallback
from bionemo.llm.model.biobert.lightning import BioBertLightningModule
from bionemo.llm.model.biobert.model import BiobertSpecOption
from bionemo.llm.utils.logger_utils import WandbLoggerOptions, setup_nemo_lightning_logger
__all__: Sequence[str] = ("main",)
def main(
data_dir: Path,
num_nodes: int,
devices: int,
seq_length: int,
result_dir: Path,
wandb_project: Optional[str],
wandb_offline: bool,
num_steps: int,
limit_val_batches: int,
val_check_interval: int,
num_dataset_workers: int,
biobert_spec_option: BiobertSpecOption,
lr: float,
micro_batch_size: int,
cosine_rampup_frac: float,
cosine_hold_frac: float,
experiment_name: str,
resume_if_exists: bool,
precision: PrecisionTypes,
wandb_entity: str = "clara-discovery",
create_tensorboard_logger: bool = False,
nemo1_init_path: Optional[Path] = None,
restore_from_checkpoint_path: Optional[str] = None,
save_best_checkpoint: bool = True,
save_last_checkpoint: bool = True,
metric_to_monitor_for_checkpoints: str = "val_loss",
save_top_k: int = 2,
save_every_n_steps: int = 100,
) -> None:
"""Train a Geneformer model on single cell data.
Args:
data_dir (Path): Base directory for the data.
num_nodes (int): Number of nodes to run on
devices (int): number of devices
seq_length (int): sequence length
result_dir (Path): directory to store results, logs and checkpoints
wandb_project (Optional[str]): weights and biases project name
wandb_offline (bool): if wandb should happen in offline mode
num_steps (int): number of steps to train the model for
limit_val_batches (int): limit the number of validation global batches to this many
val_check_interval (int): number of steps to periodically check the validation loss and save num_dataset_workers (
int): num dataset workers
biobert_spec_option (BiobertSpecOption): the biobert spec option (architecture) to use for this run
lr (float): learning rate
micro_batch_size (int): micro batch size, from this and parallelism settings we infer the global batch size
cosine_rampup_frac (float): fraction of steps at the beginning of the run to ramp up the learning rate
cosine_hold_frac (float): fraction of steps to hold the minimum learning rate at the end of the run
experiment_name (str): experiment name, this is the name used for the wandb run, and the sub-directory of the
result_dir that stores the logs and checkpoints.
resume_if_exists (bool): attempt to resume if the checkpoint exists [FIXME @skothenhill this doesn't work yet]
wandb_entity (str): the group to use for the wandb run, sometimes called a team, could also be your username
create_tensorboard_logger (bool): create the tensorboard logger
restore_from_checkpoint_path (path): If set, restores the model from the directory passed in. Expects the
checkpoint to be created by using the ModelCheckpoint class and enable_nemo_ckpt_io=True.
"""
# Create the result directory if it does not exist.
result_dir.mkdir(parents=True, exist_ok=True)
# Setup train/test/val data paths
train_data_path = data_dir / "train"
val_data_path = data_dir / "val"
test_data_path = data_dir / "test"
# Setup the strategy and trainer
pipeline_model_parallel_size = 1
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=pipeline_model_parallel_size,
ddp="megatron",
find_unused_parameters=True,
ckpt_include_optimizer=True,
)
wandb_options: Optional[WandbLoggerOptions] = (
None
if wandb_project is None
else WandbLoggerOptions(
offline=wandb_offline,
project=wandb_project,
entity=wandb_entity,
log_model=False,
)
)
trainer = nl.Trainer(
devices=devices,
max_steps=num_steps,
accelerator="gpu",
strategy=strategy,
limit_val_batches=limit_val_batches, # This controls upsampling and downsampling
val_check_interval=val_check_interval, # TODO(@jstjohn) Checkpoint saving is currently broken, fix and change this.
num_nodes=num_nodes,
callbacks=[
# TODO(@skothenhill-nv) these need to be cleaned up when we have the automatic addition of track_io
io.track_io(LossLoggingCallback)(),
io.track_io(RichModelSummary)(max_depth=4),
io.track_io(LearningRateMonitor)(),
],
plugins=nl.MegatronMixedPrecision(precision=precision, amp_O2=False),
)
preprocessor = GeneformerPreprocess(
download_directory=train_data_path,
medians_file_path=train_data_path / "medians.json",
tokenizer_vocab_path=train_data_path / "geneformer.vocab",
)
match preprocessor.preprocess():
case {"tokenizer": tokenizer, "median_dict": median_dict}:
logging.info("*************** Preprocessing Finished ************")
case _:
logging.error("Preprocessing failed.")
# Configure the data module and model
data = SingleCellDataModule(
seq_length=seq_length,
tokenizer=tokenizer,
train_dataset_path=train_data_path,
val_dataset_path=val_data_path,
test_dataset_path=test_data_path,
random_token_prob=0.02, # changed to represent the incorrect setting we originally used.
median_dict=median_dict,
micro_batch_size=micro_batch_size,
global_batch_size=micro_batch_size * int(num_nodes * devices / pipeline_model_parallel_size),
# persistent workers is supported when num_dataset_workers > 0
persistent_workers=num_dataset_workers > 0,
pin_memory=False,
num_workers=num_dataset_workers,
)
geneformer_config = GeneformerConfig(
num_layers=6,
hidden_size=256,
ffn_hidden_size=512,
num_attention_heads=4,
seq_length=seq_length,
fp32_residual_connection=False, # TODO(@jstjohn) check this
hidden_dropout=0.02,
init_method_std=0.02,
kv_channels=None,
apply_query_key_layer_scaling=False,
make_vocab_size_divisible_by=128,
masked_softmax_fusion=True, # TODO(@jstjohn) check this
fp16_lm_cross_entropy=False,
params_dtype=get_autocast_dtype(precision),
pipeline_dtype=get_autocast_dtype(precision),
autocast_dtype=get_autocast_dtype(precision), # setting this speeds things up a lot
gradient_accumulation_fusion=False, # THIS BREAKS STUFF, leave False
layernorm_zero_centered_gamma=False, # TODO(@jstjohn) check this
layernorm_epsilon=1.0e-12,
activation_func=F.gelu, # TODO(@jstjohn) check this
qk_layernorm=False, # TODO(@jstjohn) check this
apply_residual_connection_post_layernorm=False, # False is new default, True was BERT pub.
bias_activation_fusion=True, # TODO(@jstjohn) check this
bias_dropout_fusion=True, # TODO(@jstjohn) check this
get_attention_mask_from_fusion=False,
attention_dropout=0.1,
share_embeddings_and_output_weights=True,
enable_autocast=False, # This has to be set to True if we use the mixed precision plugin
biobert_spec_option=biobert_spec_option,
nemo1_ckpt_path=nemo1_init_path,
)
# The lightning class owns a copy of the actual model, and a loss function, both of which are configured
# and lazily returned by the `geneformer_config` object defined above.
model = BioBertLightningModule(
geneformer_config,
tokenizer=tokenizer,
optimizer=MegatronOptimizerModule(
config=OptimizerConfig(
lr=lr,
# TODO(@jstjohn) try decoupled_lr
optimizer="adam",
use_distributed_optimizer=True,
),
lr_scheduler=CosineAnnealingScheduler(
max_steps=num_steps,
# minimum learning rate is 1/100th of the initial learning rate, so eg lr=1e-3 -> min_lr=1e-5
min_lr=lr / 100,
warmup_steps=int(math.ceil(num_steps * cosine_rampup_frac)),
interval="step",
monitor="val_loss",
constant_steps=int(math.ceil(num_steps * cosine_hold_frac)),
),
),
)
# Configure our custom Checkpointer
checkpoint_callback = nl_callbacks.ModelCheckpoint(
save_best_model=save_best_checkpoint,
save_last=save_last_checkpoint,
monitor=metric_to_monitor_for_checkpoints, # "val_loss",
save_top_k=save_top_k,
every_n_train_steps=save_every_n_steps,
enable_nemo_ckpt_io=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe
async_save=False, # Tries to save asynchronously, previously led to race conditions.
)
# Setup the logger and train the model
nemo_logger = setup_nemo_lightning_logger(
root_dir=result_dir,
name=experiment_name,
initialize_tensorboard_logger=create_tensorboard_logger,
wandb_kwargs=wandb_options,
ckpt_callback=checkpoint_callback,
)
llm.train(
model=model,
data=data,
trainer=trainer,
log=nemo_logger,
resume=resume.AutoResume(
path=restore_from_checkpoint_path, # Overrides the path found by resume_if_exists when set.
resume_if_exists=resume_if_exists, # Looks for the -last checkpoint to continue training.
resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint.
),
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Pretrain Geneformer with single cell data.")
parser.add_argument(
"--data-dir",
type=Path,
required=True,
help="Path to the data base directory, for example this might be "
"/workspace/bionemo2/data/cellxgene_2023-12-15_small",
)
parser.add_argument(
"--precision",
type=str,
choices=get_args(PrecisionTypes),
required=False,
default="bf16-mixed",
help="Precision type to use for training.",
)
parser.add_argument(
"--lr",
type=float,
required=False,
default=1e-4,
help="Learning rate for training. Default is 1e-4. With bigger global batches try 1e-3",
)
parser.add_argument(
"--create-tensorboard-logger", action="store_true", default=False, help="Create a tensorboard logger."
)
# FIXME (@skothenhill) figure out how checkpointing and resumption should work with the new nemo trainer
parser.add_argument(
"--resume-if-exists", action="store_true", default=False, help="Resume training if a checkpoint exists."
)
parser.add_argument(
"--result-dir", type=Path, required=False, default=Path("./results"), help="Path to the result directory."
)
parser.add_argument(
"--experiment-name", type=str, required=False, default="geneformer", help="Name of the experiment."
)
parser.add_argument("--wandb-offline", action="store_true", default=False, help="Use wandb in offline mode.")
parser.add_argument(
"--wandb-project",
type=str,
required=False,
default=None,
help="Wandb project name. Wandb will only happen if this is set..",
)
parser.add_argument(
"--cosine-rampup-frac",
type=float,
required=False,
default=0.01,
help="Fraction of steps in which to ramp up the learning rate. Default is 0.01.",
)
parser.add_argument(
"--cosine-hold-frac",
type=float,
required=False,
default=0.05,
help="Fraction of final steps in which to hold the minimum LR. Default is 0.05.",
)
parser.add_argument(
"--num-gpus",
type=int,
required=False,
default=1,
help="Number of GPUs to use for training. Default is 1.",
)
parser.add_argument(
"--num-nodes",
type=int,
required=False,
default=1,
help="Number of nodes to use for training. Default is 1.",
)
parser.add_argument(
"--num-steps",
type=int,
required=False,
default=10000,
help="Number of steps to use for training. Default is 10000.",
)
parser.add_argument(
"--num-dataset-workers",
type=int,
required=False,
default=0,
help="Number of steps to use for training. Default is 0.",
)
parser.add_argument(
"--val-check-interval",
type=int,
required=False,
default=10000,
help="Number of steps to use for training. Default is 10000.",
)
parser.add_argument(
"--seq-length",
type=int,
required=False,
default=2048,
help="Sequence length of cell. Default is 2048.",
)
parser.add_argument(
"--limit-val-batches",
type=int,
required=False,
default=2,
help="Number of steps to use for training. Default is 2.",
)
parser.add_argument(
"--micro-batch-size",
type=int,
required=False,
default=64,
help="Micro-batch size. Global batch size is inferred from this.",
)
parser.add_argument(
"--biobert-spec-option",
type=BiobertSpecOption,
choices=[e.value for e in BiobertSpecOption],
required=False,
default=BiobertSpecOption.bert_layer_local_spec.value,
help="Biobert spec option to use for the model. Default is 'bert_layer_local_spec'.",
)
parser.add_argument(
"--nemo1-init-path",
type=Path,
required=False,
help="Path to nemo1 file, if desired to load at init time.",
)
parser.add_argument(
"--save-best-checkpoint",
action="store_true",
default=True,
help="Save the best checkpoint based on the metric to monitor.",
)
parser.add_argument(
"--save-last-checkpoint",
action="store_true",
default=True,
help="Save the last checkpoint.",
)
parser.add_argument(
"--metric-to-monitor-for-checkpoints",
type=str,
required=False,
default="val_loss",
help="The metric to monitor for checkpointing.",
)
parser.add_argument(
"--save-top-k",
type=int,
required=False,
default=2,
help="Save the top k checkpoints.",
)
parser.add_argument(
"--restore-from-checkpoint-path",
type=Path,
required=False,
default=None,
help="Path to the checkpoint directory to restore from. Will override `--resume-if-exists` when set.",
)
# Parse the arguments and pull them out into local variables for ease of future refactor to a
# config management system.
args = parser.parse_args()
main(
data_dir=args.data_dir,
num_nodes=args.num_nodes,
devices=args.num_gpus,
seq_length=args.seq_length,
result_dir=args.result_dir,
wandb_project=args.wandb_project,
wandb_offline=args.wandb_offline,
num_steps=args.num_steps,
limit_val_batches=args.limit_val_batches,
val_check_interval=args.val_check_interval,
num_dataset_workers=args.num_dataset_workers,
biobert_spec_option=args.biobert_spec_option,
lr=args.lr,
micro_batch_size=args.micro_batch_size,
cosine_rampup_frac=args.cosine_rampup_frac,
cosine_hold_frac=args.cosine_hold_frac,
precision=args.precision,
experiment_name=args.experiment_name,
resume_if_exists=args.resume_if_exists,
nemo1_init_path=args.nemo1_init_path,
restore_from_checkpoint_path=args.restore_from_checkpoint_path,
save_best_checkpoint=args.save_best_checkpoint,
save_last_checkpoint=args.save_last_checkpoint,
metric_to_monitor_for_checkpoints=args.metric_to_monitor_for_checkpoints,
save_top_k=args.save_top_k,
save_every_n_steps=args.val_check_interval,
)