Skip to content

Add CallbackGroup & Metadata factory function #13437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 33 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
642e360
feat: add callback group definition & callback ABC
PytLab May 5, 2025
1badf29
Apply isort and black reformatting
PytLab May 5, 2025
3bf3367
feat: insert callback functions of CallbackGroup
PytLab May 6, 2025
2b51e12
Apply isort and black reformatting
PytLab May 6, 2025
249dad3
chore: PR test for jiashang
liquor233 May 7, 2025
db2b15d
feat: use __init_subclass__ to cover all ModelPT subclasses
PytLab May 12, 2025
d921d64
Apply isort and black reformatting
PytLab May 12, 2025
3e32f1a
feat: Adding metadata config manager poc
May 12, 2025
e1074f6
Apply isort and black reformatting
sajup-oss May 12, 2025
d79f4f1
feat: revert test changes.
liquor233 May 13, 2025
263f7e9
fix: Updating metadata attributes
sajup-oss May 21, 2025
81cd1d9
fix: Merging changes
sajup-oss May 21, 2025
4852936
Apply isort and black reformatting
sajup-oss May 21, 2025
48d6d87
fix: Adding OneloggerCallback
sajup-oss May 22, 2025
2ba6cc5
fix: Reverting changes in examples/multimodal/speech_llm/modular_audi…
sajup-oss May 22, 2025
c908b53
fix: Merge branch 'zshao/add_callback_group' of github.com:NVIDIA/NeM…
sajup-oss May 23, 2025
bd39d8f
Apply isort and black reformatting
sajup-oss May 23, 2025
ba4e4a6
fix: update modular models and megatron GPT models
liquor233 May 26, 2025
515136c
Apply isort and black reformatting
liquor233 May 26, 2025
bc030f7
feat: add on_app_start and on_app_end
liquor233 May 26, 2025
2ed58f4
Apply isort and black reformatting
liquor233 May 26, 2025
35d2f2c
fix: Adding small test example for testing
sajup-oss May 26, 2025
ddc99fb
Apply isort and black reformatting
sajup-oss May 26, 2025
ca6ff4d
fix: Fixing review comments as discussed with Jiashang
May 26, 2025
9f11d01
Apply isort and black reformatting
sajup-oss May 26, 2025
64e0e03
fix: updating nemo code to v2
sajup-oss Jun 13, 2025
181bb3e
fix: updating code to v2
sajup-oss Jun 13, 2025
61d631c
Apply isort and black reformatting
sajup-oss Jun 13, 2025
8eb4fc6
fix: updating wandb to get info from env
sajup-oss Jun 13, 2025
2900246
fix: updating wandb to get info from env
sajup-oss Jun 13, 2025
4acbc2c
Apply isort and black reformatting
sajup-oss Jun 13, 2025
dffccfa
fix: fix som impl issue
liquor233 Jul 4, 2025
60eb727
Apply isort and black reformatting
liquor233 Jul 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions nemo/collections/asr/models/ssl_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
NeuralType,
SpectrogramType,
)
from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup
from nemo.utils import logging

__all__ = ['SpeechEncDecSelfSupervisedModel', 'EncDecMaskedTokenPredModel', 'EncDecDenoiseMaskedTokenPredModel']
Expand Down Expand Up @@ -245,6 +246,7 @@ def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict
- :class:`~nemo.collections.asr.data.audio_to_text.TarredAudioToBPEDataset`
- :class:`~nemo.collections.asr.data.audio_to_text_dali.AudioToCharDALIDataset`
"""
CallbackGroup.get_instance().on_dataloader_init_start()
if 'shuffle' not in train_data_config:
train_data_config['shuffle'] = True

Expand Down Expand Up @@ -274,6 +276,7 @@ def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict
"Model Trainer was not set before constructing the dataset, incorrect number of "
"training batches will be used. Please set the trainer and rebuild the dataset."
)
CallbackGroup.get_instance().on_dataloader_init_end()

def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@
)
from apex.transformer.pipeline_parallel.utils import get_micro_batch_size, get_num_microbatches

from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup

__all__ = ["ModularAudioGPTModel", "CrossAttendModularAudioGPTModel"]


Expand Down Expand Up @@ -931,6 +933,7 @@ def restore_from_pretrained_models(
cfg: input yaml config, with trainer, model, exp_manager, etc.
trainer: trainer object
"""
CallbackGroup.get_instance().on_load_checkpoint_start()
if (
cfg.model.get("pretrained_audio_model", None) is None
and cfg.model.perception.get("encoders", None) is None
Expand Down Expand Up @@ -977,6 +980,7 @@ def restore_from_pretrained_models(
if 'inference' in cfg:
inference_cfg = OmegaConf.to_container(cfg.inference, resolve=True)
model.set_inference_config(inference_cfg)
CallbackGroup.get_instance().on_load_checkpoint_end()
return model

@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@
update_num_microbatches,
)

from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup

transformer_engine, HAVE_TE = safe_import("transformer_engine")
te_module, HAVE_TE_MODULE = safe_import_from("transformer_engine.pytorch", "module")
get_gpt_layer_with_te_and_hyena_spec, HAVE_HYENA_SPEC = safe_import_from(
Expand Down Expand Up @@ -625,7 +627,7 @@ def setup_mcore_distributed_parallel(self):
# by calling model_module.broadcast_params() if the model is randomly initialized.

def configure_optimizers(self):

CallbackGroup.get_instance().on_optimizer_init_start()
if self.with_distributed_adam and not self.use_mcore_dist_optim:

# Special handling for embedding grads
Expand Down Expand Up @@ -708,6 +710,7 @@ def make_parameter_bucket(module: torch.nn.Module) -> List[torch.nn.Parameter]:
used_params = set(itertools.chain.from_iterable(buckets))
buckets[-1].extend(p for p in self.parameters() if p not in used_params and p.requires_grad)
self.distributed_adam_buckets = buckets
CallbackGroup.get_instance().on_optimizer_init_end()

return super().configure_optimizers()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
get_num_microbatches,
)

from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup

__all__ = ['MegatronGPTSFTModel']

Expand Down Expand Up @@ -877,13 +878,15 @@ def build_data_loader(self, dataset, data_cfg, consumed_samples=0):
)

