diff --git a/.vscode/settings.json b/.vscode/settings.json index ac8fe37ce..f099449db 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,8 +1,14 @@ { "cSpell.words": [ - // List of words to be added to the spell-checker dictionary. - // In vscode, use the "Add to workspace settings" in the quick-fix menu to add words to this list. - "bionemo" + "allclose", + "bionemo", + "dtype", + "nemo", + "pretraining", + "rampup", + "resamplers", + "singlecell", + "uniref" ], "editor.rulers": [ 120 diff --git a/scripts/singlecell/geneformer/pretrain.py b/scripts/singlecell/geneformer/pretrain.py index c966bcc19..6b8252aa0 100644 --- a/scripts/singlecell/geneformer/pretrain.py +++ b/scripts/singlecell/geneformer/pretrain.py @@ -169,7 +169,7 @@ def main( train_dataset_path=train_data_path, val_dataset_path=val_data_path, test_dataset_path=test_data_path, - random_token_prob=0.1, # this is the incorrect setting we originally used. + random_token_prob=0.02, # changed to represent the incorrect setting we originally used. median_dict=median_dict, micro_batch_size=micro_batch_size, global_batch_size=micro_batch_size * int(num_nodes * devices / pipeline_model_parallel_size), diff --git a/sub-packages/bionemo-core/src/bionemo/core/utils/random_utils.py b/sub-packages/bionemo-core/src/bionemo/core/utils/random_utils.py index 1c29c0a83..0c35addfe 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/utils/random_utils.py +++ b/sub-packages/bionemo-core/src/bionemo/core/utils/random_utils.py @@ -44,3 +44,11 @@ def random_numpy_context(seed: int = 42) -> Iterator[None]: yield finally: np.random.set_state(state) + + +def get_seed_from_rng(rng: np.random.Generator) -> int: + """Generates a deterministic random seed from an existing random generator. + + Used to seed a torch random generator from a numpy random generator. + """ + return rng.integers(np.iinfo(np.int64).max) diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/datamodule.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/datamodule.py index 2a89a36c1..7a6c6b801 100644 --- a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/datamodule.py +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/datamodule.py @@ -14,9 +14,11 @@ # limitations under the License. +import functools from pathlib import Path from typing import List, Optional, Sequence +import numpy as np import pytorch_lightning as pl from nemo.lightning.pytorch.plugins import MegatronDataSampler from nemo.utils import logging @@ -25,7 +27,10 @@ from torch.utils.data import DataLoader from bionemo.core.data.resamplers import PRNGDatasetShuffler +from bionemo.core.utils import random_utils from bionemo.geneformer.data.singlecell.dataset import SingleCellDataset +from bionemo.geneformer.tokenizer.gene_tokenizer import GeneTokenizer +from bionemo.llm.data import collate __all__: Sequence[str] = ("SingleCellDataModule",) @@ -90,6 +95,8 @@ def __init__( # noqa: D107 self.persistent_workers = persistent_workers self.pin_memory = pin_memory self.index_mapping_dir = index_mapping_dir or str(Path(self.data_path_train).parent) + + rng = np.random.default_rng(seed) self._train_dataset_ori = SingleCellDataset( self.data_path_train, self.tokenizer, @@ -98,6 +105,7 @@ def __init__( # noqa: D107 mask_prob=self.mask_prob, mask_token_prob=self.mask_token_prob, random_token_prob=self.random_token_prob, + seed=random_utils.get_seed_from_rng(rng), ) self._val_dataset_ori = SingleCellDataset( self.data_path_val, @@ -107,6 +115,7 @@ def __init__( # noqa: D107 mask_prob=self.mask_prob, mask_token_prob=self.mask_token_prob, random_token_prob=self.random_token_prob, + seed=random_utils.get_seed_from_rng(rng), ) self._test_dataset_ori = SingleCellDataset( self.data_path_test, @@ -116,6 +125,7 @@ def __init__( # noqa: D107 mask_prob=self.mask_prob, mask_token_prob=self.mask_token_prob, random_token_prob=self.random_token_prob, + seed=random_utils.get_seed_from_rng(rng), ) # This is needed here, or you need to specify it in the megatron adapter thing TODO name? @@ -169,7 +179,12 @@ def _create_dataloader(self, dataset, **kwargs) -> DataLoader: num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers, - # collate_fn=dataset.collate_fn, No special work happens in this dataloader outside of getitem + collate_fn=functools.partial( + collate.bert_padding_collate_fn, + padding_value=self.tokenizer.token_to_id(GeneTokenizer.pad_token), + min_length=None, + max_length=self.max_len, + ), **kwargs, ) diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py index af71aa860..e68981ee3 100644 --- a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/dataset.py @@ -16,32 +16,25 @@ import json from pathlib import Path -from typing import Any, Dict, Optional, Sequence, Tuple, TypedDict +from typing import Any, Dict, Optional, Sequence, Tuple import numpy as np +import torch from nemo.utils import logging from torch.utils.data import Dataset -from bionemo.geneformer.data.singlecell.utils import sample_or_truncate_plus_pad +from bionemo.core.utils import random_utils +from bionemo.geneformer.data.singlecell.utils import sample_or_truncate from bionemo.geneformer.tokenizer.gene_tokenizer import GeneTokenizer +from bionemo.llm.data import masking, types __all__: Sequence[str] = ( "SingleCellDataset", - "Item", "process_item", ) -class Item(TypedDict): # noqa: D101 - text: np.ndarray - types: np.ndarray - padding_mask: np.ndarray - labels: np.ndarray - loss_mask: np.ndarray - is_random: np.ndarray - - class SingleCellDataset(Dataset): """A dataset class for single-cell pre-training. These can be generated using the sc_memmap.py script. Future updates will contain more comprehensive workflows for generating a Sparse Memmap from scRNA-seq. @@ -94,6 +87,7 @@ def __init__( # noqa: D107 random_token_prob: float = 0.1, prepend_cls_token: bool = True, assert_increasing_columns: bool = True, + seed: int = np.random.SeedSequence().entropy, # type: ignore ): super().__init__() self.data_path = data_path @@ -102,6 +96,7 @@ def __init__( # noqa: D107 self.mask_token_prob = mask_token_prob self.mask_prob = mask_prob self.prepend_cls_token = prepend_cls_token + self._seed = seed # check if column indices are increasing for looking up genes. This is a way of spotting if the sc_memmap.py # script produced properly strctured sparse files. self.assert_increasing_columns = assert_increasing_columns @@ -198,8 +193,10 @@ def lookup_cell_by_idx(self, idx) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: ) return gene_data, col_idxs, feature_ids - def __getitem__(self, idx: int) -> Item: - """Performs a lookup and the required transformation for the model""" # noqa: D415 + def __getitem__(self, idx: int) -> types.BertSample: # noqa: D105 + rng = np.random.default_rng([self._seed, idx]) + + """Performs a lookup and the required transformation for the model""" gene_data, col_idxs, feature_ids = self.lookup_cell_by_idx(idx) return process_item( gene_data, @@ -207,6 +204,7 @@ def __getitem__(self, idx: int) -> Item: feature_ids, self.tokenizer, gene_median=self.gene_medians, + rng=rng, max_len=self.max_len, mask_token_prob=self.mask_token_prob, mask_prob=self.mask_prob, @@ -221,6 +219,7 @@ def process_item( # noqa: D417 feature_ids: np.ndarray, tokenizer: GeneTokenizer, gene_median: dict, + rng: np.random.Generator, max_len: int = 1024, mask_prob: float = 0.15, mask_token_prob: float = 0.8, @@ -228,7 +227,7 @@ def process_item( # noqa: D417 target_sum: int = 10000, normalize: bool = True, prepend_cls_token: bool = True, -) -> Item: +) -> types.BertSample: """Process a single item in the dataset. Optionally performs median normalization and rank ordering. The tokenizers CLS token is added to the beginning @@ -240,6 +239,7 @@ def process_item( # noqa: D417 feature_ids (list): Feature ids for the full dataset. tokenizer (Tokenizer): Tokenizer object. gene_median (optional(dict)): Dictionary of gene medians. Defaults to None. Expects ensembl IDs to be keys. + rng: Random number generator to ensure deterministic results. max_len (int): Maximum length of the item. Defaults to 1024. Applies padding to any sequence shorter than max_len and truncates any sequence longer than max_len. mask_prob (float): Probability of masking a token. Defaults to 0.15. target_sum (int): Target sum for normalization. Defaults to 10000. @@ -259,14 +259,6 @@ def process_item( # noqa: D417 if max_len < 1: raise ValueError(f"max_len must be greater than 1, {max_len=}") - if random_token_prob + mask_token_prob > 1.0: - raise ValueError( - "Sum of random_token_prob and mask_token_prob must be less than or equal to 1.0, identity_token_prob is any remainder less than 1.0." - ) - - identity_token_prob = 1.0 - (random_token_prob + mask_token_prob) - assert identity_token_prob >= 0.0 - if gene_median is None: raise ValueError("gene_median must be provided for this tokenizer") @@ -296,62 +288,34 @@ def process_item( # noqa: D417 token_ids = token_ids[idxs] # - select max_len subset, set sample to false so it doesnt permute the already rank ordered expression values. - token_ids = sample_or_truncate_plus_pad( - token_ids, max_len, tokenizer.token_to_id(tokenizer.pad_token), sample=False + token_ids = sample_or_truncate(token_ids, max_len, sample=False) + + masked_tokens, labels, loss_mask = masking.apply_bert_pretraining_mask( + tokenized_sequence=torch.from_numpy(token_ids), + random_seed=random_utils.get_seed_from_rng(rng), + mask_config=masking.BertMaskConfig( + mask_token=tokenizer.token_to_id(tokenizer.mask_token), + random_tokens=range(5, len(tokenizer.vocab)), + mask_prob=mask_prob, + mask_token_prob=mask_token_prob, + random_token_prob=random_token_prob, + ), ) - mask = None - mask_tokens_positions = None - random_tokens_positions = None - - # - masked tokens - if mask_prob > 0.0: - probs = np.full(token_ids.shape[0], mask_prob) - probs[token_ids == tokenizer.token_to_id(tokenizer.pad_token)] = 0.0 - mask = np.random.binomial(1, probs).astype(bool) - mask_tokens_positions = mask & np.random.binomial(1, mask_token_prob, mask.shape).astype(bool) - random_tokens_positions = ( - mask & np.random.binomial(1, random_token_prob, mask.shape).astype(bool) & (~mask_tokens_positions) - ) - # - ensure [CLS] token is masked from the loss. Note that we're dealing with 1d arrays so flattening isn't a problem here. - if prepend_cls_token: - mask = np.insert(mask, 0, False) - mask_tokens_positions = np.insert(mask_tokens_positions, 0, False) - random_tokens_positions = np.insert(random_tokens_positions, 0, False) - - # - add [CLS] token, note that token_ids is a 1d array so flattening isn't a problem here. if prepend_cls_token: - token_ids = np.insert(token_ids, 0, tokenizer.token_to_id(tokenizer.cls_token)) - attention_mask = token_ids != tokenizer.token_to_id(tokenizer.pad_token) - - labels = np.ones(len(token_ids)) * -1 - - if mask is None: - # If prob is set to zero, we get None for our mask, which could have unintended side effects. - # We abuse the scenario where mask == None - labels[mask] = token_ids[mask] - mask = np.zeros(shape=token_ids.shape, dtype=bool) - else: - mask[~attention_mask] = False # make sure that we aren't doing MLM on [PAD] tokens - labels[mask] = token_ids[mask] - if mask_tokens_positions is None: - mask_tokens_positions = np.zeros_like(mask) - if random_tokens_positions is None: - random_tokens_positions = np.zeros_like(mask) - # identity_tokens = mask & (~mask_tokens_positions) & (~random_tokens_positions), not needed because - token_ids[mask_tokens_positions] = tokenizer.token_to_id(tokenizer.mask_token) - # There are 5 special tokens in the tokenizer, so we start from 5. TODO make this a parameter of the tokenizer. - if random_tokens_positions.sum() > 0: - token_ids[random_tokens_positions] = np.random.randint(5, len(tokenizer.vocab), random_tokens_positions.sum()) + masked_tokens, labels, loss_mask = masking.add_cls_and_eos_tokens( + sequence=masked_tokens, + labels=labels, + loss_mask=loss_mask, + cls_token=tokenizer.token_to_id(tokenizer.cls_token), + ) # NeMo megatron assumes this return structure. - item = { - "text": token_ids.astype(np.int64), - "types": np.zeros_like(token_ids).astype(np.int64), - "attention_mask": attention_mask.astype(np.int64), - "labels": labels.astype(np.int64), - "loss_mask": mask, - "is_random": np.zeros_like(token_ids).astype(np.int64), + return { + "text": masked_tokens, + "types": torch.zeros_like(masked_tokens, dtype=torch.int64), + "attention_mask": torch.ones_like(masked_tokens, dtype=torch.int64), + "labels": labels, + "loss_mask": loss_mask, + "is_random": torch.zeros_like(masked_tokens, dtype=torch.int64), } - - return item diff --git a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/utils.py b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/utils.py index aafe817f3..4dcd9b4f1 100644 --- a/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/utils.py +++ b/sub-packages/bionemo-geneformer/src/bionemo/geneformer/data/singlecell/utils.py @@ -12,41 +12,30 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence import numpy as np -__all__: Sequence[str] = ("sample_or_truncate_plus_pad",) - - -def sample_or_truncate_plus_pad( - gene_ids: np.array, +def sample_or_truncate( + gene_ids: np.ndarray, max_length: int, - pad_token_id: int, sample: bool = True, -) -> np.array: +) -> np.ndarray: """Truncate and pad samples. Args: gene_ids (np.ndarray): Array of gene IDs. max_length (int): Maximum length of the samples. - pad_token_id (int): ID of the padding token. sample (bool, optional): Whether to sample or truncate the samples. Defaults to True. Returns: np.array: Tuple containing the truncated or padded gene IDs. """ - if len(gene_ids) == max_length: + if len(gene_ids) <= max_length: return gene_ids - if len(gene_ids) > max_length: # - sample or truncate - if sample: - indices = np.random.permutation(len(gene_ids))[:max_length] - return gene_ids[indices] - else: - return gene_ids[:max_length] - else: # - pad - pad_tokens = np.full((max_length - len(gene_ids)), pad_token_id, dtype=np.int32) - gene_ids = np.concatenate([gene_ids, pad_tokens]) - return gene_ids + if sample: + indices = np.random.permutation(len(gene_ids))[:max_length] + return gene_ids[indices] + else: + return gene_ids[:max_length] diff --git a/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/test_model.py b/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/test_model.py index 6bf705497..8662453e1 100644 --- a/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/test_model.py +++ b/sub-packages/bionemo-geneformer/tests/bionemo/geneformer/test_model.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import tarfile from copy import deepcopy from pathlib import Path @@ -34,6 +35,7 @@ from bionemo.geneformer.api import GeneformerConfig from bionemo.geneformer.data.singlecell.dataset import SingleCellDataset from bionemo.geneformer.data.singlecell.preprocess import GeneformerPreprocess +from bionemo.llm.data import collate from bionemo.llm.model.biobert.lightning import BioBertLightningModule from bionemo.llm.model.biobert.model import BiobertSpecOption from bionemo.llm.utils.weight_utils import nemo1_to_nemo2_biobert_key_mapping @@ -171,9 +173,11 @@ def test_nemo1_nemo2_weight_shapes_match(geneformer_config, seed: int = 42): if not train_data_path.exists(): raise FileNotFoundError(f"Could not find train data at {train_data_path}. {data_error_str}") - with tarfile.open( - nemo1_checkpoint_path, "r" - ) as old_ckpt, torch.no_grad(), megatron_parallel_state_utils.distributed_model_parallel_state(seed): + with ( + tarfile.open(nemo1_checkpoint_path, "r") as old_ckpt, + torch.no_grad(), + megatron_parallel_state_utils.distributed_model_parallel_state(seed), + ): ckpt_file = old_ckpt.extractfile("./model_weights.ckpt") old_weights = torch.load(ckpt_file) preprocessor = GeneformerPreprocess( @@ -442,9 +446,11 @@ def test_geneformer_inference_nemo1_v_nemo2_golden_values_by_layer( if not train_data_path.exists(): raise FileNotFoundError(f"Could not find train data at {train_data_path}. {data_error_str}") - with tarfile.open( - nemo1_checkpoint_path, "r" - ) as old_ckpt, torch.inference_mode(), megatron_parallel_state_utils.distributed_model_parallel_state(seed): + with ( + tarfile.open(nemo1_checkpoint_path, "r") as old_ckpt, + torch.inference_mode(), + megatron_parallel_state_utils.distributed_model_parallel_state(seed), + ): ckpt_file = old_ckpt.extractfile("./model_weights.ckpt") old_weights = torch.load(ckpt_file) new_state_dict_from_old = {} @@ -626,9 +632,11 @@ def _get_loss_from_model(model_config: GeneformerConfig, seed: int) -> float: data_dir = Path(data_path) train_data_path = data_dir / "train" test_data_path = data_dir / "test" - with torch.inference_mode(), megatron_parallel_state_utils.distributed_model_parallel_state( - seed - ), random_numpy_context(seed): + with ( + torch.inference_mode(), + megatron_parallel_state_utils.distributed_model_parallel_state(seed), + random_numpy_context(seed), + ): preprocessor = GeneformerPreprocess( download_directory=train_data_path, medians_file_path=train_data_path / "medians.json", @@ -659,8 +667,9 @@ def _get_loss_from_model(model_config: GeneformerConfig, seed: int) -> float: max_len=2048, mask_prob=0.15, mask_token_prob=0.8, - random_token_prob=0.1, # TODO: once this is fixed, change to 0.02 to match the prior numbers. + random_token_prob=0.02, prepend_cls_token=True, + seed=42, ) dss = PRNGDatasetShuffler( ds, @@ -671,6 +680,12 @@ def _get_loss_from_model(model_config: GeneformerConfig, seed: int) -> float: batch_size=8, shuffle=False, num_workers=0, + collate_fn=functools.partial( + collate.bert_padding_collate_fn, + padding_value=tokenizer.token_to_id(tokenizer.pad_token), + min_length=None, + max_length=2048, + ), drop_last=False, ) loss = 0 diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/data/__init__.py b/sub-packages/bionemo-llm/src/bionemo/llm/data/__init__.py new file mode 100644 index 000000000..25e6abfbc --- /dev/null +++ b/sub-packages/bionemo-llm/src/bionemo/llm/data/__init__.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/data/collate.py b/sub-packages/bionemo-llm/src/bionemo/llm/data/collate.py new file mode 100644 index 000000000..e965288fc --- /dev/null +++ b/sub-packages/bionemo-llm/src/bionemo/llm/data/collate.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Sequence, TypeVar + +import torch + +from bionemo.llm.data import types + + +_T = TypeVar("_T", bound=dict[str, torch.Tensor]) + + +def padding_collate_fn( + batch: Sequence[_T], + padding_values: dict[str, int], + min_length: int | None = None, + max_length: int | None = None, +) -> _T: + """Collate function with padding. + + Args: + batch: List of samples, each of which is a dictionary of tensors. + padding_values: A dictionary of padding values for each tensor key. + min_length: Minimum length of the output batch; tensors will be padded to this length. If not + provided, no extra padding beyond the max_length will be added. + max_length: Maximum length of the sequence. If not provided, tensors will be padded to the + longest sequence in the batch. + + Returns: + A collated batch with the same dictionary input structure. + """ + for entry in batch: + if entry.keys() != padding_values.keys(): + raise ValueError("All keys in inputs must match provided padding_values.") + + def _pad(tensors, padding_value): + if max_length is not None: + tensors = [t[:max_length] for t in tensors] + batched_tensors = torch.nn.utils.rnn.pad_sequence(tensors, batch_first=True, padding_value=padding_value) + if min_length is None: + return batched_tensors + return torch.nn.functional.pad(batched_tensors, (0, min_length - batched_tensors.size(1)), value=padding_value) + + return {k: _pad([s[k] for s in batch], padding_values[k]) for k in batch[0].keys()} # type: ignore[return-value] + + +def bert_padding_collate_fn( + batch: Sequence[types.BertSample], + padding_value: int, + min_length: int | None = None, + max_length: int | None = None, +) -> types.BertSample: + """Padding collate function for BERT dataloaders. + + Args: + batch (list): List of samples. + padding_value (int, optional): The tokenizer's pad token ID. + min_length: Minimum length of the output batch; tensors will be padded to this length. If not + provided, no extra padding beyond the max_length will be added. + max_length: Maximum length of the sequence. If not provided, tensors will be padded to the + longest sequence in the batch. + """ + padding_values = { + "text": padding_value, + "types": 0, + "attention_mask": False, + "labels": -1, + "loss_mask": False, + "is_random": 0, + } + return padding_collate_fn( + batch=batch, # type: ignore[assignment] + padding_values=padding_values, + min_length=min_length, + max_length=max_length, + ) diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/data/masking.py b/sub-packages/bionemo-llm/src/bionemo/llm/data/masking.py new file mode 100644 index 000000000..06fe5b1b1 --- /dev/null +++ b/sub-packages/bionemo-llm/src/bionemo/llm/data/masking.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass + +import torch + + +@dataclass(frozen=True) +class BertMaskConfig: + """Configuration for masking tokens in a BERT-style model. + + Attributes: + mask_prob: Probability of masking a token. + mask_token_prob: Probability of replacing a masked token with the mask token. + random_token_prob: Probability of replacing a masked token with a random token. + """ + + mask_token: int + random_tokens: range + mask_prob: float = 0.15 + mask_token_prob: float = 0.8 + random_token_prob: float = 0.1 + + def __post_init__(self) -> None: + """Check that the sum of `mask_token_prob` and `random_token_prob` is less than or equal to 1.0. + + Raises: + ValueError: If the sum of `mask_token_prob` and `random_token_prob` is greater than 1.0. + """ + if self.random_token_prob + self.mask_token_prob > 1.0: + raise ValueError("Sum of random_token_prob and mask_token_prob must be less than or equal to 1.0.") + + +def apply_bert_pretraining_mask( + tokenized_sequence: torch.Tensor, random_seed: int, mask_config: BertMaskConfig +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Applies the pretraining mask to a tokenized sequence. + + Args: + tokenized_sequence: Tokenized protein sequence. + random_seed: Random seed for reproducibility. + mask_config: Configuration for masking tokens in a BERT-style model. + + Returns: + masked_sequence: + The tokenized sequence with some tokens masked. + labels: + A tensor the same shape as `masked_sequence` containing labels for the masked tokens, with -1 for non-masked + tokens. + loss_mask: + A boolean tensor the same shape as `masked_sequence`, where 'True' indicates which tokens should be included + in the loss. + """ + if mask_config.random_token_prob + mask_config.mask_token_prob > 1.0: + raise ValueError("Sum of random_token_prob and mask_token_prob must be less than or equal to 1.0.") + + # Set the seed so that __getitem__(idx) is always deterministic. + # This is required by Megatron-LM's parallel strategies. + torch.manual_seed(random_seed) + + mask_stop_1 = mask_config.mask_prob * mask_config.mask_token_prob + mask_stop_2 = mask_config.mask_prob * (mask_config.mask_token_prob + mask_config.random_token_prob) + + random_draws = torch.rand(tokenized_sequence.shape) # Random draws for each token in [0, 1). + + # Overall mask for a token being masked in some capacity - either mask token, random token, or left as-is (identity) + loss_mask = random_draws < mask_config.mask_prob + + # The first `mask_token_prob` fraction of the `mask_prob` tokens are replaced with the mask token. + mask_token_mask = random_draws < mask_stop_1 + + # The next `random_token_prob` fraction of the `mask_prob` tokens are replaced with a random token. + random_token_mask = (random_draws >= mask_stop_1) & (random_draws < mask_stop_2) + + # The remaining tokens are implicitly left as-is, representing an identity mask. + + # Mask the tokens. + masked_sequence = tokenized_sequence.clone() + masked_sequence[mask_token_mask] = mask_config.mask_token + num_random_tokens: int = random_token_mask.sum().item() # type: ignore[assignment] + masked_sequence[random_token_mask] = torch.randint( + low=mask_config.random_tokens.start, + high=mask_config.random_tokens.stop, + size=(num_random_tokens,), + dtype=masked_sequence.dtype, + ) + + # Create the labels for the masked tokens. + labels = tokenized_sequence.clone() + labels[~loss_mask] = -1 # Ignore loss for non-masked tokens. + + return masked_sequence, labels, loss_mask + + +def add_cls_and_eos_tokens( + sequence: torch.Tensor, + labels: torch.Tensor, + loss_mask: torch.Tensor, + cls_token: int | None = None, + eos_token: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Prepends the CLS token and appends the EOS token to the masked sequence, updating the loss mask and labels. + + These labels should never be masked, so this is done after the masking step. + + Args: + sequence: The input (likely masked) sequence. + labels: The true values of the input sequence at the mask positions. + loss_mask: A boolean tensor indicating which tokens should be included in the loss. + cls_token: The token to use for the CLS token. If None, no CLS token is added. + eos_token: The token to use for the EOS token. If None, no EOS token is added. + + Returns: + The same input tensors with the CLS and EOS tokens added, and the labels and loss_mask updated accordingly. + """ + # Prepend the CLS token and append the EOS token, and update the loss mask and labels accordingly. + sequence = torch.cat( + [ + torch.tensor([cls_token], dtype=sequence.dtype) + if cls_token is not None + else torch.tensor([], dtype=sequence.dtype), + sequence, + torch.tensor([eos_token], dtype=sequence.dtype) + if eos_token is not None + else torch.tensor([], dtype=sequence.dtype), + ] + ) + + labels = torch.cat( + [ + torch.tensor([-1], dtype=labels.dtype) if cls_token is not None else torch.tensor([], dtype=labels.dtype), + labels, + torch.tensor([-1], dtype=labels.dtype) if eos_token is not None else torch.tensor([], dtype=labels.dtype), + ] + ) + + loss_mask = torch.cat( + [ + torch.tensor([False]) if cls_token is not None else torch.tensor([], dtype=loss_mask.dtype), + loss_mask, + torch.tensor([False]) if eos_token is not None else torch.tensor([], dtype=loss_mask.dtype), + ] + ) + + return sequence, labels, loss_mask diff --git a/sub-packages/bionemo-llm/src/bionemo/llm/data/types.py b/sub-packages/bionemo-llm/src/bionemo/llm/data/types.py new file mode 100644 index 000000000..ca84618e9 --- /dev/null +++ b/sub-packages/bionemo-llm/src/bionemo/llm/data/types.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import TypedDict + +import torch + + +class BertSample(TypedDict): + """The type expected by NeMo/Megatron for a single dataset item. + + Attributes: + text: The tokenized, masked input text. + types: The token type ids, if applicable. + attention_mask: A mask over all valid tokens, excluding padding. + labels: The true values of the masked tokens at each position covered by loss_mask. + loss_mask: The mask over the text indicating which tokens are masked and should be predicted. + is_random: ?? + """ + + text: torch.Tensor + types: torch.Tensor + attention_mask: torch.Tensor + labels: torch.Tensor + loss_mask: torch.Tensor + is_random: torch.Tensor diff --git a/sub-packages/bionemo-llm/tests/bionemo/llm/data/test_collate.py b/sub-packages/bionemo-llm/tests/bionemo/llm/data/test_collate.py new file mode 100644 index 000000000..74dbef54c --- /dev/null +++ b/sub-packages/bionemo-llm/tests/bionemo/llm/data/test_collate.py @@ -0,0 +1,187 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import torch + +from bionemo.llm.data.collate import bert_padding_collate_fn, padding_collate_fn + + +def test_padding_collate_fn(): + sample1 = { + "my_key": torch.tensor([1, 2, 3]), + } + sample2 = { + "my_key": torch.tensor([4, 5, 6, 7, 8]), + } + batch = [sample1, sample2] + collated_batch = padding_collate_fn(batch, padding_values={"my_key": -1}) + + assert torch.all(torch.eq(collated_batch["my_key"], torch.tensor([[1, 2, 3, -1, -1], [4, 5, 6, 7, 8]]))) + + +def test_padding_collate_with_missing_key_raises(): + sample1 = { + "my_key": torch.tensor([1, 2, 3]), + } + sample2 = { + "my_key": torch.tensor([4, 5, 6, 7, 8]), + "other_key": torch.tensor([1, 2, 3]), + } + batch = [sample1, sample2] + with pytest.raises(ValueError, match="All keys in inputs must match provided padding_values."): + padding_collate_fn(batch, padding_values={"my_key": -1, "other_key": -1}) + + +def test_bert_padding_collate_fn(): + # Create sample data + sample1 = { + "text": torch.tensor([1, 2, 3]), + "types": torch.zeros((3,), dtype=torch.int64), + "attention_mask": torch.tensor([True, True, False]), + "labels": torch.tensor([7, 8, 9]), + "loss_mask": torch.tensor([True, False, True]), + "is_random": torch.zeros((3,), dtype=torch.int64), + } + sample2 = { + "text": torch.tensor([10, 11, 12]), + "types": torch.zeros((3,), dtype=torch.int64), + "attention_mask": torch.tensor([True, False, True]), + "labels": torch.tensor([16, 17, 18]), + "loss_mask": torch.tensor([False, True, False]), + "is_random": torch.zeros((3,), dtype=torch.int64), + } + batch = [sample1, sample2] + + # Call the collate_fn + collated_batch = bert_padding_collate_fn(batch, padding_value=-1) + + # Assert the expected output + assert torch.all(torch.eq(collated_batch["text"], torch.tensor([[1, 2, 3], [10, 11, 12]]))) + assert torch.all(torch.eq(collated_batch["types"], torch.tensor([[0, 0, 0], [0, 0, 0]]))) + assert torch.all( + torch.eq(collated_batch["attention_mask"], torch.tensor([[True, True, False], [True, False, True]])) + ) + assert torch.all(torch.eq(collated_batch["labels"], torch.tensor([[7, 8, 9], [16, 17, 18]]))) + assert torch.all(torch.eq(collated_batch["loss_mask"], torch.tensor([[True, False, True], [False, True, False]]))) + assert torch.all(torch.eq(collated_batch["is_random"], torch.tensor([[0, 0, 0], [0, 0, 0]]))) + + +def test_bert_padding_collate_fn_with_padding(): + # Create sample data + sample1 = { + "text": torch.tensor([1, 2, 3]), + "types": torch.zeros((3,), dtype=torch.int64), + "attention_mask": torch.tensor([True, True, False]), + "labels": torch.tensor([7, 8, 9]), + "loss_mask": torch.tensor([True, False, True]), + "is_random": torch.zeros((3,), dtype=torch.int64), + } + sample2 = { + "text": torch.tensor([4, 5, 6, 7, 8]), + "types": torch.zeros((5,), dtype=torch.int64), + "attention_mask": torch.tensor([True, True, True, True, True]), + "labels": torch.tensor([-1, 5, -1, 7, 8]), + "loss_mask": torch.tensor([False, True, False, True, True]), + "is_random": torch.zeros((5,), dtype=torch.int64), + } + batch = [sample1, sample2] + + # Call the collate_fn + collated_batch = bert_padding_collate_fn(batch, padding_value=10) + + # Assert the expected output + assert torch.all(torch.eq(collated_batch["text"], torch.tensor([[1, 2, 3, 10, 10], [4, 5, 6, 7, 8]]))) + assert torch.all(torch.eq(collated_batch["types"], torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]))) + assert torch.all( + torch.eq( + collated_batch["attention_mask"], + torch.tensor([[True, True, False, False, False], [True, True, True, True, True]]), + ) + ) + assert torch.all(torch.eq(collated_batch["labels"], torch.tensor([[7, 8, 9, -1, -1], [-1, 5, -1, 7, 8]]))) + assert torch.all( + torch.eq( + collated_batch["loss_mask"], + torch.tensor([[True, False, True, False, False], [False, True, False, True, True]]), + ) + ) + assert torch.all(torch.eq(collated_batch["is_random"], torch.tensor([[0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]))) + + +def test_bert_padding_collate_fn_with_max_length_truncates(): + # Create sample data + sample1 = { + "text": torch.tensor([1, 2, 3]), + "types": torch.zeros((3,), dtype=torch.int64), + "attention_mask": torch.tensor([True, True, False]), + "labels": torch.tensor([7, 8, 9]), + "loss_mask": torch.tensor([True, False, True]), + "is_random": torch.zeros((3,), dtype=torch.int64), + } + sample2 = { + "text": torch.tensor([4, 5, 6, 7, 8]), + "types": torch.zeros((5,), dtype=torch.int64), + "attention_mask": torch.tensor([True, True, True, True, True]), + "labels": torch.tensor([-1, 5, -1, 7, 8]), + "loss_mask": torch.tensor([False, True, False, True, True]), + "is_random": torch.zeros((5,), dtype=torch.int64), + } + batch = [sample1, sample2] + + # Call the collate_fn + collated_batch = bert_padding_collate_fn(batch, padding_value=10, max_length=4) + + # Assert the expected output + assert torch.all(torch.eq(collated_batch["text"], torch.tensor([[1, 2, 3, 10], [4, 5, 6, 7]]))) + assert torch.all(torch.eq(collated_batch["types"], torch.tensor([[0, 0, 0, 0], [0, 0, 0, 0]]))) + assert torch.all( + torch.eq( + collated_batch["attention_mask"], torch.tensor([[True, True, False, False], [True, True, True, True]]) + ) + ) + assert torch.all(torch.eq(collated_batch["labels"], torch.tensor([[7, 8, 9, -1], [-1, 5, -1, 7]]))) + assert torch.all( + torch.eq(collated_batch["loss_mask"], torch.tensor([[True, False, True, False], [False, True, False, True]])) + ) + assert torch.all(torch.eq(collated_batch["is_random"], torch.tensor([[0, 0, 0, 0], [0, 0, 0, 0]]))) + + +def test_bert_padding_collate_fn_with_min_length_pads_extra(): + # Create sample data + sample1 = { + "text": torch.tensor([1, 2, 3]), + "types": torch.zeros((3,), dtype=torch.int64), + "attention_mask": torch.tensor([True, True, False]), + "labels": torch.tensor([7, 8, 9]), + "loss_mask": torch.tensor([True, False, True]), + "is_random": torch.zeros((3,), dtype=torch.int64), + } + sample2 = { + "text": torch.tensor([10, 11, 12]), + "types": torch.zeros((3,), dtype=torch.int64), + "attention_mask": torch.tensor([True, False, True]), + "labels": torch.tensor([16, 17, 18]), + "loss_mask": torch.tensor([False, True, False]), + "is_random": torch.zeros((3,), dtype=torch.int64), + } + batch = [sample1, sample2] + + # Call the collate_fn + collated_batch = bert_padding_collate_fn(batch, padding_value=-1, min_length=5) + assert torch.all(torch.eq(collated_batch["text"], torch.tensor([[1, 2, 3, -1, -1], [10, 11, 12, -1, -1]]))) + for val in collated_batch.values(): + assert val.size(1) == 5 diff --git a/sub-packages/bionemo-llm/tests/bionemo/llm/data/test_masking.py b/sub-packages/bionemo-llm/tests/bionemo/llm/data/test_masking.py new file mode 100644 index 000000000..df0e26509 --- /dev/null +++ b/sub-packages/bionemo-llm/tests/bionemo/llm/data/test_masking.py @@ -0,0 +1,245 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import torch + +from bionemo.llm.data.masking import BertMaskConfig, add_cls_and_eos_tokens, apply_bert_pretraining_mask + + +def test_bert_mask_config_raises_with_invalid_probabilities(): + with pytest.raises(ValueError): + BertMaskConfig(mask_token=1, random_tokens=range(2, 4), mask_token_prob=0.9, random_token_prob=0.2) + + +def test_apply_bert_pretraining_mask(): + # fmt: off + tokenized_sequence = torch.tensor( + [20, 15, 11, 7, 10, 16, 9, 10, 4, 15, 8, 12, 7, 10, 12, 4, 9, + 10, 8, 15, 9, 14, 7, 8, 6, 5, 16, 4, 5, 9, 9, 4, 8, 7, + 8, 10, 16, 7, 12, 7, 16, 13, 12, 5, 19, 4, 10, 8, 4, 6, 19, + 17, 12, 7, 5, 11, 14, 10, 6, 19, 7, 4, 5, 6, 6]) + # fmt: on + + random_seed = 123 + + # Apply the function + masked_sequence, labels, loss_mask = apply_bert_pretraining_mask( + tokenized_sequence, + random_seed, + mask_config=BertMaskConfig(mask_token=32, random_tokens=range(4, 24)), + ) + + # Check the unmasked tokens are unchanged. + assert torch.allclose(masked_sequence[~loss_mask], tokenized_sequence[~loss_mask]) + + # Make sure the output labels are correct. + assert torch.allclose(labels[loss_mask], tokenized_sequence[loss_mask]) + + values, _ = torch.mode(masked_sequence[loss_mask]) + assert values.item() == 32 + + +def test_apply_bert_pretraining_mask_no_mask_token(): + # fmt: off + tokenized_sequence = torch.tensor( + [20, 15, 11, 7, 10, 16, 9, 10, 4, 15, 8, 12, 7, 10, 12, 4, 9, + 10, 8, 15, 9, 14, 7, 8, 6, 5, 16, 4, 5, 9, 9, 4, 8, 7, + 8, 10, 16, 7, 12, 7, 16, 13, 12, 5, 19, 4, 10, 8, 4, 6, 19, + 17, 12, 7, 5, 11, 14, 10, 6, 19, 7, 4, 5, 6, 6]) + # fmt: on + + random_seed = 123 + + # Apply the function + masked_sequence, labels, loss_mask = apply_bert_pretraining_mask( + tokenized_sequence, + random_seed, + mask_config=BertMaskConfig(mask_token_prob=0.0, mask_token=32, random_tokens=range(4, 24)), + ) + + # Check the unmasked tokens are unchanged. + assert torch.allclose(masked_sequence[~loss_mask], tokenized_sequence[~loss_mask]) + + # Make sure the output labels are correct. + assert torch.allclose(labels[loss_mask], tokenized_sequence[loss_mask]) + + # Make sure no mask tokens are in the output sequence + assert torch.all(masked_sequence != 32) + + +def test_apply_bert_pretraining_mask_changing_mask_prob(): + # fmt: off + tokenized_sequence = torch.tensor( + [20, 15, 11, 7, 10, 16, 9, 10, 4, 15, 8, 12, 7, 10, 12, 4, 9, + 10, 8, 15, 9, 14, 7, 8, 6, 5, 16, 4, 5, 9, 9, 4, 8, 7, + 8, 10, 16, 7, 12, 7, 16, 13, 12, 5, 19, 4, 10, 8, 4, 6, 19, + 17, 12, 7, 5, 11, 14, 10, 6, 19, 7, 4, 5, 6, 6]) + # fmt: on + + random_seed = 123 + + # Apply the function + masked_sequence, labels, loss_mask = apply_bert_pretraining_mask( + tokenized_sequence, + random_seed, + mask_config=BertMaskConfig(mask_prob=0.0, mask_token=32, random_tokens=range(4, 24)), + ) + + # All mask values should be False. + assert torch.all(~loss_mask) + + +def test_apply_bert_pretraining_mask_converges_to_correct_probability(): + sequence = torch.ones(100_000, dtype=torch.long) + random_seed = 123 + + masked_sequence, _, loss_mask = apply_bert_pretraining_mask( + sequence, + random_seed, + mask_config=BertMaskConfig( + mask_token=2, random_tokens=range(3, 5), mask_prob=0.5, mask_token_prob=0.25, random_token_prob=0.12 + ), + ) + + # Check that overall masking probability is correct. + assert pytest.approx(loss_mask.float().mean(), abs=0.01) == 0.5 + + # Check that the distribution of masked tokens is correct. + assert pytest.approx((masked_sequence == 2).float().mean(), abs=0.01) == 0.5 * 0.25 + + # Check that the distribution of random tokens is correct. + assert ( + pytest.approx(torch.logical_or(masked_sequence == 3, masked_sequence == 4).float().mean(), abs=0.01) + == 0.5 * 0.12 + ) + + # Check that the distribution of unmasked tokens is correct. + assert pytest.approx((masked_sequence[loss_mask] == 1).float().mean(), abs=0.01) == 1.0 - (0.25 + 0.12) + + +def test_apply_bert_pretraining_mask_is_reproducible_with_same_seed(): + torch.manual_seed(42) + tokenized_sequence = torch.randint(0, 100, (1000,)) + + # Apply the function + masked_sequence, labels, loss_mask = apply_bert_pretraining_mask( + tokenized_sequence, + 123, + mask_config=BertMaskConfig(mask_prob=0.5, mask_token=32, random_tokens=range(4, 24)), + ) + + for _ in range(10): + new_seq, new_labels, new_mask = apply_bert_pretraining_mask( + tokenized_sequence, + 123, + mask_config=BertMaskConfig(mask_prob=0.5, mask_token=32, random_tokens=range(4, 24)), + ) + + assert torch.allclose(masked_sequence, new_seq) + assert torch.allclose(labels, new_labels) + assert torch.allclose(loss_mask, new_mask) + + +def test_apply_bert_pretraining_mask_changes_with_new_seed(): + torch.manual_seed(42) + tokenized_sequence = torch.randint(0, 100, (1000,)) + + # Apply the function + masked_sequence, labels, loss_mask = apply_bert_pretraining_mask( + tokenized_sequence, + 123, + mask_config=BertMaskConfig(mask_prob=0.5, mask_token=32, random_tokens=range(4, 24)), + ) + + new_seq, new_labels, new_mask = apply_bert_pretraining_mask( + tokenized_sequence, + 321, + mask_config=BertMaskConfig(mask_prob=0.5, mask_token=32, random_tokens=range(4, 24)), + ) + + assert not torch.allclose(masked_sequence, new_seq) + assert not torch.allclose(labels, new_labels) + assert not torch.allclose(loss_mask, new_mask) + + +def test_add_cls_and_eos_tokens_both_tokens(): + sequence = torch.tensor([1, 2, 3]) + loss_mask = torch.tensor([False, True, False]) + labels = torch.tensor([-1, 2, -1]) + + augmented_sequence, augmented_labels, augmented_loss_mask = add_cls_and_eos_tokens( + sequence, labels, loss_mask, cls_token=0, eos_token=4 + ) + + assert len(augmented_sequence) == len(sequence) + 2 + assert augmented_sequence[0] == 0 + assert torch.allclose(augmented_sequence[1:-1], sequence) + assert augmented_sequence[-1] == 4 + + assert len(augmented_loss_mask) == len(loss_mask) + 2 + assert not augmented_loss_mask[0] + assert torch.allclose(augmented_loss_mask[1:-1], loss_mask) + assert not augmented_loss_mask[-1] + + assert len(augmented_labels) == len(labels) + 2 + assert augmented_labels[0] == -1 + assert torch.allclose(augmented_labels[1:-1], labels) + assert augmented_labels[-1] == -1 + + +def test_add_cls_and_eos_tokens_only_cls(): + sequence = torch.tensor([1, 2, 3]) + loss_mask = torch.tensor([False, True, False]) + labels = torch.tensor([-1, 2, -1]) + + augmented_sequence, augmented_labels, augmented_loss_mask = add_cls_and_eos_tokens( + sequence, labels, loss_mask, cls_token=0, eos_token=None + ) + + assert len(augmented_sequence) == len(sequence) + 1 + assert augmented_sequence[0] == 0 + assert torch.allclose(augmented_sequence[1:], sequence) + + assert len(augmented_loss_mask) == len(loss_mask) + 1 + assert not augmented_loss_mask[0] + assert torch.allclose(augmented_loss_mask[1:], loss_mask) + + assert len(augmented_labels) == len(labels) + 1 + assert augmented_labels[0] == -1 + assert torch.allclose(augmented_labels[1:], labels) + + +def test_add_cls_and_eos_tokens_only_bos(): + sequence = torch.tensor([1, 2, 3]) + loss_mask = torch.tensor([False, True, False]) + labels = torch.tensor([-1, 2, -1]) + + augmented_sequence, augmented_labels, augmented_loss_mask = add_cls_and_eos_tokens( + sequence, labels, loss_mask, cls_token=None, eos_token=4 + ) + + assert len(augmented_sequence) == len(sequence) + 1 + assert torch.allclose(augmented_sequence[:-1], sequence) + assert augmented_sequence[-1] == 4 + + assert len(augmented_loss_mask) == len(loss_mask) + 1 + assert torch.allclose(augmented_loss_mask[:-1], loss_mask) + assert not augmented_loss_mask[-1] + + assert len(augmented_labels) == len(labels) + 1 + assert torch.allclose(augmented_labels[:-1], labels) + assert augmented_labels[-1] == -1