Skip to content

adding some additional docstrings #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 16 additions & 12 deletions sub-packages/bionemo-core/src/bionemo/core/data/resamplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,27 @@
from torch.utils.data import Dataset


class PRNGDatasetShuffler(Dataset): # noqa: D101
def __init__(self, dataset: Dataset, seed: int = 42, num_samples: Optional[int] = None):
"""Initializes the PRNGDatasetShuffler. PRNGDatasetShuffler shuffles a given dataset using a pseudo-random number generator (PRNG).
This allows for reproducible shuffling by controlling the random seed, while not ever storing the list of indices in memory.
It works by generating random indices assuming that the requesting function asks for them sequentially.
Although random lookups are supported, random lookups will involve recomputing state which is slow, and involves
linearly advancing from 0 if the last requested index was greater than or equal to this requested index.
This should work well with the megatron sampler which is sequential. It handles
skipped lookups as will happen with multiple workers by not generating those numbers.
class PRNGDatasetShuffler(Dataset):
"""A thread-safe dataset shuffler that uses a pseudo-random number generator (PRNG) to shuffle the dataset.

PRNGDatasetShuffler shuffles a given dataset using a pseudo-random number generator (PRNG). This allows for
reproducible shuffling by controlling the random seed, while not ever storing the list of indices in memory. It
works by generating random indices assuming that the requesting function asks for them sequentially. Although random
lookups are supported, random lookups will involve recomputing state which is slow, and involves linearly advancing
from 0 if the last requested index was greater than or equal to this requested index. This should work well with the
megatron sampler which is sequential. It handles skipped lookups as will happen with multiple workers by not
generating those numbers.
"""

def __init__(self, dataset: Dataset, seed: int = 42, num_samples: Optional[int] = None):
"""Initializes the PRNGDatasetShuffler.

Args:
dataset (Dataset): The dataset to be shuffled.
seed (int, optional): The seed value for the PRNG. Default is 42.
num_samples (Optional[int], optional): The number of samples to draw from the dataset.
If None, the length of the dataset is used. Default is None.
""" # noqa: D205
"""
self.initial_seed = seed
self.rng = random.Random(seed)
self.dataset_len = len(dataset)
Expand All @@ -54,11 +58,11 @@ def rand_idx(self) -> int:
"""Generates a random index within the range of the dataset size."""
return self.rng.randint(0, self.dataset_len - 1)

def advance_state(self, num_to_advance: int): # noqa: D417
def advance_state(self, num_to_advance: int):
"""Advances the PRNG state by generating n_to_advance random indices.

Args:
n_to_advance (int): The number of random state steps to advance.
num_to_advance: The number of random state steps to advance.
"""
for _ in range(num_to_advance):
self.rand_idx()
Expand Down
14 changes: 10 additions & 4 deletions sub-packages/bionemo-core/src/bionemo/core/model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,19 @@
Model = TypeVar("Model")


class BionemoModelConfig(Generic[Model], ABC): # D101 # noqa: D101
class BionemoModelConfig(Generic[Model], ABC):
"""An abstract class for model configuration."""

@abstractmethod
def configure_model(self, *args, **kwargs) -> Model: # D101 # noqa: D102
def configure_model(self, *args, **kwargs) -> Model:
"""Configures the model."""
raise NotImplementedError()


class BionemoTrainableModelConfig(Generic[Model, Loss], BionemoModelConfig[Model]): # D101 # noqa: D101
class BionemoTrainableModelConfig(Generic[Model, Loss], BionemoModelConfig[Model]):
"""An abstract class for trainable model configuration."""

@abstractmethod
def get_loss_reduction_class(self) -> Type[Loss]: # D101 # noqa: D102
def get_loss_reduction_class(self) -> Type[Loss]:
"""Returns the loss reduction class."""
Comment on lines +32 to +46
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this ChatGPT output @pstjohn ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah copilot most likely

