Skip to content

Add CallbackGroup & Metadata factory function #13437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 43 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
642e360
feat: add callback group definition & callback ABC
PytLab May 5, 2025
1badf29
Apply isort and black reformatting
PytLab May 5, 2025
3bf3367
feat: insert callback functions of CallbackGroup
PytLab May 6, 2025
2b51e12
Apply isort and black reformatting
PytLab May 6, 2025
249dad3
chore: PR test for jiashang
liquor233 May 7, 2025
db2b15d
feat: use __init_subclass__ to cover all ModelPT subclasses
PytLab May 12, 2025
d921d64
Apply isort and black reformatting
PytLab May 12, 2025
3e32f1a
feat: Adding metadata config manager poc
May 12, 2025
e1074f6
Apply isort and black reformatting
sajup-oss May 12, 2025
d79f4f1
feat: revert test changes.
liquor233 May 13, 2025
263f7e9
fix: Updating metadata attributes
sajup-oss May 21, 2025
81cd1d9
fix: Merging changes
sajup-oss May 21, 2025
4852936
Apply isort and black reformatting
sajup-oss May 21, 2025
48d6d87
fix: Adding OneloggerCallback
sajup-oss May 22, 2025
2ba6cc5
fix: Reverting changes in examples/multimodal/speech_llm/modular_audi…
sajup-oss May 22, 2025
c908b53
fix: Merge branch 'zshao/add_callback_group' of github.com:NVIDIA/NeM…
sajup-oss May 23, 2025
bd39d8f
Apply isort and black reformatting
sajup-oss May 23, 2025
ba4e4a6
fix: update modular models and megatron GPT models
liquor233 May 26, 2025
515136c
Apply isort and black reformatting
liquor233 May 26, 2025
bc030f7
feat: add on_app_start and on_app_end
liquor233 May 26, 2025
2ed58f4
Apply isort and black reformatting
liquor233 May 26, 2025
35d2f2c
fix: Adding small test example for testing
sajup-oss May 26, 2025
ddc99fb
Apply isort and black reformatting
sajup-oss May 26, 2025
ca6ff4d
fix: Fixing review comments as discussed with Jiashang
May 26, 2025
9f11d01
Apply isort and black reformatting
sajup-oss May 26, 2025
64e0e03
fix: updating nemo code to v2
sajup-oss Jun 13, 2025
181bb3e
fix: updating code to v2
sajup-oss Jun 13, 2025
61d631c
Apply isort and black reformatting
sajup-oss Jun 13, 2025
8eb4fc6
fix: updating wandb to get info from env
sajup-oss Jun 13, 2025
2900246
fix: updating wandb to get info from env
sajup-oss Jun 13, 2025
4acbc2c
Apply isort and black reformatting
sajup-oss Jun 13, 2025
dffccfa
fix: fix som impl issue
liquor233 Jul 4, 2025
60eb727
Apply isort and black reformatting
liquor233 Jul 4, 2025
b97fbda
fix: fix issue for exp manager.
liquor233 Jul 7, 2025
5c144ed
feat: Merge branch 'zshao/add_callback_group' of https://github.com/N…
liquor233 Jul 7, 2025
041a32b
Apply isort and black reformatting
liquor233 Jul 7, 2025
b70f85b
feat: remove callback_group
liquor233 Jul 10, 2025
f473d1b
feat: fix timingtracker issue
liquor233 Jul 10, 2025
1705b19
Apply isort and black reformatting
liquor233 Jul 10, 2025
e6b4e64
feat: fix for startup callbcaks
liquor233 Jul 14, 2025
5b7bd1c
Apply isort and black reformatting
liquor233 Jul 14, 2025
c687003
feat: change to adapter
liquor233 Jul 14, 2025
42181c5
Apply isort and black reformatting
liquor233 Jul 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -530,3 +530,6 @@ branch](https://github.com/NVIDIA/NeMo/tree/gh-pages-src#readme).
AGREEMENT](https://www.nvidia.com/en-us/data-center/products/nvidia-ai-enterprise/eula/).
By pulling and using the container, you accept the terms and
conditions of this license.


PR test
6 changes: 6 additions & 0 deletions examples/multimodal/speech_llm/modular_audio_gpt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@

import torch.multiprocessing as mp
from omegaconf.omegaconf import OmegaConf, open_dict
from one_logger_utils.nemo import OneLoggerNeMoCallback

from nemo.collections.multimodal.speech_llm.models.modular_models import ModularAudioGPTModel
from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder
from nemo.core.config import hydra_runner
from nemo.utils import logging, model_utils
from nemo.utils.callback_group import init_global_callback_group
from nemo.utils.exp_manager import exp_manager
from nemo.utils.meta_info_manager import MetaInfoManager

mp.set_start_method("spawn", force=True)

Expand Down Expand Up @@ -53,6 +56,9 @@ def main(cfg) -> None:
with open_dict(cfg):
cfg.model.precision = cfg.trainer.precision

one_logger_cb = OneLoggerNeMoCallback(callback_config=MetaInfoManager(cfg).get_metadata())
init_global_callback_group(callbacks=[one_logger_cb])

precision = cfg.trainer.precision
trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer()
cfg.trainer.precision = precision
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
)
from apex.transformer.pipeline_parallel.utils import get_micro_batch_size, get_num_microbatches


