diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 1a2a0ef87..e500cc709 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -12,7 +12,7 @@ "target": "dev" }, "mounts": [ - // Mount the local ~/.aws config to pass along AWS credentials for PBSS + // Mount the local ~/.aws config to pass along AWS credentials for PBSS. "source=${localEnv:HOME}/.aws,target=/home/bionemo/.aws,type=bind,consistency=cached", "source=${localEnv:HOME}/.ssh,target=/home/bionemo/.ssh,readonly,type=bind,consistency=cached" ], diff --git a/scripts/protein/esm2/esm2_pretrain.py b/scripts/protein/esm2/esm2_pretrain.py new file mode 100644 index 000000000..d4ca9514a --- /dev/null +++ b/scripts/protein/esm2/esm2_pretrain.py @@ -0,0 +1,483 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +from pathlib import Path +from typing import Optional, Sequence, get_args + +from megatron.core.optimizer import OptimizerConfig +from nemo import lightning as nl +from nemo.collections import llm +from nemo.lightning import resume +from nemo.lightning.pytorch import callbacks as nl_callbacks +from nemo.lightning.pytorch.optim import MegatronOptimizerModule +from pytorch_lightning.callbacks import LearningRateMonitor, RichModelSummary + +from bionemo.core.utils.dtypes import PrecisionTypes, get_autocast_dtype +from bionemo.esm2.api import ESM2Config +from bionemo.esm2.data.datamodule import ESMDataModule +from bionemo.esm2.data.tokenizer import get_tokenizer +from bionemo.esm2.model.lr_scheduler import WarmupAnnealDecayHoldScheduler +from bionemo.llm.lightning import LossLoggingCallback +from bionemo.llm.model.biobert.lightning import BioBertLightningModule +from bionemo.llm.model.biobert.model import BiobertSpecOption +from bionemo.llm.utils.datamodule_utils import float_or_int_or_none, infer_global_batch_size +from bionemo.llm.utils.logger_utils import WandbLoggerOptions, setup_nemo_lightning_logger + + +__all__: Sequence[str] = ("main", "parser") + + +def main( + train_cluster_path: Path, + train_database_path: Path, + valid_cluster_path: Path, + valid_database_path: Path, + num_nodes: int, + devices: int, + seq_length: int, + result_dir: Path, + wandb_project: Optional[str], + wandb_offline: bool, + num_steps: int, + warmup_steps: int, + limit_val_batches: int, + val_check_interval: int, + num_dataset_workers: int, + biobert_spec_option: BiobertSpecOption, # TODO(@farhadrgh) clarify how to parse this. + lr: float, + micro_batch_size: int, + accumulate_grad_batches: int, + experiment_name: str, + resume_if_exists: bool, + precision: PrecisionTypes, + wandb_entity: str = "clara-discovery", + create_tensorboard_logger: bool = False, + nemo1_init_path: Optional[Path] = None, + restore_from_checkpoint_path: Optional[str] = None, + save_best_checkpoint: bool = True, + save_last_checkpoint: bool = True, + metric_to_monitor_for_checkpoints: str = "val_loss", + save_top_k: int = 2, + save_every_n_steps: int = 100, + num_layers: int = 33, + hidden_size: int = 1280, + num_attention_heads: int = 20, + ffn_hidden_size: int = 1280 * 4, +) -> None: + """Train an ESM2 model on UR data. + + Args: + train_cluster_path (Path): path to train cluster partquet + train_database_path (Path): path to train database + valid_cluster_path (Path): path to validation cluster parquet + valid_database_path (Path): path to validation database + num_nodes (int): Number of nodes to run on + devices (int): number of devices + seq_length (int): sequence length + result_dir (Path): directory to store results, logs and checkpoints + wandb_project (Optional[str]): weights and biases project name + wandb_offline (bool): if wandb should happen in offline mode + num_steps (int): number of steps to train the model for + limit_val_batches (int): limit the number of validation global batches to this many + val_check_interval (int): number of steps to periodically check the validation loss and save num_dataset_workers ( + int): num dataset workers + biobert_spec_option (BiobertSpecOption): the biobert spec option (architecture) to use for this run + lr (float): learning rate + micro_batch_size (int): micro batch size, from this and parallelism settings we infer the global batch size + experiment_name (str): experiment name, this is the name used for the wandb run, and the sub-directory of the + result_dir that stores the logs and checkpoints. + resume_if_exists (bool): attempt to resume if the checkpoint exists [FIXME @skothenhill this doesn't work yet] + wandb_entity (str): the group to use for the wandb run, sometimes called a team, could also be your username + create_tensorboard_logger (bool): create the tensorboard logger + restore_from_checkpoint_path (path): If set, restores the model from the directory passed in. Expects the + checkpoint to be created by using the ModelCheckpoint class and enable_nemo_ckpt_io=True. + """ + # Create the result directory if it does not exist. + result_dir.mkdir(parents=True, exist_ok=True) + + # Setup the strategy and trainer + pipeline_model_parallel_size = 1 + tensor_model_parallel_size = 1 + global_batch_size = infer_global_batch_size( + micro_batch_size=micro_batch_size, + num_nodes=num_nodes, + devices=devices, + accumulate_grad_batches=accumulate_grad_batches, + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + ) + + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + ddp="megatron", + find_unused_parameters=True, + ckpt_include_optimizer=True, + ) + + wandb_options: Optional[WandbLoggerOptions] = ( + None + if wandb_project is None + else WandbLoggerOptions( + offline=wandb_offline, + project=wandb_project, + entity=wandb_entity, + log_model=False, + ) + ) + trainer = nl.Trainer( + devices=devices, + max_steps=num_steps, + accelerator="gpu", + strategy=strategy, + limit_val_batches=limit_val_batches, # This controls upsampling and downsampling + val_check_interval=val_check_interval, + num_nodes=num_nodes, + callbacks=[ + LossLoggingCallback(), + RichModelSummary(max_depth=4), + LearningRateMonitor(), + ], + plugins=nl.MegatronMixedPrecision(precision=precision), + ) + + tokenizer = get_tokenizer() + + # Initialize the data module. + data = ESMDataModule( + train_cluster_path=train_cluster_path, + train_database_path=train_database_path, + valid_cluster_path=valid_cluster_path, + valid_database_path=valid_database_path, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + min_seq_length=None, + max_seq_length=seq_length, + num_workers=num_dataset_workers, + ) + + # Configure the model + esm2_config = ESM2Config( + seq_length=seq_length, + num_layers=num_layers, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + ffn_hidden_size=ffn_hidden_size, + params_dtype=get_autocast_dtype(precision), + pipeline_dtype=get_autocast_dtype(precision), + autocast_dtype=get_autocast_dtype(precision), # setting this speeds things up a lot + biobert_spec_option=biobert_spec_option, + nemo1_ckpt_path=nemo1_init_path, + ) + + model = BioBertLightningModule( + esm2_config, + tokenizer=tokenizer, + optimizer=MegatronOptimizerModule( + config=OptimizerConfig( + lr=lr, + optimizer="adam", # fused_adam not supported + use_distributed_optimizer=True, + weight_decay=0.01, + adam_beta1=0.9, + adam_beta2=0.98, + ), + lr_scheduler=WarmupAnnealDecayHoldScheduler( + warmup_steps=warmup_steps, max_steps=num_steps, max_lr=lr, min_lr=lr / 10.0, anneal_percentage=0.10 + ), + ), + ) + + # Configure our custom Checkpointer + checkpoint_callback = nl_callbacks.ModelCheckpoint( + save_best_model=save_best_checkpoint, + save_last=save_last_checkpoint, + monitor=metric_to_monitor_for_checkpoints, # "val_loss", + save_top_k=save_top_k, + every_n_train_steps=save_every_n_steps, + enable_nemo_ckpt_io=True, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe + ) + + # Setup the logger and train the model + nemo_logger = setup_nemo_lightning_logger( + root_dir=result_dir, + name=experiment_name, + initialize_tensorboard_logger=create_tensorboard_logger, + wandb_kwargs=wandb_options, + ckpt_callback=checkpoint_callback, + ) + + llm.train( + model=model, + data=data, + trainer=trainer, + log=nemo_logger, + resume=resume.AutoResume( + path=restore_from_checkpoint_path, # Overrides the path found by resume_if_exists when set. + resume_if_exists=resume_if_exists, # Looks for the -last checkpoint to continue training. + resume_ignore_no_checkpoint=True, # When false this will throw an error with no existing checkpoint. + ), + ) + + +# TODO migrate to hydra config +# Parse the arguments and pull them out into local variables for ease of future refactor to a +# config management system. +parser = argparse.ArgumentParser(description="Pretrain ESM2 with UR data.") +parser.add_argument( + "--train-cluster-path", + type=Path, + required=True, + help="Path to the train cluster data parquet file", +) +parser.add_argument( + "--train-database-path", + type=Path, + required=True, + help="Path to the train sequence database file", +) +parser.add_argument( + "--valid-cluster-path", + type=Path, + required=True, + help="Path to the valid cluster data parquet file", +) +parser.add_argument( + "--valid-database-path", + type=Path, + required=True, + help="Path to the vali sequence database file", +) +parser.add_argument( + "--precision", + type=str, + choices=get_args(PrecisionTypes), + required=False, + default="bf16-mixed", + help="Precision type to use for training.", +) +parser.add_argument( + "--lr", + type=float, + required=False, + default=4e-4, + help="Learning rate for training. Default is 4e-4", +) +parser.add_argument( + "--create-tensorboard-logger", action="store_true", default=False, help="Create a tensorboard logger." +) +# FIXME (@skothenhill) figure out how checkpointing and resumption should work with the new nemo trainer +parser.add_argument( + "--resume-if-exists", action="store_true", default=False, help="Resume training if a checkpoint exists." +) +parser.add_argument( + "--result-dir", type=Path, required=False, default=Path("./results"), help="Path to the result directory." +) +parser.add_argument("--experiment-name", type=str, required=False, default="esm2", help="Name of the experiment.") +parser.add_argument("--wandb-offline", action="store_true", default=False, help="Use wandb in offline mode.") +parser.add_argument( + "--wandb-project", + type=str, + required=False, + default=None, + help="Wandb project name. Wandb will only happen if this is set.", +) +parser.add_argument( + "--num-gpus", + type=int, + required=False, + default=1, + help="Number of GPUs to use for training. Default is 1.", +) +parser.add_argument( + "--num-nodes", + type=int, + required=False, + default=1, + help="Number of nodes to use for training. Default is 1.", +) +parser.add_argument( + "--num-steps", + type=int, + required=False, + default=500000, + help="Number of steps to use for training. Default is 500000.", +) +parser.add_argument( + "--warmup-steps", + type=int, + required=False, + default=2000, + help="Number of warmup steps for WarmupAnnealDecayHold Scheduler. Default is 2000.", +) +parser.add_argument( + "--num-dataset-workers", + type=int, + required=False, + default=1, + help="Number of workers to use for training. Default is 1.", +) +parser.add_argument( + "--val-check-interval", + type=int, + required=False, + default=10000, + help="Number of steps between validation. Default is 10000.", +) +parser.add_argument( + "--seq-length", + type=int, + required=False, + default=1024, + help="Sequence length of cell. Default is 1024.", +) +parser.add_argument( + "--limit-val-batches", + type=float_or_int_or_none, + required=False, + default=2, + help="Number of global batches used for validation if int. Fraction of validation dataset if float. Default is 2.", +) +parser.add_argument( + "--micro-batch-size", + type=int, + required=False, + default=64, + help="Micro-batch size. Global batch size is inferred from this.", +) +parser.add_argument( + "--accumulate-grad-batches", + type=int, + required=False, + default=1, + help="Gradient accumulation steps. Global batch size is inferred from this.", +) +parser.add_argument( + "--biobert-spec-option", + type=BiobertSpecOption, + choices=[e.value for e in BiobertSpecOption], + required=False, + default=BiobertSpecOption.esm2_bert_layer_local_spec.value, + help="Biobert spec option to use for the model. Default is 'esm2_bert_layer_local_spec'.", +) +parser.add_argument( + "--nemo1-init-path", + type=Path, + required=False, + help="Path to nemo1 file, if desired to load at init time.", +) +parser.add_argument( + "--save-best-checkpoint", + action="store_true", + default=True, + help="Save the best checkpoint based on the metric to monitor.", +) +parser.add_argument( + "--save-last-checkpoint", + action="store_true", + default=True, + help="Save the last checkpoint.", +) +parser.add_argument( + "--metric-to-monitor-for-checkpoints", + type=str, + required=False, + default="val_loss", + help="The metric to monitor for checkpointing.", +) +parser.add_argument( + "--save-top-k", + type=int, + required=False, + default=2, + help="Save the top k checkpoints.", +) +parser.add_argument( + "--restore-from-checkpoint-path", + type=Path, + required=False, + default=None, + help="Path to the checkpoint directory to restore from. Will override `--resume-if-exists` when set.", +) + +# ESM2 specific configuration (default: 650M) +parser.add_argument( + "--num-layers", + type=int, + required=False, + default=33, + help="Number of layers in the model. Default is 33.", +) +parser.add_argument( + "--hidden-size", + type=int, + required=False, + default=1280, + help="Hidden size of the model. Default is 1280.", +) +parser.add_argument( + "--num-attention-heads", + type=int, + required=False, + default=20, + help="Number of attention heads in the model. Default is 20.", +) +parser.add_argument( + "--ffn-hidden-size", + type=int, + required=False, + default=4 * 1280, + help="FFN hidden size of the model. Default is 4 * 1280.", +) + +if __name__ == "__main__": + args = parser.parse_args() + main( + train_cluster_path=args.train_cluster_path, + train_database_path=args.train_database_path, + valid_cluster_path=args.valid_cluster_path, + valid_database_path=args.valid_database_path, + num_nodes=args.num_nodes, + devices=args.num_gpus, + seq_length=args.seq_length, + result_dir=args.result_dir, + wandb_project=args.wandb_project, + wandb_offline=args.wandb_offline, + num_steps=args.num_steps, + warmup_steps=args.warmup_steps, + limit_val_batches=args.limit_val_batches, + val_check_interval=args.val_check_interval, + num_dataset_workers=args.num_dataset_workers, + biobert_spec_option=args.biobert_spec_option, + lr=args.lr, + micro_batch_size=args.micro_batch_size, + accumulate_grad_batches=args.accumulate_grad_batches, + precision=args.precision, + experiment_name=args.experiment_name, + resume_if_exists=args.resume_if_exists, + nemo1_init_path=args.nemo1_init_path, + restore_from_checkpoint_path=args.restore_from_checkpoint_path, + save_best_checkpoint=args.save_best_checkpoint, + save_last_checkpoint=args.save_last_checkpoint, + metric_to_monitor_for_checkpoints=args.metric_to_monitor_for_checkpoints, + save_top_k=args.save_top_k, + save_every_n_steps=args.val_check_interval, + num_layers=args.num_layers, + hidden_size=args.hidden_size, + num_attention_heads=args.num_attention_heads, + ffn_hidden_size=args.ffn_hidden_size, + ) diff --git a/scripts/protein/esm2/test_esm2_pretrain.py b/scripts/protein/esm2/test_esm2_pretrain.py new file mode 100644 index 000000000..7bd4f55a5 --- /dev/null +++ b/scripts/protein/esm2/test_esm2_pretrain.py @@ -0,0 +1,389 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shlex +import sqlite3 +import subprocess +from pathlib import Path +from typing import Dict + +import pandas as pd +import pytest +from esm2_pretrain import main, parser # TODO: needs to be refactored to a package and imported! +from lightning.fabric.plugins.environments.lightning import find_free_network_port + +from bionemo import esm2 +from bionemo.llm.model.biobert.transformer_specs import BiobertSpecOption +from bionemo.llm.utils.datamodule_utils import parse_kwargs_to_arglist +from bionemo.testing import megatron_parallel_state_utils + + +# python scripts/download_artifacts.py --models all --model_dir ./models --data all --data_dir ./ --verbose --source pbss +bionemo2_root: Path = ( + # esm2 module's path is the most dependable --> don't expect this to change! + Path(esm2.__file__) + # This gets us from 'sub-packages/bionemo-esm2/src/bionemo/esm2/__init__.py' to 'sub-packages/bionemo-esm2' + .parent.parent.parent.parent + # From here, we want to get to the root of the repository: _before_ sub-packages/ + .parent.parent +).absolute() + + +@pytest.mark.skip("duplicate unittest") +@pytest.fixture +def dummy_protein_dataset(tmp_path): + """Create a mock protein dataset.""" + db_file = tmp_path / "protein_dataset.db" + conn = sqlite3.connect(str(db_file)) + cursor = conn.cursor() + + cursor.execute( + """ + CREATE TABLE protein ( + id TEXT PRIMARY KEY, + sequence TEXT + ) + """ + ) + + proteins = [ + ("UniRef90_A", "ACDEFGHIKLMNPQRSTVWY"), + ("UniRef90_B", "DEFGHIKLMNPQRSTVWYAC"), + ("UniRef90_C", "MGHIKLMNPQRSTVWYACDE"), + ("UniRef50_A", "MKTVRQERLKSIVRI"), + ("UniRef50_B", "MRILERSKEPVSGAQLA"), + ] + cursor.executemany("INSERT INTO protein VALUES (?, ?)", proteins) + + conn.commit() + conn.close() + + return db_file + + +@pytest.mark.skip("duplicate unittest") +@pytest.fixture +def dummy_parquet_train_val_inputs(tmp_path): + """Create a mock protein train and val cluster parquet.""" + train_cluster_path = tmp_path / "train_clusters.parquet" + train_clusters = pd.DataFrame( + { + "ur90_id": [["UniRef90_A"], ["UniRef90_B", "UniRef90_C"]], + } + ) + train_clusters.to_parquet(train_cluster_path) + + valid_cluster_path = tmp_path / "valid_clusters.parquet" + valid_clusters = pd.DataFrame( + { + "ur50_id": ["UniRef50_A", "UniRef50_B", "UniRef50_A", "UniRef50_B"], # 2 IDs more than confest + } + ) + valid_clusters.to_parquet(valid_cluster_path) + return train_cluster_path, valid_cluster_path + + +def test_bionemo2_rootdir(): + assert (bionemo2_root / "sub-packages").exists(), "Could not find bionemo2 root directory." + assert (bionemo2_root / "sub-packages").is_dir(), "sub-packages is supposed to be a directory." + + +@pytest.mark.skip("duplicate with argparse, model and data unittests") +def test_main_runs(tmpdir, dummy_protein_dataset, dummy_parquet_train_val_inputs): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + + result_dir = Path(tmpdir.mkdir("results")) + + with megatron_parallel_state_utils.distributed_model_parallel_state(): + main( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + num_nodes=1, + devices=1, + seq_length=128, + result_dir=result_dir, + wandb_project=None, + wandb_offline=True, + num_steps=55, + warmup_steps=5, + limit_val_batches=1, + val_check_interval=1, + num_dataset_workers=1, + biobert_spec_option=BiobertSpecOption.esm2_bert_layer_local_spec, + lr=1e-4, + micro_batch_size=2, + accumulate_grad_batches=2, + precision="bf16-mixed", + experiment_name="test_experiment", + resume_if_exists=False, + create_tensorboard_logger=False, + num_layers=2, + num_attention_heads=2, + hidden_size=4, + ffn_hidden_size=4 * 4, + ) + + assert (result_dir / "test_experiment").exists(), "Could not find test experiment directory." + assert (result_dir / "test_experiment").is_dir(), "Test experiment directory is supposed to be a directory." + children = list((result_dir / "test_experiment").iterdir()) + assert len(children) == 1, f"Expected 1 child in test experiment directory, found {children}." + uq_rundir = children[0] # it will be some date. + assert ( + result_dir / "test_experiment" / uq_rundir / "checkpoints" + ).exists(), "Could not find test experiment checkpoints directory." + assert ( + result_dir / "test_experiment" / uq_rundir / "checkpoints" + ).is_dir(), "Test experiment checkpoints directory is supposed to be a directory." + assert ( + result_dir / "test_experiment" / uq_rundir / "nemo_log_globalrank-0_localrank-0.txt" + ).is_file(), "Could not find experiment log." + + +@pytest.mark.parametrize("limit_val_batches", [1.0, 4, None]) +def test_val_dataloader_in_main_runs_with_limit_val_batches( + tmpdir, dummy_protein_dataset, dummy_parquet_train_val_inputs, limit_val_batches +): + """Ensures doesn't run out of validation samples whenever updating limit_val_batches logic. + + Args: + tmpdir (str): Temporary directory. + dummy_protein_dataset (str): Path to dummy protein dataset. + dummy_parquet_train_val_inputs (tuple[str, str]): Tuple of dummy protein train and val cluster parquet paths. + limit_val_batches (Union[int, float, None]): Limit validation batches. None implies 1.0 as in PTL. + """ + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + + result_dir = Path(tmpdir.mkdir("results")) + + with megatron_parallel_state_utils.distributed_model_parallel_state(): + main( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + num_nodes=1, + devices=1, + seq_length=128, + result_dir=result_dir, + wandb_project=None, + wandb_offline=True, + num_steps=10, + warmup_steps=2, + limit_val_batches=limit_val_batches, + val_check_interval=1, + num_dataset_workers=1, + biobert_spec_option=BiobertSpecOption.esm2_bert_layer_local_spec, + lr=1e-4, + micro_batch_size=2, + accumulate_grad_batches=1, + precision="bf16-mixed", + experiment_name="test_experiment", + resume_if_exists=False, + create_tensorboard_logger=False, + num_layers=2, + num_attention_heads=2, + hidden_size=4, + ffn_hidden_size=4 * 4, + ) + + assert (result_dir / "test_experiment").exists(), "Could not find test experiment directory." + assert (result_dir / "test_experiment").is_dir(), "Test experiment directory is supposed to be a directory." + children = list((result_dir / "test_experiment").iterdir()) + assert len(children) == 1, f"Expected 1 child in test experiment directory, found {children}." + uq_rundir = children[0] # it will be some date. + assert ( + result_dir / "test_experiment" / uq_rundir / "checkpoints" + ).exists(), "Could not find test experiment checkpoints directory." + assert ( + result_dir / "test_experiment" / uq_rundir / "checkpoints" + ).is_dir(), "Test experiment checkpoints directory is supposed to be a directory." + assert ( + result_dir / "test_experiment" / uq_rundir / "nemo_log_globalrank-0_localrank-0.txt" + ).is_file(), "Could not find experiment log." + + +@pytest.mark.skip("duplicate with argparse, model and data unittests") +def test_pretrain_cli(tmpdir, dummy_protein_dataset, dummy_parquet_train_val_inputs): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + + result_dir = Path(tmpdir.mkdir("results")) + open_port = find_free_network_port() + # NOTE: if you need to change the following command, please update the README.md example. + cmd_str = f"""python \ + {bionemo2_root}/scripts/protein/esm2/esm2_pretrain.py \ + --train-cluster-path {train_cluster_path} \ + --train-database-path {dummy_protein_dataset} \ + --valid-cluster-path {valid_cluster_path} \ + --valid-database-path {dummy_protein_dataset} \ + --result-dir {result_dir} \ + --experiment-name test_experiment \ + --num-gpus 1 \ + --num-nodes 1 \ + --val-check-interval 10 \ + --num-dataset-workers 1 \ + --num-steps 55 \ + --seq-length 128 \ + --limit-val-batches 2 \ + --micro-batch-size 2 \ + --accumulate-grad-batches 2 + """.strip() + env = dict(**os.environ) # a local copy of the environment + env["MASTER_PORT"] = str(open_port) + cmd = shlex.split(cmd_str) + result = subprocess.run( + cmd, + cwd=tmpdir, + env=env, + capture_output=True, + ) + assert result.returncode == 0, f"Pretrain script failed: {cmd_str}" + assert (result_dir / "test_experiment").exists(), "Could not find test experiment directory." + + +@pytest.fixture(scope="function") +def required_args_reference() -> Dict[str, str]: + """ + This fixture provides a dictionary of required arguments for the pretraining script. + + It includes the following keys: + - train_cluster_path: The path to the training cluster parquet file. + - train_database_path: The path to the training database file. + - valid_cluster_path: The path to the validation cluster parquet file. + - valid_database_path: The path to the validation database file. + + The values for these keys are placeholders and should be replaced with actual file paths. + + Returns: + A dictionary with the required arguments for the pretraining script. + """ + return { + "train_cluster_path": "path/to/train_cluster.parquet", + "train_database_path": "path/to/train.db", + "valid_cluster_path": "path/to/valid_cluster.parquet", + "valid_database_path": "path/to/valid.db", + } + + +# TODO(@sichu) add test on dataset/datamodule on invalid path +def test_required_train_cluster_path(required_args_reference): + """ + Test train_cluster_path is required. + + Args: + required_args_reference (Dict[str, str]): A dictionary with the required arguments for the pretraining script. + """ + required_args_reference.pop("train_cluster_path") + arglist = parse_kwargs_to_arglist(required_args_reference) + with pytest.raises(SystemExit): + parser.parse_args(arglist) + + +def test_required_train_database_path(required_args_reference): + """ + Test train_database_path is required. + + Args: + required_args_reference (Dict[str, str]): A dictionary with the required arguments for the pretraining script. + """ + required_args_reference.pop("train_database_path") + arglist = parse_kwargs_to_arglist(required_args_reference) + with pytest.raises(SystemExit): + parser.parse_args(arglist) + + +def test_required_valid_cluster_path(required_args_reference): + """ + Test valid_cluster_path is required. + + Args: + required_args_reference (Dict[str, str]): A dictionary with the required arguments for the pretraining script. + """ + required_args_reference.pop("valid_cluster_path") + arglist = parse_kwargs_to_arglist(required_args_reference) + with pytest.raises(SystemExit): + parser.parse_args(arglist) + + +def test_required_valid_database_path(required_args_reference): + """ + Test valid_database_path is required. + + Args: + required_args_reference (Dict[str, str]): A dictionary with the required arguments for the pretraining script. + """ + required_args_reference.pop("valid_database_path") + arglist = parse_kwargs_to_arglist(required_args_reference) + with pytest.raises(SystemExit): + parser.parse_args(arglist) + + +#### test expected behavior on parser #### +@pytest.mark.parametrize("limit_val_batches", [0.1, 0.5, 1.0]) +def test_limit_val_batches_is_float(required_args_reference, limit_val_batches): + """ + Test whether limit_val_batches can be parsed as a float. + + Args: + required_args_reference (Dict[str, str]): A dictionary with the required arguments for the pretraining script. + limit_val_batches (float): The value of limit_val_batches. + """ + required_args_reference["limit_val_batches"] = limit_val_batches + arglist = parse_kwargs_to_arglist(required_args_reference) + parser.parse_args(arglist) + + +@pytest.mark.parametrize("limit_val_batches", ["0.1", "0.5", "1.0"]) +def test_limit_val_batches_is_float_string(required_args_reference, limit_val_batches): + """ + Test whether limit_val_batches can be parsed as a string of float. + + Args: + required_args_reference (Dict[str, str]): A dictionary with the required arguments for the pretraining script. + limit_val_batches (float): The value of limit_val_batches. + """ + required_args_reference["limit_val_batches"] = limit_val_batches + arglist = parse_kwargs_to_arglist(required_args_reference) + parser.parse_args(arglist) + + +@pytest.mark.parametrize("limit_val_batches", [None, "None"]) +def test_limit_val_batches_is_none(required_args_reference, limit_val_batches): + """ + Test whether limit_val_batches can be parsed as none. + + Args: + required_args_reference (Dict[str, str]): A dictionary with the required arguments for the pretraining script. + """ + required_args_reference["limit_val_batches"] = limit_val_batches + arglist = parse_kwargs_to_arglist(required_args_reference) + args = parser.parse_args(arglist) + assert args.limit_val_batches is None + + +@pytest.mark.parametrize("limit_val_batches", [1, 2]) +def test_limit_val_batches_is_int(required_args_reference, limit_val_batches): + """ + Test whether limit_val_batches can be parsed as integer. + + Args: + required_args_reference (Dict[str, str]): A dictionary with the required arguments for the pretraining script. + limit_val_batches (int): The value of limit_val_batches. + """ + required_args_reference["limit_val_batches"] = limit_val_batches + arglist = parse_kwargs_to_arglist(required_args_reference) + parser.parse_args(arglist) diff --git a/scripts/singlecell/geneformer/pretrain.py b/scripts/singlecell/geneformer/pretrain.py index 3dba54d2a..14376fcbf 100644 --- a/scripts/singlecell/geneformer/pretrain.py +++ b/scripts/singlecell/geneformer/pretrain.py @@ -43,10 +43,11 @@ from bionemo.llm.lightning import LossLoggingCallback from bionemo.llm.model.biobert.lightning import BioBertLightningModule from bionemo.llm.model.biobert.model import BiobertSpecOption +from bionemo.llm.utils.datamodule_utils import float_or_int_or_none, infer_global_batch_size from bionemo.llm.utils.logger_utils import WandbLoggerOptions, setup_nemo_lightning_logger -__all__: Sequence[str] = ("main",) +__all__: Sequence[str] = ("main", "parser") def main( @@ -64,6 +65,7 @@ def main( biobert_spec_option: BiobertSpecOption, lr: float, micro_batch_size: int, + accumulate_grad_batches: int, cosine_rampup_frac: float, cosine_hold_frac: float, experiment_name: str, @@ -116,8 +118,18 @@ def main( # Setup the strategy and trainer pipeline_model_parallel_size = 1 + tensor_model_parallel_size = 1 + global_batch_size = infer_global_batch_size( + micro_batch_size=micro_batch_size, + num_nodes=num_nodes, + devices=devices, + accumulate_grad_batches=accumulate_grad_batches, + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + ) + strategy = nl.MegatronStrategy( - tensor_model_parallel_size=1, + tensor_model_parallel_size=tensor_model_parallel_size, pipeline_model_parallel_size=pipeline_model_parallel_size, ddp="megatron", find_unused_parameters=True, @@ -172,7 +184,7 @@ def main( random_token_prob=0.02, # changed to represent the incorrect setting we originally used. median_dict=median_dict, micro_batch_size=micro_batch_size, - global_batch_size=micro_batch_size * int(num_nodes * devices / pipeline_model_parallel_size), + global_batch_size=global_batch_size, # persistent workers is supported when num_dataset_workers > 0 persistent_workers=num_dataset_workers > 0, pin_memory=False, @@ -265,170 +277,177 @@ def main( ) -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Pretrain Geneformer with single cell data.") - parser.add_argument( - "--data-dir", - type=Path, - required=True, - help="Path to the data base directory, for example this might be " - "/workspace/bionemo2/data/cellxgene_2023-12-15_small", - ) - parser.add_argument( - "--precision", - type=str, - choices=get_args(PrecisionTypes), - required=False, - default="bf16-mixed", - help="Precision type to use for training.", - ) - parser.add_argument( - "--lr", - type=float, - required=False, - default=1e-4, - help="Learning rate for training. Default is 1e-4. With bigger global batches try 1e-3", - ) - parser.add_argument( - "--create-tensorboard-logger", action="store_true", default=False, help="Create a tensorboard logger." - ) - # FIXME (@skothenhill) figure out how checkpointing and resumption should work with the new nemo trainer - parser.add_argument( - "--resume-if-exists", action="store_true", default=False, help="Resume training if a checkpoint exists." - ) - parser.add_argument( - "--result-dir", type=Path, required=False, default=Path("./results"), help="Path to the result directory." - ) - parser.add_argument( - "--experiment-name", type=str, required=False, default="geneformer", help="Name of the experiment." - ) - parser.add_argument("--wandb-offline", action="store_true", default=False, help="Use wandb in offline mode.") - parser.add_argument( - "--wandb-project", - type=str, - required=False, - default=None, - help="Wandb project name. Wandb will only happen if this is set..", - ) - parser.add_argument( - "--cosine-rampup-frac", - type=float, - required=False, - default=0.01, - help="Fraction of steps in which to ramp up the learning rate. Default is 0.01.", - ) - parser.add_argument( - "--cosine-hold-frac", - type=float, - required=False, - default=0.05, - help="Fraction of final steps in which to hold the minimum LR. Default is 0.05.", - ) +parser = argparse.ArgumentParser(description="Pretrain Geneformer with single cell data.") +parser.add_argument( + "--data-dir", + type=Path, + required=True, + help="Path to the data base directory, for example this might be " + "/workspace/bionemo2/data/cellxgene_2023-12-15_small", +) +parser.add_argument( + "--precision", + type=str, + choices=get_args(PrecisionTypes), + required=False, + default="bf16-mixed", + help="Precision type to use for training.", +) +parser.add_argument( + "--lr", + type=float, + required=False, + default=1e-4, + help="Learning rate for training. Default is 1e-4. With bigger global batches try 1e-3", +) +parser.add_argument( + "--create-tensorboard-logger", action="store_true", default=False, help="Create a tensorboard logger." +) +# FIXME (@skothenhill) figure out how checkpointing and resumption should work with the new nemo trainer +parser.add_argument( + "--resume-if-exists", action="store_true", default=False, help="Resume training if a checkpoint exists." +) +parser.add_argument( + "--result-dir", type=Path, required=False, default=Path("./results"), help="Path to the result directory." +) +parser.add_argument( + "--experiment-name", type=str, required=False, default="geneformer", help="Name of the experiment." +) +parser.add_argument("--wandb-offline", action="store_true", default=False, help="Use wandb in offline mode.") +parser.add_argument( + "--wandb-project", + type=str, + required=False, + default=None, + help="Wandb project name. Wandb will only happen if this is set..", +) +parser.add_argument( + "--cosine-rampup-frac", + type=float, + required=False, + default=0.01, + help="Fraction of steps in which to ramp up the learning rate. Default is 0.01.", +) +parser.add_argument( + "--cosine-hold-frac", + type=float, + required=False, + default=0.05, + help="Fraction of final steps in which to hold the minimum LR. Default is 0.05.", +) - parser.add_argument( - "--num-gpus", - type=int, - required=False, - default=1, - help="Number of GPUs to use for training. Default is 1.", - ) - parser.add_argument( - "--num-nodes", - type=int, - required=False, - default=1, - help="Number of nodes to use for training. Default is 1.", - ) - parser.add_argument( - "--num-steps", - type=int, - required=False, - default=10000, - help="Number of steps to use for training. Default is 10000.", - ) - parser.add_argument( - "--num-dataset-workers", - type=int, - required=False, - default=0, - help="Number of steps to use for training. Default is 0.", - ) - parser.add_argument( - "--val-check-interval", - type=int, - required=False, - default=10000, - help="Number of steps to use for training. Default is 10000.", - ) - parser.add_argument( - "--seq-length", - type=int, - required=False, - default=2048, - help="Sequence length of cell. Default is 2048.", - ) - parser.add_argument( - "--limit-val-batches", - type=int, - required=False, - default=2, - help="Number of steps to use for training. Default is 2.", - ) - parser.add_argument( - "--micro-batch-size", - type=int, - required=False, - default=64, - help="Micro-batch size. Global batch size is inferred from this.", - ) - parser.add_argument( - "--biobert-spec-option", - type=BiobertSpecOption, - choices=[e.value for e in BiobertSpecOption], - required=False, - default=BiobertSpecOption.bert_layer_local_spec.value, - help="Biobert spec option to use for the model. Default is 'bert_layer_local_spec'.", - ) - parser.add_argument( - "--nemo1-init-path", - type=Path, - required=False, - help="Path to nemo1 file, if desired to load at init time.", - ) - parser.add_argument( - "--save-best-checkpoint", - action="store_true", - default=True, - help="Save the best checkpoint based on the metric to monitor.", - ) - parser.add_argument( - "--save-last-checkpoint", - action="store_true", - default=True, - help="Save the last checkpoint.", - ) - parser.add_argument( - "--metric-to-monitor-for-checkpoints", - type=str, - required=False, - default="val_loss", - help="The metric to monitor for checkpointing.", - ) - parser.add_argument( - "--save-top-k", - type=int, - required=False, - default=2, - help="Save the top k checkpoints.", - ) - parser.add_argument( - "--restore-from-checkpoint-path", - type=Path, - required=False, - default=None, - help="Path to the checkpoint directory to restore from. Will override `--resume-if-exists` when set.", - ) +parser.add_argument( + "--num-gpus", + type=int, + required=False, + default=1, + help="Number of GPUs to use for training. Default is 1.", +) +parser.add_argument( + "--num-nodes", + type=int, + required=False, + default=1, + help="Number of nodes to use for training. Default is 1.", +) +parser.add_argument( + "--num-steps", + type=int, + required=False, + default=10000, + help="Number of steps to use for training. Default is 10000.", +) +parser.add_argument( + "--num-dataset-workers", + type=int, + required=False, + default=0, + help="Number of steps to use for training. Default is 0.", +) +parser.add_argument( + "--val-check-interval", + type=int, + required=False, + default=10000, + help="Number of steps to use for training. Default is 10000.", +) +parser.add_argument( + "--seq-length", + type=int, + required=False, + default=2048, + help="Sequence length of cell. Default is 2048.", +) +parser.add_argument( + "--limit-val-batches", + type=float_or_int_or_none, + required=False, + default=2, + help="Number of global batches used for validation if int. Fraction of validation dataset if float. Default is 2.", +) +parser.add_argument( + "--micro-batch-size", + type=int, + required=False, + default=64, + help="Micro-batch size. Global batch size is inferred from this.", +) +parser.add_argument( + "--accumulate-grad-batches", + type=int, + required=False, + default=1, + help="Gradient accumulation steps. Global batch size is inferred from this.", +) +parser.add_argument( + "--biobert-spec-option", + type=BiobertSpecOption, + choices=[e.value for e in BiobertSpecOption], + required=False, + default=BiobertSpecOption.bert_layer_local_spec.value, + help="Biobert spec option to use for the model. Default is 'bert_layer_local_spec'.", +) +parser.add_argument( + "--nemo1-init-path", + type=Path, + required=False, + help="Path to nemo1 file, if desired to load at init time.", +) +parser.add_argument( + "--save-best-checkpoint", + action="store_true", + default=True, + help="Save the best checkpoint based on the metric to monitor.", +) +parser.add_argument( + "--save-last-checkpoint", + action="store_true", + default=True, + help="Save the last checkpoint.", +) +parser.add_argument( + "--metric-to-monitor-for-checkpoints", + type=str, + required=False, + default="val_loss", + help="The metric to monitor for checkpointing.", +) +parser.add_argument( + "--save-top-k", + type=int, + required=False, + default=2, + help="Save the top k checkpoints.", +) +parser.add_argument( + "--restore-from-checkpoint-path", + type=Path, + required=False, + default=None, + help="Path to the checkpoint directory to restore from. Will override `--resume-if-exists` when set.", +) +if __name__ == "__main__": # Parse the arguments and pull them out into local variables for ease of future refactor to a # config management system. args = parser.parse_args() @@ -447,6 +466,7 @@ def main( biobert_spec_option=args.biobert_spec_option, lr=args.lr, micro_batch_size=args.micro_batch_size, + accumulate_grad_batches=args.accumulate_grad_batches, cosine_rampup_frac=args.cosine_rampup_frac, cosine_hold_frac=args.cosine_hold_frac, precision=args.precision, diff --git a/scripts/singlecell/geneformer/test_pretrain.py b/scripts/singlecell/geneformer/test_pretrain.py index cd188d7ce..8eb485e04 100644 --- a/scripts/singlecell/geneformer/test_pretrain.py +++ b/scripts/singlecell/geneformer/test_pretrain.py @@ -17,12 +17,16 @@ import shlex import subprocess from pathlib import Path +from typing import Dict +import pytest from lightning.fabric.plugins.environments.lightning import find_free_network_port -from pretrain import main # TODO: needs to be refactored to a package and imported! +from pretrain import main, parser # TODO: needs to be refactored to a package and imported! from bionemo import geneformer from bionemo.llm.model.biobert.transformer_specs import BiobertSpecOption +from bionemo.llm.utils.datamodule_utils import parse_kwargs_to_arglist +from bionemo.testing import megatron_parallel_state_utils # TODO(@jstjohn) use fixtures for pulling down data and checkpoints @@ -35,7 +39,6 @@ # From here, we want to get to the root of the repository: _before_ sub-packages/ .parent.parent ).absolute() -assert bionemo2_root != Path("/") data_path: Path = bionemo2_root / "test_data/cellxgene_2023-12-15_small/processed_data" @@ -50,31 +53,34 @@ def test_bionemo2_rootdir(): assert data_path.is_dir(), f"Test data directory is supposed to be a directory.\n{data_error_str}" +@pytest.mark.skip("duplicate unittest") def test_main_runs(tmpdir): result_dir = Path(tmpdir.mkdir("results")) - main( - data_dir=data_path, - num_nodes=1, - devices=1, - seq_length=128, - result_dir=result_dir, - wandb_project=None, - wandb_offline=True, - num_steps=55, - limit_val_batches=1, - val_check_interval=1, - num_dataset_workers=0, - biobert_spec_option=BiobertSpecOption.bert_layer_local_spec, - lr=1e-4, - micro_batch_size=2, - cosine_rampup_frac=0.01, - cosine_hold_frac=0.01, - precision="bf16-mixed", - experiment_name="test_experiment", - resume_if_exists=False, - create_tensorboard_logger=False, - ) + with megatron_parallel_state_utils.distributed_model_parallel_state(): + main( + data_dir=data_path, + num_nodes=1, + devices=1, + seq_length=128, + result_dir=result_dir, + wandb_project=None, + wandb_offline=True, + num_steps=55, + limit_val_batches=1, + val_check_interval=1, + num_dataset_workers=0, + biobert_spec_option=BiobertSpecOption.bert_layer_local_spec, + lr=1e-4, + micro_batch_size=2, + accumulate_grad_batches=2, + cosine_rampup_frac=0.01, + cosine_hold_frac=0.01, + precision="bf16-mixed", + experiment_name="test_experiment", + resume_if_exists=False, + create_tensorboard_logger=False, + ) assert (result_dir / "test_experiment").exists(), "Could not find test experiment directory." assert (result_dir / "test_experiment").is_dir(), "Test experiment directory is supposed to be a directory." @@ -92,12 +98,13 @@ def test_main_runs(tmpdir): ).is_file(), "Could not find experiment log." +@pytest.mark.skip("duplicate unittest") def test_pretrain_cli(tmpdir): result_dir = Path(tmpdir.mkdir("results")) open_port = find_free_network_port() # NOTE: if you need to change the following command, please update the README.md example. cmd_str = f"""python \ - scripts/singlecell/geneformer/pretrain.py \ + {bionemo2_root}/scripts/singlecell/geneformer/pretrain.py \ --data-dir {data_path} \ --result-dir {result_dir} \ --experiment-name test_experiment \ @@ -108,16 +115,103 @@ def test_pretrain_cli(tmpdir): --num-steps 55 \ --seq-length 128 \ --limit-val-batches 2 \ - --micro-batch-size 2 + --micro-batch-size 2 \ + --accumulate-grad-batches 2 """.strip() env = dict(**os.environ) # a local copy of the environment env["MASTER_PORT"] = str(open_port) cmd = shlex.split(cmd_str) result = subprocess.run( cmd, - cwd=bionemo2_root, + cwd=tmpdir, env=env, capture_output=True, ) assert result.returncode == 0, f"Pretrain script failed: {cmd_str}" assert (result_dir / "test_experiment").exists(), "Could not find test experiment directory." + + +@pytest.fixture(scope="function") +def required_args_reference() -> Dict[str, str]: + """ + This fixture provides a dictionary of required arguments for the pretraining script. + + It includes the following keys: + - data_dir: The path to the data directory. + + Returns: + A dictionary with the required arguments for the pretraining script. + """ + return { + "data_dir": "path/to/cellxgene_2023-12-15_small", + } + + +def test_required_data_dir(required_args_reference): + """ + Test data_dir is required. + + Args: + required_args_reference (Dict[str, str]): A dictionary with the required arguments for the pretraining script. + """ + required_args_reference.pop("data_dir") + arglist = parse_kwargs_to_arglist(required_args_reference) + with pytest.raises(SystemExit): + parser.parse_args(arglist) + + +#### test expected behavior on parser #### +@pytest.mark.parametrize("limit_val_batches", [0.1, 0.5, 1.0]) +def test_limit_val_batches_is_float(required_args_reference, limit_val_batches): + """ + Test whether limit_val_batches can be parsed as a float. + + Args: + required_args_reference (Dict[str, str]): A dictionary with the required arguments for the pretraining script. + limit_val_batches (float): The value of limit_val_batches. + """ + required_args_reference["limit_val_batches"] = limit_val_batches + arglist = parse_kwargs_to_arglist(required_args_reference) + parser.parse_args(arglist) + + +@pytest.mark.parametrize("limit_val_batches", ["0.1", "0.5", "1.0"]) +def test_limit_val_batches_is_float_string(required_args_reference, limit_val_batches): + """ + Test whether limit_val_batches can be parsed as a string of float. + + Args: + required_args_reference (Dict[str, str]): A dictionary with the required arguments for the pretraining script. + limit_val_batches (float): The value of limit_val_batches. + """ + required_args_reference["limit_val_batches"] = limit_val_batches + arglist = parse_kwargs_to_arglist(required_args_reference) + parser.parse_args(arglist) + + +@pytest.mark.parametrize("limit_val_batches", [None, "None"]) +def test_limit_val_batches_is_none(required_args_reference, limit_val_batches): + """ + Test whether limit_val_batches can be parsed as none. + + Args: + required_args_reference (Dict[str, str]): A dictionary with the required arguments for the pretraining script. + """ + required_args_reference["limit_val_batches"] = limit_val_batches + arglist = parse_kwargs_to_arglist(required_args_reference) + args = parser.parse_args(arglist) + assert args.limit_val_batches is None + + +@pytest.mark.parametrize("limit_val_batches", [1, 2]) +def test_limit_val_batches_is_int(required_args_reference, limit_val_batches): + """ + Test whether limit_val_batches can be parsed as integer. + + Args: + required_args_reference (Dict[str, str]): A dictionary with the required arguments for the pretraining script. + limit_val_batches (int): The value of limit_val_batches. + """ + required_args_reference["limit_val_batches"] = limit_val_batches + arglist = parse_kwargs_to_arglist(required_args_reference) + parser.parse_args(arglist) diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/resamplers.py b/sub-packages/bionemo-core/src/bionemo/core/data/resamplers.py index e9848eb9d..490a8867a 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/resamplers.py +++ b/sub-packages/bionemo-core/src/bionemo/core/data/resamplers.py @@ -21,10 +21,10 @@ from torch.utils.data import Dataset -class PRNGDatasetShuffler(Dataset): +class PRNGResampleDataset(Dataset): """A thread-safe dataset shuffler that uses a pseudo-random number generator (PRNG) to shuffle the dataset. - PRNGDatasetShuffler shuffles a given dataset using a pseudo-random number generator (PRNG). This allows for + PRNGResampleDataset shuffles a given dataset using a pseudo-random number generator (PRNG). This allows for reproducible shuffling by controlling the random seed, while not ever storing the list of indices in memory. It works by generating random indices assuming that the requesting function asks for them sequentially. Although random lookups are supported, random lookups will involve recomputing state which is slow, and involves linearly advancing @@ -34,7 +34,7 @@ class PRNGDatasetShuffler(Dataset): """ def __init__(self, dataset: Dataset, seed: int = 42, num_samples: Optional[int] = None): - """Initializes the PRNGDatasetShuffler. + """Initializes the PRNGResampleDataset. Args: dataset (Dataset): The dataset to be shuffled. @@ -105,4 +105,4 @@ def __len__(self) -> int: return self.num_samples -__all__ = ["PRNGDatasetShuffler"] +__all__ = ["PRNGResampleDataset"] diff --git a/sub-packages/bionemo-core/tests/bionemo/pytorch/data/test_resamplers.py b/sub-packages/bionemo-core/tests/bionemo/pytorch/data/test_resamplers.py index 534efef3d..06348f7f3 100644 --- a/sub-packages/bionemo-core/tests/bionemo/pytorch/data/test_resamplers.py +++ b/sub-packages/bionemo-core/tests/bionemo/pytorch/data/test_resamplers.py @@ -18,14 +18,14 @@ import pytest -from bionemo.core.data.resamplers import PRNGDatasetShuffler +from bionemo.core.data.resamplers import PRNGResampleDataset def test_prng_dataset_sequential_shuffler_full(): - """Test that the PRNGDatasetShuffler returns the same results as a random number generator when accessed sequentially.""" + """Test that the PRNGResampleDataset returns the same results as a random number generator when accessed sequentially.""" dataset = list(range(10)) seed = 42 - shuffled_dataset = PRNGDatasetShuffler(dataset, seed=seed, num_samples=100) + shuffled_dataset = PRNGResampleDataset(dataset, seed=seed, num_samples=100) rng = random.Random(seed) expected_output_full = [rng.randint(0, 9) for _ in range(100)] full_output = [shuffled_dataset[i] for i in range(100)] @@ -34,12 +34,12 @@ def test_prng_dataset_sequential_shuffler_full(): @pytest.mark.parametrize("modulo_remainder", [0, 1, 4, 9]) def test_prng_dataset_sequential_shuffler_skips(modulo_remainder: int): - """Test that the PRNGDatasetShuffler returns the same results as a random number generator when accessed sequentially but with + """Test that the PRNGResampleDataset returns the same results as a random number generator when accessed sequentially but with some indices skipped, as would happen in a parallel dataloader context. """ dataset = list(range(100)) seed = 42 - shuffled_dataset = PRNGDatasetShuffler(dataset, seed=seed, num_samples=1000) + shuffled_dataset = PRNGResampleDataset(dataset, seed=seed, num_samples=1000) rng = random.Random(seed) expected_output_full = [rng.randint(0, 99) for _ in range(1000)] every_10th_output = [shuffled_dataset[i] for i in range(1000) if i % 10 == modulo_remainder] @@ -48,13 +48,13 @@ def test_prng_dataset_sequential_shuffler_skips(modulo_remainder: int): @pytest.mark.parametrize("modulo_remainder", [0, 1]) def test_prng_dataset_random_shuffler_skips(modulo_remainder: int): - """Test that the PRNGDatasetShuffler returns the same results as a random number generator when accessed in a random order + """Test that the PRNGResampleDataset returns the same results as a random number generator when accessed in a random order and with some indices skipped as well. This is what would happen if a user did an unexpected thing and called this on a dataset in a random order. This is expected to be slower but we still want it to work. """ dataset = list(range(100)) seed = 42 - shuffled_dataset = PRNGDatasetShuffler(dataset, seed=seed, num_samples=1000) + shuffled_dataset = PRNGResampleDataset(dataset, seed=seed, num_samples=1000) rng = random.Random(seed) expected_output_full = [rng.randint(0, 99) for _ in range(1000)] indices_to_check = [i for i in range(1000) if i % 10 == modulo_remainder] @@ -69,7 +69,7 @@ def test_repeated_lookups(): """Test that repeated lookups of the same index return the same value.""" dataset = list(range(100)) seed = 42 - shuffled_dataset = PRNGDatasetShuffler(dataset, seed=seed, num_samples=1000) + shuffled_dataset = PRNGResampleDataset(dataset, seed=seed, num_samples=1000) rng = random.Random(seed) expected_output_full = [rng.randint(0, 99) for _ in range(1000)] indices_to_check = [i for i in range(1000) if i % 10 == 3] diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/data/datamodule.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/data/datamodule.py index 009bb8901..5e4b99e63 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/data/datamodule.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/data/datamodule.py @@ -25,9 +25,11 @@ from nemo.utils import logging from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from bionemo.core.data.resamplers import PRNGResampleDataset from bionemo.core.utils import random_utils from bionemo.esm2.data import dataset, tokenizer from bionemo.llm.data import collate +from bionemo.llm.utils.datamodule_utils import infer_num_samples class ESMDataModule(pl.LightningDataModule): @@ -51,7 +53,7 @@ def __init__( mask_prob: float = 0.15, mask_token_prob: float = 0.8, mask_random_prob: float = 0.1, - tokenizer: tokenizer.HFTokenizer = tokenizer.get_tokenizer(), + tokenizer: tokenizer.BioNeMoAutoTokenizer = tokenizer.get_tokenizer(), ) -> None: """Initialize the ESMDataModule. @@ -97,7 +99,7 @@ def __init__( seq_len=max_seq_length, micro_batch_size=micro_batch_size, global_batch_size=global_batch_size, - dataloader_type="cyclic", # This should attach a `MegatronPretrainingRandomSampler`. + dataloader_type="single", # `MegatronPretrainingRandomSampler` from "cyclic" is failing. rampup_batch_size=rampup_batch_size, ) @@ -126,15 +128,11 @@ def setup(self, stage: str = "") -> None: if max_train_steps <= 0: raise RuntimeError("Please specify trainer.max_steps") - eval_iters = int((max_train_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches) - num_train_samples = int(max_train_steps * self.data_sampler.global_batch_size) - num_val_samples = int(eval_iters * self.data_sampler.global_batch_size) - - if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float): - # This is to make sure we only have one epoch on every validation iteration - num_val_samples = 1 - - self._train_ds = dataset.create_train_dataset( + # Create training dataset + num_train_samples = int( + max_train_steps * self.data_sampler.global_batch_size + ) # training data requires upsampling (multiply by max_train_steps) on single MegatronPretrainingRandomSampler + _train_ds = dataset.create_train_dataset( cluster_file=self._train_cluster_path, db_path=self._train_database_path, total_samples=num_train_samples, @@ -145,9 +143,20 @@ def setup(self, stage: str = "") -> None: mask_random_prob=self._mask_random_prob, tokenizer=self._tokenizer, ) - - self._valid_ds = dataset.create_valid_dataset( - cluster_file=self._valid_cluster_path, + self._train_ds = self._sample_and_shuffle_dataset( + _train_ds, None, "train" + ) # shuffle manually without cyclic MegatronPretrainingRandomSampler + + # Create validation dataset + val_clusters = dataset.create_valid_clusters(self._valid_cluster_path) + num_val_samples = infer_num_samples( + limit_batches=self.trainer.limit_val_batches, + num_samples_in_dataset=len(val_clusters), + global_batch_size=self.data_sampler.global_batch_size, + stage="val", + ) + _valid_ds = dataset.create_valid_dataset( + clusters=self._valid_cluster_path, db_path=self._valid_database_path, total_samples=num_val_samples, seed=random_utils.get_seed_from_rng(rng), @@ -157,6 +166,9 @@ def setup(self, stage: str = "") -> None: mask_random_prob=self._mask_random_prob, tokenizer=self._tokenizer, ) + self._valid_ds = self._sample_and_shuffle_dataset( + _valid_ds, None, "val" + ) # shuffle manually without cyclic MegatronPretrainingRandomSampler assert ( hasattr(self, "trainer") and self.trainer is not None @@ -190,3 +202,20 @@ def val_dataloader(self) -> EVAL_DATALOADERS: def test_dataloader(self) -> EVAL_DATALOADERS: """Raises a not implemented error.""" raise NotImplementedError("No test dataset provided for ESM2") + + def _sample_and_shuffle_dataset(self, dataset: dataset.ESMMaskedResidueDataset, num_samples: int, stage: str): # noqa: D417 + """Sample the training dataset. + + Args: + dataset (torch.utils.data.Dataset): The dataset to sample from + + Returns: + ResamplingMappedDataset: Resampled dataset + + """ + # This is where re-sampling occurs. + return PRNGResampleDataset( + dataset, + num_samples=num_samples, + seed=self._seed + len(stage), + ) diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/data/dataset.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/data/dataset.py index 20823d206..7744af5d7 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/data/dataset.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/data/dataset.py @@ -40,6 +40,7 @@ def __init__(self, db_path: str | os.PathLike): """ self.conn = sqlite3.connect(str(db_path)) self.cursor = self.conn.cursor() + self._len = None def __len__(self) -> int: """Returns the number of proteins in the dataset. @@ -47,8 +48,10 @@ def __len__(self) -> int: Returns: Number of proteins in the dataset. """ - self.cursor.execute("SELECT COUNT(*) FROM protein") - return int(self.cursor.fetchone()[0]) + if self._len is None: + self.cursor.execute("SELECT COUNT(*) FROM protein") + self._len = int(self.cursor.fetchone()[0]) + return self._len def __getitem__(self, idx: str) -> str: """Returns the sequence of a protein at a given index. @@ -87,7 +90,7 @@ def __init__( mask_prob: float = 0.15, mask_token_prob: float = 0.8, mask_random_prob: float = 0.1, - tokenizer: tokenizer.HFTokenizer = tokenizer.get_tokenizer(), + tokenizer: tokenizer.BioNeMoAutoTokenizer = tokenizer.get_tokenizer(), ) -> None: """Initializes the dataset. @@ -202,7 +205,7 @@ def create_train_dataset( mask_prob: float = 0.15, mask_token_prob: float = 0.8, mask_random_prob: float = 0.1, - tokenizer: tokenizer.HFTokenizer = tokenizer.get_tokenizer(), + tokenizer: tokenizer.BioNeMoAutoTokenizer = tokenizer.get_tokenizer(), ): """Creates a training dataset for ESM pretraining. @@ -249,8 +252,30 @@ def create_train_dataset( ) -def create_valid_dataset( - cluster_file: str | os.PathLike, +def create_valid_clusters(cluster_file: str | os.PathLike) -> pd.Series: + """Create a pandas series of UniRef50 cluster IDs from a cluster parquet file. + + Args: + cluster_file: Path to the cluster file. The file should contain a single column named "ur50_id" with UniRef50 + IDs, with one UniRef50 ID per row. + + Returns: + A pandas series of UniRef50 cluster IDs. + """ + if not Path(cluster_file).exists(): + raise ValueError(f"Cluster file {cluster_file} not found.") + + cluster_df = pd.read_parquet(cluster_file) + if "ur50_id" not in cluster_df.columns: + raise ValueError( + f"Validation cluster file must contain a 'ur50_id' column. Found columns {cluster_df.columns}." + ) + clusters = cluster_df["ur50_id"].apply(lambda x: [x]) + return clusters + + +def create_valid_dataset( # noqa: D417 + clusters: pd.Series | str | os.PathLike, db_path: str | os.PathLike, total_samples: int, seed: int, @@ -258,12 +283,12 @@ def create_valid_dataset( mask_prob: float = 0.15, mask_token_prob: float = 0.8, mask_random_prob: float = 0.1, - tokenizer: tokenizer.HFTokenizer = tokenizer.get_tokenizer(), + tokenizer: tokenizer.BioNeMoAutoTokenizer = tokenizer.get_tokenizer(), ): """Creates a validation dataset for ESM pretraining. Args: - cluster_file: Path to the cluster file. The file should contain a single column named "ur50_id" with UniRef50 + cluster_file: Clusters as pd.Series, or path to the cluster file. The file should contain a single column named "ur50_id" with UniRef50 IDs, with one UniRef50 ID per row. db_path: Path to the SQLite database. total_samples: Total number of samples to draw from the dataset. @@ -281,22 +306,17 @@ def create_valid_dataset( ValueError: If the cluster file does not exist, the database file does not exist, or the cluster file does not contain a "ur50_id" column. """ - if not Path(cluster_file).exists(): - raise ValueError(f"Cluster file {cluster_file} not found.") + if isinstance(clusters, (str, os.PathLike)): + clusters = create_valid_clusters(clusters) + elif not isinstance(clusters, pd.Series): + raise ValueError(f"Clusters must be a pandas Series. Got {type(clusters)}.") if not Path(db_path).exists(): raise ValueError(f"Database file {db_path} not found.") protein_dataset = ProteinSQLiteDataset(db_path) - cluster_df = pd.read_parquet(cluster_file) - if "ur50_id" not in cluster_df.columns: - raise ValueError( - f"Validation cluster file must contain a 'ur50_id' column. Found columns {cluster_df.columns}." - ) - # Create a single bucket for each UniRef50 cluster. - clusters = cluster_df["ur50_id"].apply(lambda x: [x]) return ESMMaskedResidueDataset( protein_dataset=protein_dataset, clusters=clusters, @@ -315,7 +335,7 @@ def create_valid_dataset( def _random_crop(s: _T, crop_length: int, rng: np.random.Generator) -> _T: """Randomly crops a input to a maximum length.""" - if crop_length > len(s): + if crop_length >= len(s): return s start_index = rng.integers(0, len(s) - crop_length) diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/data/tokenizer/__init__.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/data/tokenizer/__init__.py index 07059d52c..8e2fbd4e7 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/data/tokenizer/__init__.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/data/tokenizer/__init__.py @@ -13,17 +13,29 @@ # See the License for the specific language governing permissions and # limitations under the License. - import functools from pathlib import Path import transformers +from nemo.lightning.io import IOMixin + +class BioNeMoAutoTokenizer(transformers.AutoTokenizer, IOMixin): # noqa D101 + def __init__(self, pretrained_model_name, use_fast=True): + """A wrapper to make AutoTokenizer serializable. -HFTokenizer = transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast + Args: + pretrained_model_name: A string, the *model id* of a predefined tokenizer hosted on huggingface + use_fast: Use a [fast Rust-based tokenizer](https://huggingface.co/docs/tokenizers/index) + if it is supported for a given model. Defaults to True. + """ + other = self.from_pretrained(pretrained_model_name, use_fast=use_fast) + for attr in dir(other): + if not attr.startswith("_"): + setattr(self, attr, getattr(other, attr)) @functools.cache -def get_tokenizer() -> HFTokenizer: +def get_tokenizer() -> BioNeMoAutoTokenizer: """Get the tokenizer for the ESM2 model.""" - return transformers.AutoTokenizer.from_pretrained(Path(__file__).parent.resolve(), use_fast=True) + return BioNeMoAutoTokenizer(Path(__file__).parent.resolve().as_posix(), use_fast=True) diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py index 794a31876..ef1eb23b3 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py @@ -32,13 +32,14 @@ from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.transformer.utils import get_linear_layer -from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo.lightning import get_vocab_size +from nemo.lightning.io import IOMixin from nemo.lightning.megatron_parallel import MegatronLossReduction from torch import Tensor from torch.optim import Optimizer from bionemo.core.model.config import BionemoModelConfig +from bionemo.esm2.data.tokenizer import BioNeMoAutoTokenizer from bionemo.esm2.model.attention import ESM2DotProductAttention from bionemo.esm2.model.embedding import ESM2Embedding from bionemo.llm.model.biobert.model import MegatronBioBertModel @@ -63,7 +64,7 @@ def __init__( transformer_layer_spec: spec_utils.ModuleSpec, vocab_size: int, max_sequence_length: int, - tokenizer: AutoTokenizer = None, + tokenizer: Optional[BioNeMoAutoTokenizer] = None, pre_process: bool = True, post_process: bool = True, fp16_lm_cross_entropy: bool = False, @@ -136,7 +137,7 @@ def __init__( # ESM2 NEW ARGS token_dropout=self.config.token_dropout, use_attention_mask=self.config.use_attention_mask, - mask_token_id=tokenizer.mask_id, + mask_token_id=tokenizer.mask_token_id, ) if self.position_embedding_type == "rope": @@ -220,7 +221,7 @@ def esm_gelu_func(x: Tensor) -> Tensor: @dataclass -class ESM2Config(BionemoModelConfig[ESM2Model], TransformerConfig): +class ESM2Config(BionemoModelConfig[ESM2Model], TransformerConfig, IOMixin): """Configuration class for ESM2 model. Attributes: diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/data/conftest.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/conftest.py similarity index 70% rename from sub-packages/bionemo-esm2/tests/bionemo/esm2/data/conftest.py rename to sub-packages/bionemo-esm2/tests/bionemo/esm2/conftest.py index 90d97f3a8..a5d332758 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/data/conftest.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/conftest.py @@ -16,6 +16,7 @@ import sqlite3 +import pandas as pd import pytest from bionemo.esm2.data.tokenizer import get_tokenizer @@ -56,3 +57,24 @@ def dummy_protein_dataset(tmp_path): conn.close() return db_file + + +@pytest.fixture +def dummy_parquet_train_val_inputs(tmp_path): + """Create a mock protein train and val cluster parquet.""" + train_cluster_path = tmp_path / "train_clusters.parquet" + train_clusters = pd.DataFrame( + { + "ur90_id": [["UniRef90_A"], ["UniRef90_B", "UniRef90_C"]], + } + ) + train_clusters.to_parquet(train_cluster_path) + + valid_cluster_path = tmp_path / "valid_clusters.parquet" + valid_clusters = pd.DataFrame( + { + "ur50_id": ["UniRef50_A", "UniRef50_B"], + } + ) + valid_clusters.to_parquet(valid_cluster_path) + return train_cluster_path, valid_cluster_path diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/data/test_datamodule.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/data/test_datamodule.py index b15fd59fc..5ea654107 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/data/test_datamodule.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/data/test_datamodule.py @@ -15,39 +15,21 @@ from unittest import mock -import pandas as pd import pytest import torch.utils.data from bionemo.esm2.data.datamodule import ESMDataModule +from bionemo.llm.utils.datamodule_utils import tensor_dict_hash -def _create_dummy_parquet_inputs(tmp_path): - train_cluster_path = tmp_path / "train_clusters.parquet" - train_clusters = pd.DataFrame( - { - "ur90_id": [["UniRef90_A"], ["UniRef90_B", "UniRef90_C"]], - } - ) - train_clusters.to_parquet(train_cluster_path) - - valid_cluster_path = tmp_path / "valid_clusters.parquet" - valid_clusters = pd.DataFrame( - { - "ur50_id": ["UniRef50_A", "UniRef50_B"], - } - ) - valid_clusters.to_parquet(valid_cluster_path) - - -def test_create_esm_datamodule_raises_without_trainer(tmp_path, dummy_protein_dataset): - _create_dummy_parquet_inputs(tmp_path) +def test_create_esm_datamodule_raises_without_trainer(dummy_protein_dataset, dummy_parquet_train_val_inputs): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs # Initialize the data module. data_module = ESMDataModule( - train_cluster_path=tmp_path / "train_clusters.parquet", + train_cluster_path=train_cluster_path, train_database_path=dummy_protein_dataset, - valid_cluster_path=tmp_path / "valid_clusters.parquet", + valid_cluster_path=valid_cluster_path, valid_database_path=dummy_protein_dataset, ) assert data_module is not None @@ -56,14 +38,14 @@ def test_create_esm_datamodule_raises_without_trainer(tmp_path, dummy_protein_da data_module.setup() -def test_create_esm_datamodule_raises_without_trainer_max_steps(tmp_path, dummy_protein_dataset): - _create_dummy_parquet_inputs(tmp_path) +def test_create_esm_datamodule_raises_without_trainer_max_steps(dummy_protein_dataset, dummy_parquet_train_val_inputs): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs # Initialize the data module. data_module = ESMDataModule( - train_cluster_path=tmp_path / "train_clusters.parquet", + train_cluster_path=train_cluster_path, train_database_path=dummy_protein_dataset, - valid_cluster_path=tmp_path / "valid_clusters.parquet", + valid_cluster_path=valid_cluster_path, valid_database_path=dummy_protein_dataset, ) assert data_module is not None @@ -76,14 +58,14 @@ def test_create_esm_datamodule_raises_without_trainer_max_steps(tmp_path, dummy_ data_module.setup() -def test_create_esm_datamodule_creates_valid_dataloaders(tmp_path, dummy_protein_dataset): - _create_dummy_parquet_inputs(tmp_path) +def test_create_esm_datamodule_creates_valid_dataloaders(dummy_protein_dataset, dummy_parquet_train_val_inputs): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs # Initialize the data module. data_module = ESMDataModule( - train_cluster_path=tmp_path / "train_clusters.parquet", + train_cluster_path=train_cluster_path, train_database_path=dummy_protein_dataset, - valid_cluster_path=tmp_path / "valid_clusters.parquet", + valid_cluster_path=valid_cluster_path, valid_database_path=dummy_protein_dataset, global_batch_size=8, micro_batch_size=4, @@ -107,7 +89,7 @@ def test_create_esm_datamodule_creates_valid_dataloaders(tmp_path, dummy_protein assert isinstance(val_dataloader, torch.utils.data.DataLoader) assert len(train_dataloader) == 10 * 8 # max steps * global batch size - assert len(val_dataloader) == (10 // 2 + 1) * 8 # number of eval steps * global batch size + assert len(val_dataloader) == 8 # global batch size; index reset every val epoch for batch in train_dataloader: assert isinstance(batch, dict) @@ -124,3 +106,272 @@ def test_create_esm_datamodule_creates_valid_dataloaders(tmp_path, dummy_protein assert isinstance(batch["labels"], torch.Tensor) assert isinstance(batch["loss_mask"], torch.Tensor) assert isinstance(batch["is_random"], torch.Tensor) + + +def test_create_esm_datamodule_creates_valid_dataloaders_with_fractional_limit_val_batches( + dummy_protein_dataset, dummy_parquet_train_val_inputs +): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + + # Initialize the data module. + data_module = ESMDataModule( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + global_batch_size=1, + micro_batch_size=1, + min_seq_length=36, + max_seq_length=36, + ) + assert data_module is not None + + data_module.trainer = mock.Mock() + data_module.trainer.max_epochs = 1 + data_module.trainer.max_steps = 10 + data_module.trainer.val_check_interval = 2 + data_module.trainer.limit_val_batches = 0.5 # fractional value + + data_module.setup() + + train_dataloader = data_module.train_dataloader() + assert isinstance(train_dataloader, torch.utils.data.DataLoader) + + val_dataloader = data_module.val_dataloader() + assert isinstance(val_dataloader, torch.utils.data.DataLoader) + + assert len(train_dataloader) == 10 * 1 # max steps * global batch size + assert len(val_dataloader) == int(2 * 0.5) // 1 # number of validation clusters // global batch size + + +def test_create_esm_datamodule_creates_valid_dataloaders_fractional_limit_val_batches_smaller_than_global_batch_size( + dummy_protein_dataset, dummy_parquet_train_val_inputs +): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + + # Initialize the data module. + data_module = ESMDataModule( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + global_batch_size=8, + micro_batch_size=4, + min_seq_length=36, + max_seq_length=36, + ) + assert data_module is not None + + data_module.trainer = mock.Mock() + data_module.trainer.max_epochs = 1 + data_module.trainer.max_steps = 10 + data_module.trainer.val_check_interval = 2 + data_module.trainer.limit_val_batches = 0.5 # fractional value + + # num_val_cluster * limit_val_batches = 2 * 0.5 = 1 < global_batch_size + with pytest.raises(ValueError, match="The limited number of val samples 1 is less than the global batch size 8"): + data_module.setup() + + +@pytest.mark.parametrize("limit_val_batches", [0, 0.0]) +def test_create_esm_datamodule_creates_valid_dataloaders_fractional_limit_val_batches_0( + dummy_protein_dataset, dummy_parquet_train_val_inputs, limit_val_batches +): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + + # Initialize the data module. + data_module = ESMDataModule( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + global_batch_size=8, + micro_batch_size=4, + min_seq_length=36, + max_seq_length=36, + ) + assert data_module is not None + + data_module.trainer = mock.Mock() + data_module.trainer.max_epochs = 1 + data_module.trainer.max_steps = 10 + data_module.trainer.val_check_interval = 2 + data_module.trainer.limit_val_batches = limit_val_batches + + with pytest.raises(ValueError, match="Invalid choice of limit_val_batches size: %s" % limit_val_batches): + data_module.setup() + + +def test_create_esm_datamodule_creates_valid_dataloaders_fractional_limit_val_batches_not_multiple_of_global_batch_size( + dummy_protein_dataset, dummy_parquet_train_val_inputs +): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + + # Initialize the data module. + data_module = ESMDataModule( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + global_batch_size=1, + micro_batch_size=1, + min_seq_length=36, + max_seq_length=36, + ) + assert data_module is not None + + data_module.trainer = mock.Mock() + data_module.trainer.max_epochs = 1 + data_module.trainer.max_steps = 10 + data_module.trainer.val_check_interval = 2 + data_module.trainer.limit_val_batches = 0.7 # fractional value + + data_module.setup() + + train_dataloader = data_module.train_dataloader() + assert isinstance(train_dataloader, torch.utils.data.DataLoader) + + val_dataloader = data_module.val_dataloader() + assert isinstance(val_dataloader, torch.utils.data.DataLoader) + + assert len(train_dataloader) == 10 * 1 # max steps * global batch size + assert len(val_dataloader) == int(2 * 0.7) // 1 # number of validation clusters // global batch size + + +def test_create_esm_datamodule_creates_valid_dataloaders_fractional_limit_val_batches_1p0( + dummy_protein_dataset, dummy_parquet_train_val_inputs +): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + + # Initialize the data module. + data_module = ESMDataModule( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + global_batch_size=1, + micro_batch_size=1, + min_seq_length=36, + max_seq_length=36, + ) + assert data_module is not None + + data_module.trainer = mock.Mock() + data_module.trainer.max_epochs = 1 + data_module.trainer.max_steps = 10 + data_module.trainer.val_check_interval = 2 + data_module.trainer.limit_val_batches = 1.0 # fractional value to use the whole dataset + + data_module.setup() + + train_dataloader = data_module.train_dataloader() + assert isinstance(train_dataloader, torch.utils.data.DataLoader) + + val_dataloader = data_module.val_dataloader() + assert isinstance(val_dataloader, torch.utils.data.DataLoader) + + assert len(train_dataloader) == 10 * 1 # max steps * global batch size + assert len(val_dataloader) == 2 // 1 # number of validation clusters // global batch size + + +def test_create_esm_datamodule_limit_val_batches_none_equals_limit_val_batches_1p0( + dummy_protein_dataset, dummy_parquet_train_val_inputs +): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + + # Initialize the data module with limit_val_batches = 1.0 + data_module_one = ESMDataModule( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + global_batch_size=1, + micro_batch_size=1, + min_seq_length=36, + max_seq_length=36, + ) + assert data_module_one is not None + + data_module_one.trainer = mock.Mock() + data_module_one.trainer.max_epochs = 1 + data_module_one.trainer.max_steps = 10 + data_module_one.trainer.val_check_interval = 2 + data_module_one.trainer.limit_val_batches = 1.0 # fractional value to use the whole dataset + + data_module_one.setup() + + # Initialize the data module with limit_val_batches = None + data_module_none = ESMDataModule( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + global_batch_size=1, + micro_batch_size=1, + min_seq_length=36, + max_seq_length=36, + ) + assert data_module_none is not None + + data_module_none.trainer = mock.Mock() + data_module_none.trainer.max_epochs = 1 + data_module_none.trainer.max_steps = 10 + data_module_none.trainer.val_check_interval = 2 + data_module_none.trainer.limit_val_batches = None # None to use the whole dataset + + data_module_none.setup() + + # Check that the two dataloaders have the same number of samples. + assert len(data_module_one.val_dataloader()) == len(data_module_none.val_dataloader()) + + +def test_create_esm_datamodule_valid_dataloaders_has_consistent_samples_per_epoch( + dummy_protein_dataset, dummy_parquet_train_val_inputs +): + """ + Test that the ESMDataModule dataloaders produce consistent samples per epoch. + + This test ensures that the ESMDataModule creates dataloaders that produce consistent + samples across epochs, even if the data is reshuffled (controlled by `is_ordered`). + + Parameters: + - dummy_protein_dataset: A dummy protein dataset used for testing. + - dummy_parquet_train_val_inputs: A tuple containing paths to dummy parquet files + for training and validation clusters. + """ + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + micro_batch_size = 2 + is_ordered = False # allow random sampling to be independent between epoches + + # Initialize the data module. + data_module = ESMDataModule( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + global_batch_size=1, + micro_batch_size=micro_batch_size, + min_seq_length=36, + max_seq_length=36, + ) + assert data_module is not None + + data_module.trainer = mock.Mock() + data_module.trainer.max_epochs = 1 + data_module.trainer.max_steps = 1 + data_module.trainer.val_check_interval = 1 + data_module.trainer.limit_val_batches = 1.0 # use the whole validation dataset + + data_module.setup() + + # hash values from batches of the first epoch + batch_hashes1 = [tensor_dict_hash(batch) for batch in data_module.val_dataloader()] + + if is_ordered: # second epoch should have exactly the same output including order + for batch in data_module.val_dataloader(): + batch_hash = tensor_dict_hash(batch) + assert batch_hash == batch_hashes1.pop() + else: # second epoch should have the same output but can be reshuffled + batch_hashes1 = set(batch_hashes1) + batch_hashes2 = {tensor_dict_hash(batch) for batch in data_module.val_dataloader()} + assert batch_hashes1 == batch_hashes2 diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/data/test_dataset.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/data/test_dataset.py index 6bcfe9985..982affe84 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/data/test_dataset.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/data/test_dataset.py @@ -73,12 +73,12 @@ def test_ESMPreTrainingDataset_getitem_match_for_identical_seeds(dummy_protein_d dataset2 = ESMMaskedResidueDataset(protein_dataset=dataset, clusters=clusters, total_samples=10, seed=123) # Check that the datasets are equal. - for i in range(len(dataset)): + for i in range(len(dataset1)): sample1 = dataset1[i] sample2 = dataset2[i] for key in sample1: - assert torch.allclose(sample1[key], sample2[key]) + torch.testing.assert_close(sample1[key], sample2[key]) def test_ESMPreTrainingDataset_getitem_is_deterministic(dummy_protein_dataset): @@ -94,7 +94,7 @@ def test_ESMPreTrainingDataset_getitem_is_deterministic(dummy_protein_dataset): for _ in range(10): sample2 = dataset[8] for key in sample1: - assert torch.allclose(sample1[key], sample2[key]) + torch.testing.assert_close(sample1[key], sample2[key]) def test_ESMPreTrainingDataset_getitem_differs_with_different_seeds(dummy_protein_dataset): @@ -148,7 +148,7 @@ def test_ESMPreTrainingDataset_crops_out_start_and_end(dummy_protein_dataset, to clusters = [["UniRef90_A"]] dataset = ESMMaskedResidueDataset( - protein_dataset=prot_dataset, clusters=clusters, total_samples=10, seed=123, max_seq_length=1024 + protein_dataset=prot_dataset, clusters=clusters, seed=123, total_samples=10, max_seq_length=1024 ) assert len(dataset[0]["text"]) == len(prot_dataset["UniRef90_A"]) + 2 @@ -156,7 +156,7 @@ def test_ESMPreTrainingDataset_crops_out_start_and_end(dummy_protein_dataset, to assert dataset[0]["text"][-1] == tokenizer.eos_token_id dataset = ESMMaskedResidueDataset( - protein_dataset=prot_dataset, clusters=clusters, total_samples=10, seed=123, max_seq_length=3 + protein_dataset=prot_dataset, clusters=clusters, seed=123, total_samples=10, max_seq_length=3 ) assert len(dataset[0]["text"]) == 3 @@ -189,7 +189,9 @@ def test_create_train_dataset(dummy_protein_dataset, tmp_path): cluster_file.to_parquet(tmp_path / "train_clusters.parquet") - dataset = create_train_dataset(tmp_path / "train_clusters.parquet", dummy_protein_dataset, 10, 123) + dataset = create_train_dataset( + cluster_file=tmp_path / "train_clusters.parquet", db_path=dummy_protein_dataset, total_samples=10, seed=123 + ) assert len(dataset) == 10 dataset[6] # Make sure it doesn't crash. @@ -203,6 +205,8 @@ def test_create_valid_dataset(dummy_protein_dataset, tmp_path): cluster_file.to_parquet(tmp_path / "valid_clusters.parquet") - dataset = create_valid_dataset(tmp_path / "valid_clusters.parquet", dummy_protein_dataset, 10, 123) + dataset = create_valid_dataset( + clusters=tmp_path / "valid_clusters.parquet", db_path=dummy_protein_dataset, total_samples=10, seed=123 + ) assert len(dataset) == 10 dataset[6] # Make sure it doesn't crash. diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_embedding.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_embedding.py index 254b2f978..645d0b526 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_embedding.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_embedding.py @@ -16,17 +16,17 @@ import pytest import torch -from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from bionemo.esm2.api import ESM2Config +from bionemo.esm2.data.tokenizer import BioNeMoAutoTokenizer, get_tokenizer from bionemo.esm2.model.embedding import ESM2_MASK_RATIO_TRAIN, ESM2Embedding from bionemo.llm.lightning import get_dtype_device from bionemo.testing import megatron_parallel_state_utils @pytest.fixture(scope="module") -def tokenizer() -> AutoTokenizer: - yield AutoTokenizer(pretrained_model_name="facebook/esm2_t33_650M_UR50D") +def tokenizer() -> BioNeMoAutoTokenizer: + yield get_tokenizer() @pytest.fixture(scope="module") @@ -41,7 +41,7 @@ def test_init(embedding, tokenizer): assert isinstance(embedding, ESM2Embedding) assert embedding.token_dropout is True assert embedding.use_attention_mask is True - assert embedding.mask_token_id == tokenizer.mask_id + assert embedding.mask_token_id == tokenizer.mask_token_id def test_forward(embedding): diff --git a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py index 089a4caf0..b8d7c4184 100644 --- a/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py +++ b/sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py @@ -18,6 +18,7 @@ from copy import deepcopy from pathlib import Path from typing import List, Tuple +from unittest import mock import pytest import torch @@ -27,7 +28,10 @@ from bionemo import esm2 from bionemo.core.utils.dtypes import get_autocast_dtype +from bionemo.core.utils.random_utils import random_numpy_context from bionemo.esm2.api import ESM2Config, ESM2Model +from bionemo.esm2.data.datamodule import ESMDataModule +from bionemo.esm2.data.tokenizer import get_tokenizer from bionemo.esm2.model.embedding import ESM2Embedding from bionemo.llm.model.biobert.model import MegatronBioBertModel from bionemo.llm.utils.weight_utils import nemo1_to_nemo2_biobert_key_mapping @@ -82,7 +86,7 @@ def esm2_650M_config_w_ckpt() -> ESM2Config: @pytest.fixture(scope="module") def esm2_model(esm2_config) -> ESM2Model: with megatron_parallel_state_utils.distributed_model_parallel_state(): - tokenizer = AutoTokenizer(pretrained_model_name="facebook/esm2_t33_650M_UR50D") + tokenizer = get_tokenizer() model = esm2_config.configure_model(tokenizer) yield model @@ -111,6 +115,32 @@ def sample_data() -> List[Tuple[str, str]]: yield sample_data +def _compute_loss(model, dataloader, vocab_size=None): + loss = 0 + n = 0 + limit_batches = 10 + for i, batch in enumerate(dataloader): + assert isinstance(batch, dict) + result = model(input_ids=batch["text"].cuda(), attention_mask=batch["attention_mask"].cuda()) + + # bionemo ESM2 vocab_size + if vocab_size is not None: + logits = result["token_logits"][..., :vocab_size] + else: + logits = result.logits + + loss_mask = batch["loss_mask"].cuda() + target = batch["labels"].cuda() + + loss += torch.nn.functional.cross_entropy(logits[loss_mask].float(), target[loss_mask], reduction="sum") + n += loss_mask.sum() + + if limit_batches is not None and i + 1 >= limit_batches: + break + mean_loss: Tensor = loss / n + return mean_loss + + def test_esm2_model_initialized(esm2_model): assert isinstance(esm2_model, MegatronBioBertModel) assert isinstance(esm2_model, ESM2Model) @@ -147,17 +177,15 @@ def test_esm2_650m_checkpoint(esm2_model): def test_esm2_golden_values(esm2_650M_config_w_ckpt, sample_data): - device = "cuda" - tokenizer = AutoTokenizer(pretrained_model_name="facebook/esm2_t33_650M_UR50D") - tokens = tokenizer.tokenizer([row[1] for row in sample_data], return_tensors="pt", padding=True).to(device) + tokens = tokenizer.tokenizer([row[1] for row in sample_data], return_tensors="pt", padding=True).to("cuda") input_ids = tokens["input_ids"] attention_mask = tokens["attention_mask"] # HF 650M model - hf_model = EsmForMaskedLM.from_pretrained("facebook/esm2_t33_650M_UR50D", torch_dtype=get_autocast_dtype(32)).to( - device - ) + hf_model = EsmForMaskedLM.from_pretrained( + "facebook/esm2_t33_650M_UR50D", torch_dtype=get_autocast_dtype(32) + ).cuda() with torch.no_grad(): hf_output_all = hf_model(input_ids, attention_mask, output_hidden_states=True) @@ -170,7 +198,7 @@ def test_esm2_golden_values(esm2_650M_config_w_ckpt, sample_data): torch.cuda.empty_cache() # configure the model to return logits - model = esm2_650M_config_w_ckpt.configure_model(tokenizer).to(device) + model = esm2_650M_config_w_ckpt.configure_model(get_tokenizer()).cuda() model.eval() result = model(input_ids, attention_mask) logits = result["token_logits"][..., : tokenizer.vocab_size] @@ -184,10 +212,68 @@ def test_esm2_golden_values(esm2_650M_config_w_ckpt, sample_data): # configure the model to return hiddens esm2_650M_config_hiddens = deepcopy(esm2_650M_config_w_ckpt) esm2_650M_config_hiddens.return_only_hidden_states = True - model = esm2_650M_config_hiddens.configure_model(tokenizer).to(device) + model = esm2_650M_config_hiddens.configure_model(get_tokenizer()).cuda() model.eval() hiddens = model(input_ids, attention_mask) embeddings = reduce_hiddens(torch.transpose(hiddens, 0, 1).float(), attention_mask) torch.testing.assert_close(logits, hf_logits, atol=9e-2, rtol=0.0) torch.testing.assert_close(embeddings, hf_embeddings, atol=5e-3, rtol=0.0) + + +def test_esm2_loss(esm2_650M_config_w_ckpt, dummy_protein_dataset, dummy_parquet_train_val_inputs): + train_cluster_path, valid_cluster_path = dummy_parquet_train_val_inputs + + compute_hf_reference: bool = False + seed: int = 42 + + with ( + torch.inference_mode(), + megatron_parallel_state_utils.distributed_model_parallel_state(seed), + random_numpy_context(seed), + ): + tokenizer = get_tokenizer() + + # ESM2 model initialized with 650M params + model = esm2_650M_config_w_ckpt.configure_model(tokenizer).cuda() + + # Initialize the data module. + data_module = ESMDataModule( + train_cluster_path=train_cluster_path, + train_database_path=dummy_protein_dataset, + valid_cluster_path=valid_cluster_path, + valid_database_path=dummy_protein_dataset, + global_batch_size=8, + micro_batch_size=4, + min_seq_length=None, + max_seq_length=1024, + seed=seed, + ) + assert data_module is not None + data_module.trainer = mock.Mock() + data_module.trainer.max_epochs = 1 + data_module.trainer.max_steps = 10 + data_module.trainer.val_check_interval = 2 + data_module.trainer.limit_val_batches = 1 + + data_module.setup() + + train_dataloader = data_module.train_dataloader() + assert isinstance(train_dataloader, torch.utils.data.DataLoader) + + val_dataloader = data_module.val_dataloader() + assert isinstance(val_dataloader, torch.utils.data.DataLoader) + + mean_loss = _compute_loss(model, train_dataloader, vocab_size=tokenizer.vocab_size) + + if compute_hf_reference: + # HF model initialized with 650M params + hf_model = EsmForMaskedLM.from_pretrained( + "facebook/esm2_t33_650M_UR50D", torch_dtype=get_autocast_dtype(32) + ).cuda() + hf_mean_loss = _compute_loss(hf_model, train_dataloader) + print(f"hf_mean_loss: {hf_mean_loss}") + else: + hf_mean_loss = torch.tensor(3.0298714637756348).cuda() + + torch.testing.assert_close(mean_loss, hf_mean_loss, atol=1e-4, rtol=0.0) diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/datamodule.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/datamodule.py index 7a6c6b801..7cbe957b9 100644 --- a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/datamodule.py +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/datamodule.py @@ -26,11 +26,12 @@ from tokenizers import Tokenizer from torch.utils.data import DataLoader -from bionemo.core.data.resamplers import PRNGDatasetShuffler +from bionemo.core.data.resamplers import PRNGResampleDataset from bionemo.core.utils import random_utils from bionemo.geneformer.data.singlecell.dataset import SingleCellDataset from bionemo.geneformer.tokenizer.gene_tokenizer import GeneTokenizer from bionemo.llm.data import collate +from bionemo.llm.utils.datamodule_utils import infer_num_samples __all__: Sequence[str] = ("SingleCellDataModule",) @@ -149,15 +150,20 @@ def setup(self, stage: str = "") -> None: # noqa: D102 "Trainer is set to run for multiple epochs. This is not recommended due to the same shuffle being used in each. Instead set max_epochs to 1 and increase the number of max_steps." ) assert max_train_steps > 0, "Please specify trainer.max_steps" - eval_iters = int((max_train_steps // self.trainer.val_check_interval + 1) * self.trainer.limit_val_batches) - test_iters = self.trainer.limit_test_batches - num_train_samples = int(max_train_steps * self.data_sampler.global_batch_size) - num_val_samples = int(eval_iters * self.data_sampler.global_batch_size) - num_test_samples = int(test_iters * self.data_sampler.global_batch_size) - if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float): - # This is to make sure we only have one epoch on every validation iteration - num_val_samples = 1 + num_train_samples = int(max_train_steps * self.data_sampler.global_batch_size) + num_val_samples = infer_num_samples( + limit_batches=self.trainer.limit_val_batches, + num_samples_in_dataset=len(self._val_dataset_ori), + global_batch_size=self.data_sampler.global_batch_size, + stage="val", + ) + num_test_samples = infer_num_samples( + limit_batches=self.trainer.limit_test_batches, + num_samples_in_dataset=len(self._test_dataset_ori), + global_batch_size=self.data_sampler.global_batch_size, + stage="test", + ) # This happens exactly once during setup. self._train_ds = self._sample_and_shuffle_dataset(self._train_dataset_ori, num_train_samples, "train") @@ -199,7 +205,7 @@ def _sample_and_shuffle_dataset(self, dataset: SingleCellDataset, num_samples: i """ # This is where re-sampling occurs. - return PRNGDatasetShuffler( + return PRNGResampleDataset( dataset, num_samples=num_samples, seed=self.seed + len(stage), diff --git a/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/test_model.py b/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/test_model.py index a07e96e19..aa2e47eb3 100644 --- a/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/test_model.py +++ b/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/test_model.py @@ -28,7 +28,7 @@ from tqdm import tqdm from bionemo import geneformer -from bionemo.core.data.resamplers import PRNGDatasetShuffler +from bionemo.core.data.resamplers import PRNGResampleDataset from bionemo.core.utils.batching_utils import pad_token_ids from bionemo.core.utils.dtypes import get_autocast_dtype from bionemo.core.utils.random_utils import random_numpy_context @@ -671,7 +671,7 @@ def _get_loss_from_model(model_config: GeneformerConfig, seed: int) -> float: prepend_cls_token=True, seed=42, ) - dss = PRNGDatasetShuffler( + dss = PRNGResampleDataset( ds, seed=seed, ) diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/lightning.py b/sub-packages/bionemo-llm/src/bionemo/llm/lightning.py index 562bc98e0..dda9a77f3 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/lightning.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/lightning.py @@ -191,8 +191,6 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, outputs = outputs["loss"] # TODO verify that losses are already reduced across ranks # torch.distributed.all_reduce(outputs, op=torch.distributed.ReduceOp.AVG) - # TODO verify that losses are already reduced across ranks - # torch.distributed.all_reduce(outputs, op=torch.distributed.ReduceOp.AVG) loss = outputs self.val_losses.append(loss) diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/lightning.py b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/lightning.py index 34922ae73..f692d6477 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/lightning.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/lightning.py @@ -88,7 +88,7 @@ def data_step(self, dataloader_iter) -> Dict[str, torch.Tensor]: # noqa: D102 return biobert_data_step(dataloader_iter) def forward_step(self, batch) -> torch.Tensor: # noqa: D102 - return bert_forward_step(self, batch) + return bert_forward_step(self, batch) # NOTE(@sichu) reduced to loss def training_step(self, batch, batch_idx=None) -> torch.Tensor: # noqa: D102 # In mcore the loss-function is part of the forward-pass (when labels are provided) diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py index b0951cbc0..82f35f254 100644 --- a/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py +++ b/sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py @@ -82,7 +82,7 @@ def __init__( # noqa: D107 transformer_layer_spec: spec_utils.ModuleSpec, vocab_size: int, max_sequence_length: int, - tokrnizer: Optional[AutoTokenizer] = None, + tokenizer: Optional[AutoTokenizer] = None, pre_process: bool = True, post_process: bool = True, fp16_lm_cross_entropy: bool = False, @@ -395,7 +395,7 @@ def configure_model(self, tokenizer) -> "MegatronBioBertModel": # noqa: D102 num_tokentypes=2 if do_next_sentence else 0, vocab_size=get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by), max_sequence_length=self.seq_length, - tokrnizer=tokenizer, + tokenizer=tokenizer, fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, parallel_output=self.parallel_output, share_embeddings_and_output_weights=self.share_embeddings_and_output_weights, diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/utils/datamodule_utils.py b/sub-packages/bionemo-llm/src/bionemo/llm/utils/datamodule_utils.py new file mode 100644 index 000000000..89d972963 --- /dev/null +++ b/sub-packages/bionemo-llm/src/bionemo/llm/utils/datamodule_utils.py @@ -0,0 +1,197 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import hashlib +from typing import Any, Callable, Dict, List, Optional, Union + +import torch + + +def float_or_int_or_none(value: Union[str, float, int, None]) -> Union[float, int, None]: + """Converts a given value into a float, int, or None. + + Args: + value (Union[str, float, int, None]): A value that can be either a string, float, int, or None. + + Returns: + Union[float, int, None]: A float, int, or None based on the input value. + + If the input value is None or "None", it returns None. + If the input value is an int or float, it returns the same value. + If the input value is a string, it tries to convert it into an int if possible, otherwise into a float. + """ + if value is None or value == "None": + return + if isinstance(value, (int, float)): + return value + if value.isdigit(): + return int(value) + return float(value) + + +def parse_kwargs_to_arglist(kwargs: Dict[str, Any]) -> List[str]: + """Converts a dictionary of keyword arguments into a list of command-line arguments. + + Args: + kwargs (Dict[str, Any]): A dictionary where keys are argument names and values are argument values. + + Returns: + A list of strings, where each string is a command-line argument in the format '--argument-name value'. + """ + arglist = [] + for k, v in kwargs.items(): + arglist.extend([f"--{k.replace('_', '-')}", str(v)]) + return arglist + + +def infer_global_batch_size( + micro_batch_size: int, + num_nodes: int, + devices: int, + accumulate_grad_batches: int = 1, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, +) -> int: + """Infers the global batch size based on the micro batch size, number of nodes, devices, accumulation of gradient batches, and model parallel sizes. + + Args: + micro_batch_size (int): The micro batch size. + num_nodes (int): The number of nodes. + devices (int): The number of devices. + accumulate_grad_batches (int): The accumulation of gradient batches. Defaults to 1. + tensor_model_parallel_size (int): The tensor model parallel size. Defaults to 1. + pipeline_model_parallel_size (int): The pipeline model parallel size. Defaults to 1. + + Returns: + int: The global batch size. + """ + if not all( + isinstance(arg, int) + for arg in [ + micro_batch_size, + num_nodes, + devices, + accumulate_grad_batches, + tensor_model_parallel_size, + pipeline_model_parallel_size, + ] + ): + raise ValueError( + f"All arguments must be of type int, got {type(micro_batch_size)}, {type(num_nodes)}, {type(devices)}, " + f"{type(accumulate_grad_batches)}, {type(tensor_model_parallel_size)}, and {type(pipeline_model_parallel_size)}" + ) + if micro_batch_size <= 0: + raise ValueError(f"micro_batch_size must be greater than 0, got {micro_batch_size}") + if num_nodes <= 0: + raise ValueError(f"num_nodes must be greater than 0, got {num_nodes}") + if devices <= 0: + raise ValueError(f"devices must be greater than 0, got {devices}") + if accumulate_grad_batches <= 0: + raise ValueError(f"accumulate_grad_batches must be greater than 0, got {accumulate_grad_batches}") + if tensor_model_parallel_size <= 0: + raise ValueError(f"tensor_model_parallel_size must be greater than 0, got {tensor_model_parallel_size}") + if pipeline_model_parallel_size <= 0: + raise ValueError(f"pipeline_model_parallel_size must be greater than 0, got {pipeline_model_parallel_size}") + if devices % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0: + raise ValueError( + f"devices must be divisible by tensor_model_parallel_size * pipeline_model_parallel_size, " + f"got {devices} and {tensor_model_parallel_size} * {pipeline_model_parallel_size}" + ) + + world_size = num_nodes * devices + model_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size + data_parallel_size = world_size // model_parallel_size + global_batch_size = micro_batch_size * data_parallel_size * accumulate_grad_batches + return global_batch_size + + +def tensor_hash(tensor: torch.Tensor, hash_func: Optional[Callable] = None) -> str: + """Generates a hash for the given tensor using the specified hash function. + + Args: + tensor (torch.Tensor): The input tensor to be hashed. + hash_func (Optional[Callable]): An optional hash function to use. If None, defaults to SHA-256. + + Returns: + str: The resulting hash string. + + If no hash function is provided, SHA-256 is used by default. The function first converts the tensor to + a contiguous array on the CPU and then to bytes before hashing. + """ + tensor_bytes = tensor.cpu().contiguous().numpy().tobytes() + if hash_func is None: + return hashlib.sha256(tensor_bytes).hexdigest() + else: + return hash_func(tensor_bytes) + + +def tensor_dict_hash(tensor_dict: Dict[str, torch.Tensor], hash_func: Optional[Callable] = None) -> str: + """Generates a hash for the given tensor dictionary using the specified hash function. + + Args: + tensor_dict (Dict[str, torch.Tensor]): The input tensor dictionary to be hashed. + hash_func (Optional[Callable]): An optional hash function to use. If None, defaults to SHA-256. + + Returns: + str: The resulting hash string. + + If no hash function is provided, SHA-256 is used by default. The function first converts the tensor to + a contiguous array on the CPU and then to bytes before hashing. + """ + hash_value = "" + for k in sorted(tensor_dict): + hash_value += tensor_hash(tensor_dict[k], hash_func) + return hash_value + + +def infer_num_samples( + limit_batches: Union[float, int, str, None], num_samples_in_dataset: int, global_batch_size: int, stage: str +): + """Infers the number of samples based on the limit_batches parameter, the length of the dataset, and the global batch size. + + Args: + limit_batches (Union[float, int, str, None]): The limit on the number of batches. Can be a float + between 0 and 1, an integer, a string, or None. If None, defaults to 1.0. + num_samples_in_dataset (int): The number of samples in the dataset. + global_batch_size (int): The global batch size. + stage (str): The stage of the training. + + Returns: + int: The number of samples from the limit. + + Raises: + ValueError: If the limited number of samples is less than the global batch size, or if the + limit_batches parameter is invalid. + + If limit_batches is a float between 0 and 1, the number of samples is inferred as a fraction of the number of samples + in the dataset. If limit_batches is an integer greater than or equal to 1, the number of limited samples is inferred + as the product of limit_batches and global batch size. If limit_batches is None, it defaultsto 1.0, indicating that + all dataset samples should be used. + """ + limit_batches = 1.0 if limit_batches is None else limit_batches # validation data does not require upsampling + if 0 < limit_batches <= 1.0 and isinstance(limit_batches, float): + num_limited_samples = int(num_samples_in_dataset * limit_batches) + if num_limited_samples < global_batch_size: + raise ValueError( + "The limited number of %s samples %s is less than the global batch size %s" + % (stage, num_limited_samples, global_batch_size) + ) + elif limit_batches >= 1 and isinstance(limit_batches, int): + num_limited_samples = int(limit_batches * global_batch_size) + else: + raise ValueError("Invalid choice of limit_%s_batches size: %s" % (stage, limit_batches)) + + return num_limited_samples diff --git a/sub-packages/bionemo-llm/tests/bionemo/llm/utils/test_datamodule_utils.py b/sub-packages/bionemo-llm/tests/bionemo/llm/utils/test_datamodule_utils.py new file mode 100644 index 000000000..c281fef09 --- /dev/null +++ b/sub-packages/bionemo-llm/tests/bionemo/llm/utils/test_datamodule_utils.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from bionemo.llm.utils.datamodule_utils import float_or_int_or_none, infer_global_batch_size + + +def test_float_or_int_or_none_type_float(): + """Test that float_or_int_or_none returns a float when given a float on edge case 1.0""" + assert isinstance(float_or_int_or_none(1.0), float) + assert isinstance(float_or_int_or_none("1.0"), float) + + +def test_float_or_int_or_none_type_int(): + """Test that float_or_int_or_none returns an int when given an int on edge case 1""" + assert isinstance(float_or_int_or_none(1), int) + assert isinstance(float_or_int_or_none("1"), int) + + +def test_float_or_int_or_none_type_none(): + """Test that float_or_int_or_none returns None when given None""" + assert float_or_int_or_none(None) is None + assert float_or_int_or_none("None") is None + + +def test_infer_global_batch_size(): + """Test that infer_global_batch_size returns the correct global batch size""" + assert infer_global_batch_size(micro_batch_size=1, num_nodes=1, devices=1) == 1 # single node, single device + assert infer_global_batch_size(micro_batch_size=1, num_nodes=1, devices=8) == 8 # single node, multi device + assert ( + infer_global_batch_size( + micro_batch_size=1, + num_nodes=2, + devices=8, + ) + == 16 + ) # multi node, multi device + assert ( + infer_global_batch_size(micro_batch_size=1, num_nodes=2, devices=8, pipeline_model_parallel_size=2) == 8 + ) # multi node, multi device with pipeline parallel + assert ( + infer_global_batch_size( + micro_batch_size=1, num_nodes=2, devices=8, pipeline_model_parallel_size=2, tensor_model_parallel_size=2 + ) + == 4 + ) # multi node, multi device with pipeline and tensor parallel + assert ( + infer_global_batch_size( + micro_batch_size=1, + num_nodes=2, + devices=8, + pipeline_model_parallel_size=2, + tensor_model_parallel_size=2, + accumulate_grad_batches=2, + ) + == 8 + ) # multi node, multi device with pipeline and tensor parallel, and accumulate grad batches