diff --git a/sub-packages/bionemo-core/src/bionemo/core/data/resamplers.py b/sub-packages/bionemo-core/src/bionemo/core/data/resamplers.py index 10f46ff1b..e9848eb9d 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/data/resamplers.py +++ b/sub-packages/bionemo-core/src/bionemo/core/data/resamplers.py @@ -21,23 +21,27 @@ from torch.utils.data import Dataset -class PRNGDatasetShuffler(Dataset): # noqa: D101 - def __init__(self, dataset: Dataset, seed: int = 42, num_samples: Optional[int] = None): - """Initializes the PRNGDatasetShuffler. PRNGDatasetShuffler shuffles a given dataset using a pseudo-random number generator (PRNG). - This allows for reproducible shuffling by controlling the random seed, while not ever storing the list of indices in memory. - It works by generating random indices assuming that the requesting function asks for them sequentially. - Although random lookups are supported, random lookups will involve recomputing state which is slow, and involves - linearly advancing from 0 if the last requested index was greater than or equal to this requested index. - This should work well with the megatron sampler which is sequential. It handles - skipped lookups as will happen with multiple workers by not generating those numbers. +class PRNGDatasetShuffler(Dataset): + """A thread-safe dataset shuffler that uses a pseudo-random number generator (PRNG) to shuffle the dataset. + + PRNGDatasetShuffler shuffles a given dataset using a pseudo-random number generator (PRNG). This allows for + reproducible shuffling by controlling the random seed, while not ever storing the list of indices in memory. It + works by generating random indices assuming that the requesting function asks for them sequentially. Although random + lookups are supported, random lookups will involve recomputing state which is slow, and involves linearly advancing + from 0 if the last requested index was greater than or equal to this requested index. This should work well with the + megatron sampler which is sequential. It handles skipped lookups as will happen with multiple workers by not + generating those numbers. + """ + def __init__(self, dataset: Dataset, seed: int = 42, num_samples: Optional[int] = None): + """Initializes the PRNGDatasetShuffler. Args: dataset (Dataset): The dataset to be shuffled. seed (int, optional): The seed value for the PRNG. Default is 42. num_samples (Optional[int], optional): The number of samples to draw from the dataset. If None, the length of the dataset is used. Default is None. - """ # noqa: D205 + """ self.initial_seed = seed self.rng = random.Random(seed) self.dataset_len = len(dataset) @@ -54,11 +58,11 @@ def rand_idx(self) -> int: """Generates a random index within the range of the dataset size.""" return self.rng.randint(0, self.dataset_len - 1) - def advance_state(self, num_to_advance: int): # noqa: D417 + def advance_state(self, num_to_advance: int): """Advances the PRNG state by generating n_to_advance random indices. Args: - n_to_advance (int): The number of random state steps to advance. + num_to_advance: The number of random state steps to advance. """ for _ in range(num_to_advance): self.rand_idx() diff --git a/sub-packages/bionemo-core/src/bionemo/core/model/config.py b/sub-packages/bionemo-core/src/bionemo/core/model/config.py index 3267b72ad..46d5b1336 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/model/config.py +++ b/sub-packages/bionemo-core/src/bionemo/core/model/config.py @@ -29,13 +29,19 @@ Model = TypeVar("Model") -class BionemoModelConfig(Generic[Model], ABC): # D101 # noqa: D101 +class BionemoModelConfig(Generic[Model], ABC): + """An abstract class for model configuration.""" + @abstractmethod - def configure_model(self, *args, **kwargs) -> Model: # D101 # noqa: D102 + def configure_model(self, *args, **kwargs) -> Model: + """Configures the model.""" raise NotImplementedError() -class BionemoTrainableModelConfig(Generic[Model, Loss], BionemoModelConfig[Model]): # D101 # noqa: D101 +class BionemoTrainableModelConfig(Generic[Model, Loss], BionemoModelConfig[Model]): + """An abstract class for trainable model configuration.""" + @abstractmethod - def get_loss_reduction_class(self) -> Type[Loss]: # D101 # noqa: D102 + def get_loss_reduction_class(self) -> Type[Loss]: + """Returns the loss reduction class.""" raise NotImplementedError() diff --git a/sub-packages/bionemo-core/src/bionemo/core/utils/batching_utils.py b/sub-packages/bionemo-core/src/bionemo/core/utils/batching_utils.py index 004a160a8..08325ea94 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/utils/batching_utils.py +++ b/sub-packages/bionemo-core/src/bionemo/core/utils/batching_utils.py @@ -30,8 +30,7 @@ def pad_token_ids( pad_size_divisible_by: int = 1, **convert_to_kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - """Pads token ids with padding value, and return the padded tokens and - the corresponding mask. + """Pads token ids with padding value, and return the padded tokens and the corresponding mask. Args: token_ids: List of token ids or tensors @@ -42,7 +41,7 @@ def pad_token_ids( Returns: Tuple[List[int], List[int]]: Padded token ids and mask - """ # noqa: D205 + """ lengths = torch.tensor([len(s) for s in token_ids]) if padding_len is None: padding_len = lengths.max() diff --git a/sub-packages/bionemo-core/src/bionemo/core/utils/dtypes.py b/sub-packages/bionemo-core/src/bionemo/core/utils/dtypes.py index 44faeea73..9520e62f3 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/utils/dtypes.py +++ b/sub-packages/bionemo-core/src/bionemo/core/utils/dtypes.py @@ -27,7 +27,18 @@ PrecisionTypes = Literal["fp16", "bf16", "fp32", "bf16-mixed", "fp32-mixed", "16-mixed", "fp16-mixed", 16, 32] -def get_autocast_dtype(precision: PrecisionTypes) -> torch.dtype: # noqa: D103 +def get_autocast_dtype(precision: PrecisionTypes) -> torch.dtype: + """Returns the torch dtype corresponding to the given precision. + + Args: + precision: The precision type. + + Returns: + torch.dtype: The torch dtype corresponding to the given precision. + + Raises: + ValueError: If the precision is not supported. + """ # TODO move this to a utilities folder, or find/import the function that does this in NeMo if precision == "fp16": return torch.float16 diff --git a/sub-packages/bionemo-core/src/bionemo/core/utils/random_utils.py b/sub-packages/bionemo-core/src/bionemo/core/utils/random_utils.py index 8392ca0a0..1c29c0a83 100644 --- a/sub-packages/bionemo-core/src/bionemo/core/utils/random_utils.py +++ b/sub-packages/bionemo-core/src/bionemo/core/utils/random_utils.py @@ -23,11 +23,11 @@ @contextmanager def random_numpy_context(seed: int = 42) -> Iterator[None]: - """Context manager for setting numpy random state, where the state is saved on entry - and restored on exit to what it was. This way you can run code that needs random - state in a `with` context using this function, and get back to whatever state was - there before. This is useful for testing where you don't want the random state from - one test to impact other tests. + """Context manager for setting numpy random state. + + The state is saved on entry and restored on exit to what it was. This way you can run code that needs random state + in a `with` context using this function, and get back to whatever state was there before. This is useful for testing + where you don't want the random state from one test to impact other tests. Example: >>> import numpy as np @@ -37,7 +37,7 @@ def random_numpy_context(seed: int = 42) -> Iterator[None]: np.random.randint(5) # this will change the state >>> new_state = np.random.get_state() >>> assert ori_state == new_state - """ # noqa: D205 + """ state = np.random.get_state() # just fail if this fails try: np.random.seed(seed) diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/attention.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/attention.py index 6b7f5ce36..3422547d6 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/attention.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/attention.py @@ -30,7 +30,7 @@ class ESM2DotProductAttention(DotProductAttention): - """ESM2-Specific core attention + """ESM2-Specific core attention. Region where selective activation recomputation is applied. This region is memory intensive but less compute intensive which @@ -44,9 +44,9 @@ class ESM2DotProductAttention(DotProductAttention): p: number of tensor model parallel partitions b: batch size s: sequence length - """ # noqa: D415 + """ - def __init__( # noqa: D107 + def __init__( self, config: TransformerConfig, layer_number: int, @@ -54,6 +54,15 @@ def __init__( # noqa: D107 attention_type: str, attention_dropout: Optional[float] = None, ) -> None: + """Initializes the Attention class. + + Args: + config: The configuration object for the transformer. + layer_number: The layer number of the attention module. + attn_mask_type: The type of attention mask to be used. + attention_type: The type of attention mechanism. + attention_dropout: The dropout rate for attention weights. Defaults to None. + """ super().__init__( config=config, layer_number=layer_number, @@ -65,7 +74,7 @@ def __init__( # noqa: D107 self.use_esm_attention = config.use_esm_attention self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 - def forward( # noqa: D102 + def forward( self, query: Tensor, key: Tensor, @@ -74,6 +83,20 @@ def forward( # noqa: D102 attn_mask_type: Optional[AttnMaskType] = None, packed_seq_params: Optional[PackedSeqParams] = None, ): + """Forward pass of the ESM2DotProductAttention module. + + Args: + query: The query tensor of shape [sq, b, np, hn]. + key: The key tensor of shape [sk, b, ng, hn]. + value: The value tensor of shape [sk, b, ng, hn]. + attention_mask: The attention mask tensor of shape [b, np, sq, sk]. + attn_mask_type: The attention mask type, currently unused. Defaults to None. + packed_seq_params: The packed sequence parameters. These are used for context parallelism so will be needed + to be implemented if we want to support this. Defaults to None. + + Returns: + Tensor: The context tensor of shape [sq, b, hp]. + """ if packed_seq_params is not None: raise ValueError( "Packed sequence is not supported by DotProductAttention. " "Please use TEDotProductAttention instead." diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/embedding.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/embedding.py index dc6f50ffa..c2bfa211d 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/embedding.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/embedding.py @@ -32,9 +32,9 @@ class ESM2Embedding(LanguageModelEmbedding): - """ESM2 Embedding with custom logic for attention masking and token dropout""" # noqa: D415 + """ESM2 Embedding with custom logic for attention masking and token dropout.""" - def __init__( # noqa: D107 + def __init__( self, config: TransformerConfig, vocab_size: int, @@ -46,6 +46,7 @@ def __init__( # noqa: D107 use_attention_mask: bool = True, mask_token_id: Optional[int] = torch.nan, ) -> None: + """Initialize the ESM2 Embedding module.""" super().__init__( config=config, vocab_size=vocab_size, @@ -60,7 +61,7 @@ def __init__( # noqa: D107 def _apply_esm2_customization( self, word_embeddings: Tensor, input_ids: Tensor, attention_mask: Tensor ) -> Tuple[Tensor, Tensor]: - """ESM2 customization for attention masking and token dropout + """ESM2 customization for attention masking and token dropout. Args: word_embeddings (Tensor[float]): The input tokens. Shape: [b, s, h] @@ -69,7 +70,7 @@ def _apply_esm2_customization( Returns: Tuple[Tensor, Tensor]: (Updated embeddings, embedding mask) Shape: ([b, s, h], [b, s]) - """ # noqa: D415 + """ embeddings_mask = None if attention_mask is not None and (self.token_dropout or self.use_attention_mask): embeddings_mask = attention_mask diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/lr_scheduler.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/lr_scheduler.py index 59a2fc851..29e7081a9 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/lr_scheduler.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/lr_scheduler.py @@ -30,14 +30,18 @@ ) -class SchedulerOutput(TypedDict): # noqa: D101 +class SchedulerOutput(TypedDict): + """Output of the scheduler method.""" + optimizer: MegatronOptimizerModule lr_scheduler: dict monitor: str -class WarmupAnnealDecayHold(_LRScheduler): # noqa: D101 - def __init__( # noqa: D417 +class WarmupAnnealDecayHold(_LRScheduler): + """Warmup Anneal Decay Hold learning rate scheduler.""" + + def __init__( self, optimizer: MegatronOptimizerModule, *, @@ -51,9 +55,13 @@ def __init__( # noqa: D417 """Initializes the WarmupAnnealDecayHold learning rate scheduler. Args: - max_steps (int): Total number of training steps. + optimizer: Optimizer to apply the learning rate scheduler. warmup_steps (int): Number of steps for the linear warm-up. + max_steps (int): Total number of training steps. max_lr (float): Peak learning rate to be achieved after warm-up. + min_lr (float): Minimum learning rate. + anneal_percentage (float): Percentage of the max_lr to hold after decay. + last_epoch (int): The index of the last epoch. """ self.warmup_steps = warmup_steps self.max_steps = max_steps @@ -67,7 +75,8 @@ def __init__( # noqa: D417 super(WarmupAnnealDecayHold, self).__init__(optimizer, last_epoch) - def get_lr(self) -> List[float]: # noqa: D102 + def get_lr(self) -> List[float]: + """Get the learning rate at the current step.""" step_num = self.last_epoch if step_num < self.warmup_steps: lr = self.min_lr + (self.max_lr - self.min_lr) * step_num / self.warmup_steps @@ -82,7 +91,7 @@ def get_lr(self) -> List[float]: # noqa: D102 class WarmupAnnealDecayHoldScheduler(LRSchedulerModule): """Warmup Policy Learning Rate Scheduler.""" - def __init__( # noqa: D107 + def __init__( self, warmup_steps: int = 2000, max_steps: int = 500_000, @@ -93,6 +102,7 @@ def __init__( # noqa: D107 frequency: int = 1, monitor: str = "val_loss", ) -> None: + """Initializes the WarmupAnnealDecayHoldScheduler.""" super().__init__() self.warmup_steps = warmup_steps self.max_steps = max_steps @@ -103,7 +113,8 @@ def __init__( # noqa: D107 self.frequency = frequency self.monitor = monitor - def scheduler(self, model: MegatronBioBertModel, optimizer: MegatronOptimizerModule) -> SchedulerOutput: # noqa: D102 + def scheduler(self, model: MegatronBioBertModel, optimizer: MegatronOptimizerModule) -> SchedulerOutput: + """Returns the scheduler output.""" lr_scheduler = WarmupAnnealDecayHold( optimizer, warmup_steps=self.warmup_steps, diff --git a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py index 3a90ba178..5c944bbe4 100644 --- a/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py +++ b/sub-packages/bionemo-esm2/src/bionemo/esm2/model/model.py @@ -49,26 +49,9 @@ class ESM2Model(MegatronBioBertModel): - """ESM2 Transformer language model. + """ESM2 Transformer language model.""" - Args: - config (TransformerConfig): transformer config - num_tokentypes (int): Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0. - transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers - vocab_size (int): vocabulary size - max_sequence_length (int): maximum size of sequence. This is used for positional embedding - tokenizer (AutoTokenizer): optional tokenizer object (currently only used in the constructor of ESM2Model) - pre_process (bool): Include embedding layer (used with pipeline parallelism) - post_process (bool): Include an output layer (used with pipeline parallelism) - parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks - share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are shared. Defaults to False. - position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope']. - Defaults is 'learned_absolute'. - rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings. - Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'. - """ - - def __init__( # noqa: D107 + def __init__( self, config: TransformerConfig, num_tokentypes: int, @@ -88,6 +71,29 @@ def __init__( # noqa: D107 return_embeddings=False, use_full_attention_mask=False, ) -> None: + """Initialize the ESM2 model. + + Args: + config (TransformerConfig): transformer config + num_tokentypes (int): Set to 2 when args.bert_binary_head is True, and 0 otherwise. Defaults to 0. + transformer_layer_spec (ModuleSpec): Specifies module to use for transformer layers + vocab_size (int): vocabulary size + max_sequence_length (int): maximum size of sequence. This is used for positional embedding + tokenizer (AutoTokenizer): optional tokenizer object (currently only used in the constructor of ESM2Model) + pre_process (bool): Include embedding layer (used with pipeline parallelism) + post_process (bool): Include an output layer (used with pipeline parallelism) + fp16_lm_cross_entropy: Whether to move the cross entropy unreduced loss calculation for lm head to fp16. + parallel_output (bool): Do not gather the outputs, keep them split across tensor parallel ranks + share_embeddings_and_output_weights (bool): When True, input embeddings and output logit weights are shared. Defaults to False. + position_embedding_type (string): Position embedding type. Options ['learned_absolute', 'rope']. + Defaults is 'learned_absolute'. + rotary_percent (float): Percent of rotary dimension to use for rotary position embeddings. + Defaults to 1.0 (100%). Ignored unless position_embedding_type is 'rope'. + seq_len_interpolation_factor (Optional[float]): Interpolation factor for sequence length. Defaults to None. + add_binary_head (bool): Whether to add a binary head. Defaults to True. + return_embeddings (bool): Whether to return embeddings. Defaults to False. + use_full_attention_mask (bool): Whether to use full attention mask. Defaults to False. + """ super(MegatronBioBertModel, self).__init__(config=config) self.post_process = post_process self.add_binary_head = add_binary_head @@ -174,9 +180,20 @@ def __init__( # noqa: D107 if self.pre_process or self.post_process: self.setup_embeddings_and_output_layer() - def embedding_forward( # noqa: D102 + def embedding_forward( self, input_ids: Tensor, position_ids: Tensor, tokentype_ids: Tensor = None, attention_mask: Tensor = None ): + """Forward pass of the embedding layer. + + Args: + input_ids: The input tensor of shape (batch_size, sequence_length) containing the input IDs. + position_ids: The tensor of shape (batch_size, sequence_length) containing the position IDs. + tokentype_ids: The tensor of shape (batch_size, sequence_length) containing the token type IDs. Defaults to None. + attention_mask: The tensor of shape (batch_size, sequence_length) containing the attention mask. Defaults to None. + + Returns: + Tensor: The output tensor of shape (batch_size, sequence_length, hidden_size) containing the embedded representations. + """ # ESM2 Customization: ESM2Embedding forward takes attention_mask # in addition to the args required by LanguageModelEmbedding return self.embedding( @@ -184,18 +201,58 @@ def embedding_forward( # noqa: D102 ) -def esm_gelu_func(x: Tensor) -> Tensor: # D205 # D205 +def esm_gelu_func(x: Tensor) -> Tensor: """ESM2-specific gelu implementation from the original ESM repo. - Using F.gelu yields subtly wrong results. + + !!! warning + + Using F.gelu yields subtly wrong results. Args: x: input tensor of any given dimension - """ # noqa: D205 + """ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) @dataclass -class ESM2Config(BionemoModelConfig[ESM2Model], TransformerConfig): # noqa: D101 +class ESM2Config(BionemoModelConfig[ESM2Model], TransformerConfig): + """Configuration class for ESM2 model. + + Attributes: + num_layers: Number of layers in the model. + hidden_size: Hidden size of the model. + num_attention_heads: Number of attention heads in the model. + ffn_hidden_size: Hidden size of the feed-forward network. + hidden_dropout: Dropout rate for hidden layers. + attention_dropout: Dropout rate for attention layers. + apply_residual_connection_post_layernorm: Whether to apply residual connection after layer normalization. + layernorm_epsilon: Epsilon value for layer normalization. + layernorm_zero_centered_gamma: Whether to zero-center the gamma parameter in layer normalization. + activation_func: Activation function used in the model. + init_method_std: Standard deviation for weight initialization. + apply_query_key_layer_scaling: Whether to apply scaling to query and key layers. + masked_softmax_fusion: Whether to use a kernel that fuses attention softmax with its mask. + fp16_lm_cross_entropy: Whether to move the cross entropy unreduced loss calculation for lm head to fp16. + share_embeddings_and_output_weights: Whether to share embeddings and output weights. + enable_autocast: Whether to enable autocast for mixed precision. + biobert_spec_option: BiobertSpecOption for the model. + position_embedding_type: Type of position embedding used in the model. + seq_length: Length of the input sequence. + make_vocab_size_divisible_by: Make the vocabulary size divisible by this value. + token_dropout: Whether to apply token dropout. + use_attention_mask: Whether to use attention mask. + use_esm_attention: Whether to use ESM attention. + attention_softmax_in_fp32: Whether to use fp32 for attention softmax. + optimizer_fn: Optional optimizer function for the model. + parallel_output: Whether to use parallel output. + rotary_base: Base value for rotary positional encoding. + rotary_percent: Percentage of rotary positional encoding. + seq_len_interpolation_factor: Interpolation factor for sequence length. + get_attention_mask_from_fusion: Whether to get attention mask from fusion. + nemo1_ckpt_path: Path to NEMO1 checkpoint. + return_only_hidden_states: Whether to return only hidden states. + """ + num_layers: int = 33 # 650M hidden_size: int = 1280 # 650M num_attention_heads: int = 20 @@ -246,7 +303,15 @@ class ESM2Config(BionemoModelConfig[ESM2Model], TransformerConfig): # noqa: D10 return_only_hidden_states: bool = False - def configure_model(self, tokenizer) -> ESM2Model: # noqa: D102 + def configure_model(self, tokenizer) -> ESM2Model: + """Configures the ESM2Model with the given tokenizer. + + Args: + tokenizer: The tokenizer to be used. + + Returns: + An instance of ESM2Model configured with the specified parameters. + """ vp_size = self.virtual_pipeline_model_parallel_size if vp_size: p_size = self.pipeline_model_parallel_size diff --git a/sub-packages/bionemo-example_model/src/bionemo/example_model/lightning_basic.py b/sub-packages/bionemo-example_model/src/bionemo/example_model/lightning_basic.py index 9c9809418..59b4ab8c9 100644 --- a/sub-packages/bionemo-example_model/src/bionemo/example_model/lightning_basic.py +++ b/sub-packages/bionemo-example_model/src/bionemo/example_model/lightning_basic.py @@ -45,9 +45,10 @@ @dataclass class ExampleConfig(ModelParallelConfig): - """Timers from ModelParallelConfig are required (apparently). - For megatron forward compatibility. - """ # noqa: D205 + """ExampleConfig is a dataclass that is used to configure the model. + + Timers from ModelParallelConfig are required for megatron forward compatibility. + """ calculate_per_token_loss: bool = False @@ -100,14 +101,14 @@ def reduce(self, losses_reduced_per_micro_batch: Sequence[ReductionT]) -> Tensor class LitAutoEncoder(pl.LightningModule): - """A very basic lightning module example that is used for testing the megatron strategy and the megatron-nemo2-bionemo - contract. - """ # noqa: D205 + """A very basic lightning module for testing the megatron strategy and the megatron-nemo2-bionemo contract.""" def __init__(self, config): - """Args: - config: a Config object necessary to construct the actual nn.Module (the thing that has the parameters). - """ # noqa: D205 + """Initializes the model. + + Args: + config: a Config object necessary to construct the actual nn.Module (the thing that has the parameters). + """ super().__init__() self.config = config self.optim = MegatronOptimizerModule( @@ -118,7 +119,10 @@ def __init__(self, config): def forward(self, batch: Dict, batch_idx: int): """This forward will be called by the megatron scheduler and it will be wrapped. - Note: The `training_step` defines the training loop and is independent of the `forward` method here. + + !!! note + + The `training_step` defines the training loop and is independent of the `forward` method here. Args: batch: A dictionary of data. @@ -126,12 +130,14 @@ def forward(self, batch: Dict, batch_idx: int): Returns: The output of the model. - """ # noqa: D205 + """ x = batch["data"] return self.module(x) - def training_step(self, batch, batch_idx: Optional[int] = None): # noqa: D417 - """Background: + def training_step(self, batch, batch_idx: Optional[int] = None): + """The training step is where the loss is calculated and the backpropagation is done. + + Background: - NeMo's Strategy overrides this method. - The strategies' training step will call the forward method of the model. - That forward method then calls the wrapped forward step of MegatronParallel which wraps the forward method of the model. @@ -142,9 +148,9 @@ def training_step(self, batch, batch_idx: Optional[int] = None): # noqa: D417 In this particular use case, we simply call the forward method of this class, the lightning module. Args: - batch: A dictionary of data. - requires `batch_idx` as default None. - """ # noqa: D205 + batch: A dictionary of data. requires `batch_idx` as default None. + batch_idx: The index of the batch. + """ return self(batch, batch_idx) def training_loss_reduction(self) -> MegatronLossReduction: # noqa: D102