__all__ = ["ModularAudioGPTModel", "CrossAttendModularAudioGPTModel"]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,6 @@ def setup_mcore_distributed_parallel(self):
# by calling model_module.broadcast_params() if the model is randomly initialized.

def configure_optimizers(self):

if self.with_distributed_adam and not self.use_mcore_dist_optim:

# Special handling for embedding grads
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
get_num_microbatches,
)


__all__ = ['MegatronGPTSFTModel']


Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/nlp/parts/megatron_trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
NLPFSDPStrategy,
PipelineMixedPrecisionPlugin,
)
from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup
from nemo.utils import logging
from nemo.utils.callbacks.dist_ckpt_io import (
AsyncFinalizableCheckpointIO,
Expand Down Expand Up @@ -199,6 +200,7 @@ def create_trainer(self, callbacks=None) -> Trainer:
precision = self.cfg.trainer.precision
strategy = self._training_strategy()
plugins = self._plugins()
callbacks.extend(CallbackGroup.get_instance().callbacks)
callbacks = self._callbacks(callbacks)
trainer = Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer, callbacks=callbacks)
# Restore the precision value after Trainer is built.
Expand Down Expand Up @@ -227,6 +229,7 @@ def _callbacks(self, callbacks: Optional[list]) -> list:
def create_trainer(self, callbacks=None) -> Trainer:
strategy = self._training_strategy()
plugins = self._plugins()
callbacks.extend(CallbackGroup.get_instance().callbacks)
callbacks = self._callbacks(callbacks)
return Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer, callbacks=callbacks)

Expand Down
4 changes: 4 additions & 0 deletions nemo/core/classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Interfaces common to all Neural Modules and Models."""
from __future__ import annotations

import copy
import hashlib
import inspect
Expand Down Expand Up @@ -42,6 +43,7 @@
from nemo.core.config.templates.model_card import NEMO_DEFAULT_MODEL_CARD_TEMPLATE
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
from nemo.core.neural_types import NeuralType, NeuralTypeComparisonResult
from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup
from nemo.utils import logging
from nemo.utils.cloud import maybe_download_from_cloud
from nemo.utils.data_utils import resolve_cache_dir
Expand Down Expand Up @@ -739,6 +741,7 @@ def from_pretrained(
Returns:
A model instance of a particular model class or its underlying config (if return_config is set).
"""
CallbackGroup.get_instance().on_load_checkpoint_start()
if save_restore_connector is None:
save_restore_connector = SaveRestoreConnector()

Expand Down Expand Up @@ -772,6 +775,7 @@ def from_pretrained(
trainer=trainer,
save_restore_connector=save_restore_connector,
)
CallbackGroup.get_instance().on_load_checkpoint_end()
return instance

@classmethod
Expand Down
5 changes: 5 additions & 0 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from nemo.core.classes.common import Model
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
from nemo.core.optim import McoreDistributedOptimizer, prepare_lr_scheduler
from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup, wrap_methods_with_callbacks

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'CallbackGroup' is not used.

Copilot Autofix

AI 12 days ago

To fix the issue, the unused import of CallbackGroup should be removed from the file. This will eliminate unnecessary clutter and improve code readability. The removal should be limited to the specific import statement flagged by CodeQL, ensuring no unintended changes to other parts of the code.

Suggested changeset 1
nemo/core/classes/modelPT.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py
--- a/nemo/core/classes/modelPT.py
+++ b/nemo/core/classes/modelPT.py
@@ -50,3 +50,3 @@
 from nemo.core.optim import McoreDistributedOptimizer, prepare_lr_scheduler