def setup_training_dataloader(self):
CallbackGroup.get_instance().on_dataloader_init_start()
if hasattr(self, '_train_ds'):
consumed_samples = self.compute_consumed_samples(0)
self._train_dl = self.build_data_loader(
dataset=self._train_ds,
data_cfg=self.cfg.data.train_ds,
consumed_samples=consumed_samples,
)
CallbackGroup.get_instance().on_dataloader_init_end()

def setup_eval_dataloader(self, datasets, data_cfg):
dataloaders = []
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/nlp/parts/megatron_trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
NLPFSDPStrategy,
PipelineMixedPrecisionPlugin,
)
from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup
from nemo.utils import logging
from nemo.utils.callbacks.dist_ckpt_io import (
AsyncFinalizableCheckpointIO,
Expand Down Expand Up @@ -227,6 +228,7 @@ def _callbacks(self, callbacks: Optional[list]) -> list:
def create_trainer(self, callbacks=None) -> Trainer:
strategy = self._training_strategy()
plugins = self._plugins()
callbacks.extend(CallbackGroup.get_instance().callbacks)
callbacks = self._callbacks(callbacks)
return Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer, callbacks=callbacks)

Expand Down
2 changes: 2 additions & 0 deletions nemo/core/classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Interfaces common to all Neural Modules and Models."""
from __future__ import annotations

import copy
import hashlib
import inspect
Expand Down Expand Up @@ -42,6 +43,7 @@
from nemo.core.config.templates.model_card import NEMO_DEFAULT_MODEL_CARD_TEMPLATE
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
from nemo.core.neural_types import NeuralType, NeuralTypeComparisonResult
from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup
from nemo.utils import logging
from nemo.utils.cloud import maybe_download_from_cloud
from nemo.utils.data_utils import resolve_cache_dir
Expand Down
5 changes: 5 additions & 0 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from nemo.core.classes.common import Model
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
from nemo.core.optim import McoreDistributedOptimizer, prepare_lr_scheduler
from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup
from nemo.utils import logging, model_utils
from nemo.utils.app_state import AppState
from nemo.utils.debug_hook import register_debug_hooks
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

trainer (Optional): Pytorch Lightning Trainer instance
"""
CallbackGroup.get_instance().on_model_init_start()
if trainer is not None and not isinstance(trainer, Trainer):
raise ValueError(
f"trainer constructor argument must be either None or lightning.pytorch.Trainer. "
Expand Down Expand Up @@ -221,6 +223,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

# A flag for the profile generation
self._chakra_profile_in_progress = False
CallbackGroup.get_instance().on_model_init_end()

def __init_subclass__(cls) -> None:
cls._save_restore_connector = SaveRestoreConnector()
Expand Down Expand Up @@ -658,6 +661,7 @@ def setup_optimization(
compatible with OmegaConf.

"""
CallbackGroup.get_instance().on_optimizer_init_start()
# Setup the optimizer parameter groups (by default use all parameters that are trainable)
self.setup_optimizer_param_groups()

Expand Down Expand Up @@ -806,6 +810,7 @@ def setup_optimization(
self._scheduler = prepare_lr_scheduler(
optimizer=self._optimizer, scheduler_config=scheduler_config, train_dataloader=self._train_dl
)
CallbackGroup.get_instance().on_optimizer_init_end()

# Return the optimizer with/without scheduler
# This return allows multiple optimizers or schedulers to be created
Expand Down
116 changes: 116 additions & 0 deletions nemo/lightning/pytorch/callbacks/callback_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable

from lightning.pytorch.callbacks import Callback


class CallbackGroup:
"""A class for hosting a collection of callback objects.

It is used to execute callback functions of multiple callback objects with the same method name.
When callbackgroup.func(args) is executed, internally it loops through the objects in
self._callbacks and runs self._callbacks[0].func(args), self._callbacks[1].func(args), etc.
The method name and arguments should match.

Attributes:
_callbacks (list[Callback]): List of callback objects.
"""

_instance = None

@classmethod
def get_instance(cls) -> 'CallbackGroup':
"""Get the singleton instance of the CallbackGroup.
Args:
cls (CallbackGroup): The class of the CallbackGroup.
Returns:
CallbackGroup: The singleton instance of the CallbackGroup.
"""
if cls._instance is None:
cls._instance = CallbackGroup()
return cls._instance

def __init__(self, callbacks: list[Callback] | None) -> None:
"""Initializes the list of callback objects.

Args:
callbacks (list[Callback]): List of callbacks.
"""
self._callbacks = callbacks or []

def __getattr__(self, method_name: str) -> Callable:
"""Loops through the callback objects to call the corresponding callback function.

Args:
method_name (str): Callback method name.
"""

def multi_callback_wrapper(*args, **kwargs) -> None:
for callback in self._callbacks:
assert hasattr(callback, method_name)
method = getattr(callback, method_name)
assert callable(method)
_ = method(*args, **kwargs)

return multi_callback_wrapper

@property
def callbacks(self):
"""Return callbacks in order.

Returns:
list: callback objects
"""
return self._callbacks


class Callback(Callback):
"""The base class for all callbacks. It inherits the pytorch lightning callback so the callback can be also passed to PTL trainer to reuse.
Below list extra callback functions in NeMo.
"""

def on_dataloader_init_start(self):
"""Called at the start of the data loading."""

def on_dataloader_init_end(self):
"""Called at the end of the data loading."""

def on_model_init_start(self):
"""Called at the start of the model initialization."""

def on_model_init_end(self):
"""Called at the end of the model initialization."""

def on_optimizer_init_start(self) -> None:
"""Called at the beginning of optimizer initialization."""

def on_optimizer_init_end(self) -> None:
"""Called at the end of optimizer initialization."""

def on_load_checkpoint_start(self) -> None:
"""Called at the beginning of loading checkpoint."""

def on_load_checkpoint_end(self) -> None:
"""Called at the end of loading checkpoint."""

def on_save_checkpoint_start(self, iteration: int = 0) -> None:
"""Called when start saving a checkpoint."""

def on_save_checkpoint_end(self, iteration: int = 0) -> None:
"""Called when saving checkpoint (sync part) call ends."""

def on_save_checkpoint_success(self, iteration: int = 0) -> None:
"""Called when checkpoint is saved successfully."""
5 changes: 5 additions & 0 deletions nemo/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from nemo.lightning.ckpt_utils import ckpt_to_dir
from nemo.lightning.io.pl import TrainerContext
from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup
from nemo.utils import logging
from nemo.utils.app_state import AppState

Expand Down Expand Up @@ -567,6 +568,7 @@ def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str)
ValueError: (mcore) async_save with EMA not supported
ValueError: (mcore) Async save requires async compatible CheckpointIO
"""
CallbackGroup.get_instance().on_save_checkpoint_start()

from nemo.utils.get_rank import is_global_rank_zero

Expand Down Expand Up @@ -598,6 +600,7 @@ def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str)
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
super()._save_checkpoint(trainer, filepath)
self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True)
CallbackGroup.get_instance().on_save_checkpoint_success(global_step=trainer.global_step)
else:
# Determine whether to include optimizer states in the checkpoint
# optimizer states are included when
Expand Down Expand Up @@ -632,6 +635,7 @@ def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str)
logging.info(f'Scheduled async checkpoint save for {filepath}')
else:
finalize_fn()
CallbackGroup.get_instance().on_save_checkpoint_end(global_step=trainer.global_step)

def _get_finalize_save_checkpoint_callback(
self, trainer: 'lightning.pytorch.Trainer', filepath: str, global_step: int
Expand All @@ -655,6 +659,7 @@ def _cb():
return

logging.info(f'Async checkpoint save for step {global_step} ({filepath}) finalized successfully.')
CallbackGroup.get_instance().on_save_checkpoint_success(global_step=trainer.global_step)

if str(filepath) in self.ckpts_to_link:
self._link_checkpoint(trainer, filepath, self.ckpts_to_link.pop(filepath), override_async=True)
Expand Down
Loading