raise NotImplementedError()
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def pad_token_ids(
pad_size_divisible_by: int = 1,
**convert_to_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Pads token ids with padding value, and return the padded tokens and
the corresponding mask.
"""Pads token ids with padding value, and return the padded tokens and the corresponding mask.

Args:
token_ids: List of token ids or tensors
Expand All @@ -42,7 +41,7 @@ def pad_token_ids(

Returns:
Tuple[List[int], List[int]]: Padded token ids and mask
""" # noqa: D205
"""
lengths = torch.tensor([len(s) for s in token_ids])
if padding_len is None:
padding_len = lengths.max()
Expand Down
13 changes: 12 additions & 1 deletion sub-packages/bionemo-core/src/bionemo/core/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,18 @@
PrecisionTypes = Literal["fp16", "bf16", "fp32", "bf16-mixed", "fp32-mixed", "16-mixed", "fp16-mixed", 16, 32]


def get_autocast_dtype(precision: PrecisionTypes) -> torch.dtype: # noqa: D103
def get_autocast_dtype(precision: PrecisionTypes) -> torch.dtype:
"""Returns the torch dtype corresponding to the given precision.

Args:
precision: The precision type.

Returns:
torch.dtype: The torch dtype corresponding to the given precision.

Raises:
ValueError: If the precision is not supported.
"""
# TODO move this to a utilities folder, or find/import the function that does this in NeMo
if precision == "fp16":
return torch.float16
Expand Down
12 changes: 6 additions & 6 deletions sub-packages/bionemo-core/src/bionemo/core/utils/random_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@

@contextmanager
def random_numpy_context(seed: int = 42) -> Iterator[None]:
"""Context manager for setting numpy random state, where the state is saved on entry
and restored on exit to what it was. This way you can run code that needs random
state in a `with` context using this function, and get back to whatever state was
there before. This is useful for testing where you don't want the random state from
one test to impact other tests.
"""Context manager for setting numpy random state.

The state is saved on entry and restored on exit to what it was. This way you can run code that needs random state
in a `with` context using this function, and get back to whatever state was there before. This is useful for testing
where you don't want the random state from one test to impact other tests.

Example:
>>> import numpy as np
Expand All @@ -37,7 +37,7 @@ def random_numpy_context(seed: int = 42) -> Iterator[None]:
np.random.randint(5) # this will change the state
>>> new_state = np.random.get_state()
>>> assert ori_state == new_state
""" # noqa: D205
"""
state = np.random.get_state() # just fail if this fails
try:
np.random.seed(seed)
Expand Down
31 changes: 27 additions & 4 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/model/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@


class ESM2DotProductAttention(DotProductAttention):
"""ESM2-Specific core attention
"""ESM2-Specific core attention.

Region where selective activation recomputation is applied.
This region is memory intensive but less compute intensive which
Expand All @@ -44,16 +44,25 @@ class ESM2DotProductAttention(DotProductAttention):
p: number of tensor model parallel partitions
b: batch size
s: sequence length
""" # noqa: D415
"""

def __init__( # noqa: D107
def __init__(
self,
config: TransformerConfig,
layer_number: int,
attn_mask_type: AttnMaskType,
attention_type: str,
attention_dropout: Optional[float] = None,
) -> None:
"""Initializes the Attention class.

Args:
config: The configuration object for the transformer.
layer_number: The layer number of the attention module.
attn_mask_type: The type of attention mask to be used.
attention_type: The type of attention mechanism.
attention_dropout: The dropout rate for attention weights. Defaults to None.
"""
super().__init__(
config=config,
layer_number=layer_number,
Expand All @@ -65,7 +74,7 @@ def __init__( # noqa: D107
self.use_esm_attention = config.use_esm_attention
self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32

def forward( # noqa: D102
def forward(
self,
query: Tensor,
key: Tensor,
Expand All @@ -74,6 +83,20 @@ def forward( # noqa: D102
attn_mask_type: Optional[AttnMaskType] = None,
packed_seq_params: Optional[PackedSeqParams] = None,
):
"""Forward pass of the ESM2DotProductAttention module.

Args:
query: The query tensor of shape [sq, b, np, hn].
key: The key tensor of shape [sk, b, ng, hn].
value: The value tensor of shape [sk, b, ng, hn].
attention_mask: The attention mask tensor of shape [b, np, sq, sk].
attn_mask_type: The attention mask type, currently unused. Defaults to None.
packed_seq_params: The packed sequence parameters. These are used for context parallelism so will be needed
to be implemented if we want to support this. Defaults to None.

Returns:
Tensor: The context tensor of shape [sq, b, hp].
"""
if packed_seq_params is not None:
raise ValueError(
"Packed sequence is not supported by DotProductAttention. " "Please use TEDotProductAttention instead."
Expand Down
9 changes: 5 additions & 4 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@


class ESM2Embedding(LanguageModelEmbedding):
"""ESM2 Embedding with custom logic for attention masking and token dropout""" # noqa: D415
"""ESM2 Embedding with custom logic for attention masking and token dropout."""

def __init__( # noqa: D107
def __init__(
self,
config: TransformerConfig,
vocab_size: int,
Expand All @@ -46,6 +46,7 @@ def __init__( # noqa: D107
use_attention_mask: bool = True,
mask_token_id: Optional[int] = torch.nan,
) -> None:
"""Initialize the ESM2 Embedding module."""
super().__init__(
config=config,
vocab_size=vocab_size,
Expand All @@ -60,7 +61,7 @@ def __init__( # noqa: D107
def _apply_esm2_customization(
self, word_embeddings: Tensor, input_ids: Tensor, attention_mask: Tensor
) -> Tuple[Tensor, Tensor]:
"""ESM2 customization for attention masking and token dropout
"""ESM2 customization for attention masking and token dropout.

Args:
word_embeddings (Tensor[float]): The input tokens. Shape: [b, s, h]
Expand All @@ -69,7 +70,7 @@ def _apply_esm2_customization(

Returns:
Tuple[Tensor, Tensor]: (Updated embeddings, embedding mask) Shape: ([b, s, h], [b, s])
""" # noqa: D415
"""
embeddings_mask = None
if attention_mask is not None and (self.token_dropout or self.use_attention_mask):
embeddings_mask = attention_mask
Expand Down
25 changes: 18 additions & 7 deletions sub-packages/bionemo-esm2/src/bionemo/esm2/model/lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,18 @@
)


class SchedulerOutput(TypedDict): # noqa: D101
class SchedulerOutput(TypedDict):
"""Output of the scheduler method."""

optimizer: MegatronOptimizerModule
lr_scheduler: dict
monitor: str


class WarmupAnnealDecayHold(_LRScheduler): # noqa: D101
def __init__( # noqa: D417
class WarmupAnnealDecayHold(_LRScheduler):
"""Warmup Anneal Decay Hold learning rate scheduler."""

def __init__(
self,
optimizer: MegatronOptimizerModule,
*,
Expand All @@ -51,9 +55,13 @@ def __init__( # noqa: D417
"""Initializes the WarmupAnnealDecayHold learning rate scheduler.

Args:
max_steps (int): Total number of training steps.
optimizer: Optimizer to apply the learning rate scheduler.
warmup_steps (int): Number of steps for the linear warm-up.
max_steps (int): Total number of training steps.
max_lr (float): Peak learning rate to be achieved after warm-up.
min_lr (float): Minimum learning rate.
anneal_percentage (float): Percentage of the max_lr to hold after decay.
last_epoch (int): The index of the last epoch.
"""
self.warmup_steps = warmup_steps
self.max_steps = max_steps
Expand All @@ -67,7 +75,8 @@ def __init__( # noqa: D417

super(WarmupAnnealDecayHold, self).__init__(optimizer, last_epoch)

def get_lr(self) -> List[float]: # noqa: D102
def get_lr(self) -> List[float]:
"""Get the learning rate at the current step."""
step_num = self.last_epoch
if step_num < self.warmup_steps:
lr = self.min_lr + (self.max_lr - self.min_lr) * step_num / self.warmup_steps
Expand All @@ -82,7 +91,7 @@ def get_lr(self) -> List[float]: # noqa: D102
class WarmupAnnealDecayHoldScheduler(LRSchedulerModule):
"""Warmup Policy Learning Rate Scheduler."""

def __init__( # noqa: D107
def __init__(
self,
warmup_steps: int = 2000,
max_steps: int = 500_000,
Expand All @@ -93,6 +102,7 @@ def __init__( # noqa: D107
frequency: int = 1,
monitor: str = "val_loss",
) -> None:
"""Initializes the WarmupAnnealDecayHoldScheduler."""
super().__init__()
self.warmup_steps = warmup_steps
self.max_steps = max_steps
Expand All @@ -103,7 +113,8 @@ def __init__( # noqa: D107
self.frequency = frequency
self.monitor = monitor

def scheduler(self, model: MegatronBioBertModel, optimizer: MegatronOptimizerModule) -> SchedulerOutput: # noqa: D102
def scheduler(self, model: MegatronBioBertModel, optimizer: MegatronOptimizerModule) -> SchedulerOutput:
"""Returns the scheduler output."""
lr_scheduler = WarmupAnnealDecayHold(
optimizer,
warmup_steps=self.warmup_steps,
Expand Down
Loading