-from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup, wrap_methods_with_callbacks
+from nemo.lightning.pytorch.callbacks.callback_group import wrap_methods_with_callbacks
 from nemo.utils import logging, model_utils
EOF
@@ -50,3 +50,3 @@
from nemo.core.optim import McoreDistributedOptimizer, prepare_lr_scheduler
from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup, wrap_methods_with_callbacks
from nemo.lightning.pytorch.callbacks.callback_group import wrap_methods_with_callbacks
from nemo.utils import logging, model_utils
Copilot is powered by AI and may make mistakes. Always verify output.
Unable to commit as this autofix suggestion is now outdated
from nemo.utils import logging, model_utils
from nemo.utils.app_state import AppState
from nemo.utils.debug_hook import register_debug_hooks
Expand Down Expand Up @@ -224,6 +225,7 @@

def __init_subclass__(cls) -> None:
cls._save_restore_connector = SaveRestoreConnector()
wrap_methods_with_callbacks(cls)

def on_fit_start(self) -> None:
if self.cfg.get("dump_debug_info", False):
Expand Down Expand Up @@ -2126,3 +2128,6 @@
return copy.deepcopy(optim_config)
else:
return OmegaConf.create(optim_config)


ModelPT = wrap_setup_training_data(ModelPT)
190 changes: 190 additions & 0 deletions nemo/lightning/pytorch/callbacks/callback_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import functools
from typing import Callable

from lightning.pytorch.callbacks import Callback


class CallbackGroup:
"""A class for hosting a collection of callback objects.

It is used to execute callback functions of multiple callback objects with the same method name.
When callbackgroup.func(args) is executed, internally it loops through the objects in
self._callbacks and runs self._callbacks[0].func(args), self._callbacks[1].func(args), etc.
The method name and arguments should match.

Attributes:
_callbacks (list[Callback]): List of callback objects.
"""

_instance = None

@classmethod
def get_instance(cls) -> 'CallbackGroup':
"""Get the singleton instance of the CallbackGroup.
Args:
cls (CallbackGroup): The class of the CallbackGroup.
Returns:
CallbackGroup: The singleton instance of the CallbackGroup.
"""
if cls._instance is None:
cls._instance = CallbackGroup()
return cls._instance

def __init__(self) -> None:
"""Initializes the list of callback objects."""
self._callbacks = []

def register(self, callback: Callback) -> None:
"""Register a callback to the callback group.

Args:
callback (Callback): The callback to register.
"""
self._callbacks.append(callback)

def __getattr__(self, method_name: str) -> Callable:
"""Loops through the callback objects to call the corresponding callback function.

Args:
method_name (str): Callback method name.
"""

def multi_callback_wrapper(*args, **kwargs) -> None:
for callback in self._callbacks:
assert hasattr(callback, method_name)
method = getattr(callback, method_name)
assert callable(method)
_ = method(*args, **kwargs)

return multi_callback_wrapper

@property
def callbacks(self):
"""Return callbacks in order.

Returns:
list: callback objects
"""
return self._callbacks


class Callback(Callback):
"""The base class for all callbacks. It inherits the pytorch lightning callback so the callback can be also passed to PTL trainer to reuse.
Below list extra callback functions in NeMo.
"""

def on_dataloader_init_start(self):
"""Called at the start of the data loading."""

def on_dataloader_init_end(self):
"""Called at the end of the data loading."""

def on_model_init_start(self):
"""Called at the start of the model initialization."""

def on_model_init_end(self):
"""Called at the end of the model initialization."""

def on_optimizer_init_start(self) -> None:
"""Called at the beginning of optimizer initialization."""

def on_optimizer_init_end(self) -> None:
"""Called at the end of optimizer initialization."""

def on_load_checkpoint_start(self) -> None:
"""Called at the beginning of loading checkpoint."""

def on_load_checkpoint_end(self) -> None:
"""Called at the end of loading checkpoint."""

def on_save_checkpoint_start(self, iteration: int = 0) -> None:
"""Called when start saving a checkpoint."""

def on_save_checkpoint_end(self, iteration: int = 0) -> None:
"""Called when saving checkpoint (sync part) call ends."""

def on_save_checkpoint_success(self, iteration: int = 0) -> None:
"""Called when checkpoint is saved successfully."""


CB_WRAP_RULES = {
# The function name is the name of the method to wrap.
# The start_hook and end_hook are the names of the methods to call before and after the original method.
# The callback_method_name is the name of the method to call in the callback group.
# Example:
# function name: {
# "start_hook": callback_method_name,
# "end_hook": callback_method_name
# }
"setup_training_data": {"start_hook": "on_dataloader_init_start", "end_hook": "on_dataloader_init_end"},
"setup_optimization": {"start_hook": "on_optimizer_init_start", "end_hook": "on_optimizer_init_end"},
"restore_from_pretrained_models": {"start_hook": "on_load_checkpoint_start", "end_hook": "on_load_checkpoint_end"},
"__init__": {"start_hook": "on_model_init_start", "end_hook": "on_model_init_end"},
"configure_optimizers": {"start_hook": "on_optimizer_init_start", "end_hook": "on_optimizer_init_end"},
"setup_training_dataloader": {"start_hook": "on_dataloader_init_start", "end_hook": "on_dataloader_init_end"},
}


def _make_callback_wrapped_method(original_method):
"""Wrap a method with the start and end hooks of the callback group.

Args:
original_method (Callable): The original method to wrap.
hooks (dict): The hooks to call.
"""
callback_group = CallbackGroup.get_instance()
hooks = CB_WRAP_RULES.get(original_method.__name__)

is_classmethod = isinstance(original_method, classmethod)

if not hooks:
return original_method

@functools.wraps(original_method)
def wrapped_instance_method(self, *args, **kwargs):
if hasattr(callback_group, hooks["start_hook"]):
getattr(callback_group, hooks["start_hook"])()
result = original_method(self, *args, **kwargs)
if hasattr(callback_group, hooks["end_hook"]):
getattr(callback_group, hooks["end_hook"])()
return result

@functools.wraps(original_method)
def wrapped_class_method(*args, **kwargs):
if hasattr(callback_group, hooks["start_hook"]):
getattr(callback_group, hooks["start_hook"])()
result = original_method(*args, **kwargs)
if hasattr(callback_group, hooks["end_hook"]):
getattr(callback_group, hooks["end_hook"])()
return result

if is_classmethod:
return classmethod(wrapped_class_method)
else:
return wrapped_instance_method


def wrap_methods_with_callbacks(cls) -> None:
"""Wrap class/instance methods with the start and end hooks of the callback group.

Args:
cls (type): The class to wrap the methods of.
"""
for method_name in CB_WRAP_RULES.keys():
if method_name in cls.__dict__:
original_method = cls.__dict__[method_name]
cls.__dict__[method_name] = _make_callback_wrapped_method(original_method)
5 changes: 5 additions & 0 deletions nemo/lightning/pytorch/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from nemo.lightning.ckpt_utils import ckpt_to_dir
from nemo.lightning.io.pl import TrainerContext
from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup
from nemo.utils import logging
from nemo.utils.app_state import AppState

Expand Down Expand Up @@ -567,6 +568,7 @@ def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str)
ValueError: (mcore) async_save with EMA not supported
ValueError: (mcore) Async save requires async compatible CheckpointIO
"""
CallbackGroup.get_instance().on_save_checkpoint_start()

from nemo.utils.get_rank import is_global_rank_zero

Expand Down Expand Up @@ -598,6 +600,7 @@ def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str)
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}")
super()._save_checkpoint(trainer, filepath)
self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True)
CallbackGroup.get_instance().on_save_checkpoint_success(global_step=trainer.global_step)
else:
# Determine whether to include optimizer states in the checkpoint
# optimizer states are included when
Expand Down Expand Up @@ -632,6 +635,7 @@ def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str)
logging.info(f'Scheduled async checkpoint save for {filepath}')
else:
finalize_fn()
CallbackGroup.get_instance().on_save_checkpoint_end(global_step=trainer.global_step)

def _get_finalize_save_checkpoint_callback(
self, trainer: 'lightning.pytorch.Trainer', filepath: str, global_step: int
Expand All @@ -655,6 +659,7 @@ def _cb():
return

logging.info(f'Async checkpoint save for step {global_step} ({filepath}) finalized successfully.')
CallbackGroup.get_instance().on_save_checkpoint_success(global_step=trainer.global_step)

if str(filepath) in self.ckpts_to_link:
self._link_checkpoint(trainer, filepath, self.ckpts_to_link.pop(filepath), override_async=True)
Expand Down
Loading
Loading