Skip to content

Add CallbackGroup & Metadata factory function #13437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 33 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
642e360
feat: add callback group definition & callback ABC
PytLab May 5, 2025
1badf29
Apply isort and black reformatting
PytLab May 5, 2025
3bf3367
feat: insert callback functions of CallbackGroup
PytLab May 6, 2025
2b51e12
Apply isort and black reformatting
PytLab May 6, 2025
249dad3
chore: PR test for jiashang
liquor233 May 7, 2025
db2b15d
feat: use __init_subclass__ to cover all ModelPT subclasses
PytLab May 12, 2025
d921d64
Apply isort and black reformatting
PytLab May 12, 2025
3e32f1a
feat: Adding metadata config manager poc
May 12, 2025
e1074f6
Apply isort and black reformatting
sajup-oss May 12, 2025
d79f4f1
feat: revert test changes.
liquor233 May 13, 2025
263f7e9
fix: Updating metadata attributes
sajup-oss May 21, 2025
81cd1d9
fix: Merging changes
sajup-oss May 21, 2025
4852936
Apply isort and black reformatting
sajup-oss May 21, 2025
48d6d87
fix: Adding OneloggerCallback
sajup-oss May 22, 2025
2ba6cc5
fix: Reverting changes in examples/multimodal/speech_llm/modular_audi…
sajup-oss May 22, 2025
c908b53
fix: Merge branch 'zshao/add_callback_group' of github.com:NVIDIA/NeM…
sajup-oss May 23, 2025
bd39d8f
Apply isort and black reformatting
sajup-oss May 23, 2025
ba4e4a6
fix: update modular models and megatron GPT models
liquor233 May 26, 2025
515136c
Apply isort and black reformatting
liquor233 May 26, 2025
bc030f7
feat: add on_app_start and on_app_end
liquor233 May 26, 2025
2ed58f4
Apply isort and black reformatting
liquor233 May 26, 2025
35d2f2c
fix: Adding small test example for testing
sajup-oss May 26, 2025
ddc99fb
Apply isort and black reformatting
sajup-oss May 26, 2025
ca6ff4d
fix: Fixing review comments as discussed with Jiashang
May 26, 2025
9f11d01
Apply isort and black reformatting
sajup-oss May 26, 2025
64e0e03
fix: updating nemo code to v2
sajup-oss Jun 13, 2025
181bb3e
fix: updating code to v2
sajup-oss Jun 13, 2025
61d631c
Apply isort and black reformatting
sajup-oss Jun 13, 2025
8eb4fc6
fix: updating wandb to get info from env
sajup-oss Jun 13, 2025
2900246
fix: updating wandb to get info from env
sajup-oss Jun 13, 2025
4acbc2c
Apply isort and black reformatting
sajup-oss Jun 13, 2025
dffccfa
fix: fix som impl issue
liquor233 Jul 4, 2025
60eb727
Apply isort and black reformatting
liquor233 Jul 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions nemo/collections/nlp/parts/megatron_trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
NLPFSDPStrategy,
PipelineMixedPrecisionPlugin,
)
from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup
from nemo.utils import logging
from nemo.utils.callbacks.dist_ckpt_io import (
AsyncFinalizableCheckpointIO,
Expand Down Expand Up @@ -199,6 +200,9 @@ def create_trainer(self, callbacks=None) -> Trainer:
precision = self.cfg.trainer.precision
strategy = self._training_strategy()
plugins = self._plugins()
if callbacks is None:
callbacks = []
callbacks.extend(CallbackGroup.get_instance().callbacks)
callbacks = self._callbacks(callbacks)
trainer = Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer, callbacks=callbacks)
# Restore the precision value after Trainer is built.
Expand Down Expand Up @@ -227,6 +231,7 @@ def _callbacks(self, callbacks: Optional[list]) -> list:
def create_trainer(self, callbacks=None) -> Trainer:
strategy = self._training_strategy()
plugins = self._plugins()
callbacks.extend(CallbackGroup.get_instance().callbacks)
callbacks = self._callbacks(callbacks)
return Trainer(plugins=plugins, strategy=strategy, **self.cfg.trainer, callbacks=callbacks)

Expand Down
4 changes: 4 additions & 0 deletions nemo/core/classes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Interfaces common to all Neural Modules and Models."""
from __future__ import annotations

import copy
import hashlib
import inspect
Expand Down Expand Up @@ -42,6 +43,7 @@
from nemo.core.config.templates.model_card import NEMO_DEFAULT_MODEL_CARD_TEMPLATE
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
from nemo.core.neural_types import NeuralType, NeuralTypeComparisonResult
from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup
from nemo.utils import logging
from nemo.utils.cloud import maybe_download_from_cloud
from nemo.utils.data_utils import resolve_cache_dir
Expand Down Expand Up @@ -739,6 +741,7 @@ def from_pretrained(
Returns:
A model instance of a particular model class or its underlying config (if return_config is set).
"""
CallbackGroup.get_instance().on_load_checkpoint_start()
if save_restore_connector is None:
save_restore_connector = SaveRestoreConnector()

Expand Down Expand Up @@ -772,6 +775,7 @@ def from_pretrained(
trainer=trainer,
save_restore_connector=save_restore_connector,
)
CallbackGroup.get_instance().on_load_checkpoint_end()
return instance

@classmethod
Expand Down
2 changes: 2 additions & 0 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from nemo.core.classes.common import Model
from nemo.core.connectors.save_restore_connector import SaveRestoreConnector
from nemo.core.optim import McoreDistributedOptimizer, prepare_lr_scheduler
from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup, wrap_methods_with_callbacks

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'CallbackGroup' is not used.

Copilot Autofix

AI 2 days ago

To fix the issue, the unused import of CallbackGroup should be removed from the file. This will eliminate unnecessary clutter and improve code readability. The removal should be limited to the specific import statement flagged by CodeQL, ensuring no unintended changes to other parts of the code.

Suggested changeset 1
nemo/core/classes/modelPT.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py
--- a/nemo/core/classes/modelPT.py
+++ b/nemo/core/classes/modelPT.py
@@ -50,3 +50,3 @@
 from nemo.core.optim import McoreDistributedOptimizer, prepare_lr_scheduler
-from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup, wrap_methods_with_callbacks
+from nemo.lightning.pytorch.callbacks.callback_group import wrap_methods_with_callbacks
 from nemo.utils import logging, model_utils
EOF
@@ -50,3 +50,3 @@
from nemo.core.optim import McoreDistributedOptimizer, prepare_lr_scheduler
from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup, wrap_methods_with_callbacks
from nemo.lightning.pytorch.callbacks.callback_group import wrap_methods_with_callbacks
from nemo.utils import logging, model_utils
Copilot is powered by AI and may make mistakes. Always verify output.
Unable to commit as this autofix suggestion is now outdated
from nemo.utils import logging, model_utils
from nemo.utils.app_state import AppState
from nemo.utils.debug_hook import register_debug_hooks
Expand Down Expand Up @@ -224,6 +225,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

def __init_subclass__(cls) -> None:
cls._save_restore_connector = SaveRestoreConnector()
wrap_methods_with_callbacks(cls)

def on_fit_start(self) -> None:
if self.cfg.get("dump_debug_info", False):
Expand Down
2 changes: 2 additions & 0 deletions nemo/lightning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from nemo.lightning.fabric.plugins import FabricMegatronMixedPrecision
from nemo.lightning.fabric.strategies import FabricMegatronStrategy
from nemo.lightning.nemo_logger import NeMoLogger
from nemo.lightning.one_logger_callback import OneLoggerNeMoCallback
from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from nemo.lightning.pytorch.optim import (
LRSchedulerModule,
Expand Down Expand Up @@ -72,6 +73,7 @@ def _is_slurm_interactive_mode():
"lr_scheduler",
"NeMoLogger",
"ModelCheckpoint",
"OneLoggerNeMoCallback",
"OptimizerModule",
"Trainer",
"configure_no_restart_validation_training_loop",
Expand Down
25 changes: 25 additions & 0 deletions nemo/lightning/nemo_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,31 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], resume_if_exists: bool =
self._setup_trainer_loggers(trainer, _dir, version)
self._setup_trainer_model_checkpoint(trainer, log_dir=log_dir, ckpt=self.ckpt)

# Configure OneLogger callback
try:
from omegaconf import OmegaConf

from nemo.utils.exp_manager import configure_onelogger

# Create a minimal config for OneLogger
cfg = OmegaConf.create(
{
"exp_manager": {
"wandb_logger_kwargs": {
"project": "nemo_experiments",
"name": self.name,
"id": version or None,
}
}
}
)

# Configure OneLogger
configure_onelogger(cfg, trainer)
logging.info("OneLogger configured successfully")
except Exception as e:
logging.warning(f"Failed to configure OneLogger: {e}")

self._setup_files_to_move(log_dir, app_state)
self._setup_file_logging(log_dir)

Expand Down
150 changes: 150 additions & 0 deletions nemo/lightning/one_logger_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""
OneLogger callback for NeMo training.

This module provides a callback that integrates OneLogger telemetry with NeMo training.
"""

import functools
import logging

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'logging' is not used.

Copilot Autofix

AI 2 days ago

To fix the issue, the unused logging import on line 8 should be removed. This will eliminate the unnecessary dependency and make the code cleaner. No other changes are required, as the removal of this import does not affect the functionality of the code.

Suggested changeset 1
nemo/lightning/one_logger_callback.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/nemo/lightning/one_logger_callback.py b/nemo/lightning/one_logger_callback.py
--- a/nemo/lightning/one_logger_callback.py
+++ b/nemo/lightning/one_logger_callback.py
@@ -7,3 +7,3 @@
 import functools
-import logging
+
 from typing import Any, Dict, List, Optional, Type
EOF
@@ -7,3 +7,3 @@
import functools
import logging

from typing import Any, Dict, List, Optional, Type
Copilot is powered by AI and may make mistakes. Always verify output.
Unable to commit as this autofix suggestion is now outdated
from typing import Any, Dict, List, Optional, Type

import nv_one_logger.training_telemetry.api.callbacks as CB

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eric: it is better to publish to public index, if we have difficulties, make sure this won't fail with import guard, and maybe ref to the open sourced one logger repo.

import pytorch_lightning as pl

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'pl' is not used.

Copilot Autofix

AI 2 days ago

To fix the issue, the unused import statement import pytorch_lightning as pl should be removed from the file. This will eliminate the unnecessary dependency and improve code clarity. No other changes are required, as the rest of the code already uses specific imports from PyTorch Lightning directly.


Suggested changeset 1
nemo/lightning/one_logger_callback.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/nemo/lightning/one_logger_callback.py b/nemo/lightning/one_logger_callback.py
--- a/nemo/lightning/one_logger_callback.py
+++ b/nemo/lightning/one_logger_callback.py
@@ -11,3 +11,3 @@
 import nv_one_logger.training_telemetry.api.callbacks as CB
-import pytorch_lightning as pl
+
 import torch
EOF
@@ -11,3 +11,3 @@
import nv_one_logger.training_telemetry.api.callbacks as CB
import pytorch_lightning as pl

import torch
Copilot is powered by AI and may make mistakes. Always verify output.
Unable to commit as this autofix suggestion is now outdated
import torch

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'torch' is not used.

Copilot Autofix

AI 2 days ago

To fix the problem, the unused torch import on line 13 should be removed. This will eliminate the unnecessary dependency and make the code cleaner. No other changes are required since the removal of this import does not affect the functionality of the code.

Suggested changeset 1
nemo/lightning/one_logger_callback.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/nemo/lightning/one_logger_callback.py b/nemo/lightning/one_logger_callback.py
--- a/nemo/lightning/one_logger_callback.py
+++ b/nemo/lightning/one_logger_callback.py
@@ -12,3 +12,3 @@
 import pytorch_lightning as pl
-import torch
+
 from pytorch_lightning import Trainer
EOF
@@ -12,3 +12,3 @@
import pytorch_lightning as pl
import torch

from pytorch_lightning import Trainer
Copilot is powered by AI and may make mistakes. Always verify output.
Unable to commit as this autofix suggestion is now outdated
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.io import AsyncCheckpointIO

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'AsyncCheckpointIO' is not used.

Copilot Autofix

AI 2 days ago

To fix the problem, we will remove the unused import statement from pytorch_lightning.plugins.io import AsyncCheckpointIO on line 17. This will clean up the code and eliminate the unnecessary dependency without affecting the functionality of the program.


Suggested changeset 1
nemo/lightning/one_logger_callback.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/nemo/lightning/one_logger_callback.py b/nemo/lightning/one_logger_callback.py
--- a/nemo/lightning/one_logger_callback.py
+++ b/nemo/lightning/one_logger_callback.py
@@ -16,3 +16,3 @@
 from pytorch_lightning.core import LightningModule
-from pytorch_lightning.plugins.io import AsyncCheckpointIO
+
 from pytorch_lightning.utilities import rank_zero_only
EOF
@@ -16,3 +16,3 @@
from pytorch_lightning.core import LightningModule
from pytorch_lightning.plugins.io import AsyncCheckpointIO

from pytorch_lightning.utilities import rank_zero_only
Copilot is powered by AI and may make mistakes. Always verify output.
Unable to commit as this autofix suggestion is now outdated
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.types import STEP_OUTPUT


class OneLoggerNeMoCallback(Callback):
"""
NeMo callback that integrates with OneLogger v2 for tracking metrics.

This callback implements NeMo's callback group API and internally
uses OneLogger's training telemetry functionality to track metrics.
"""

def __init__(
self,
callback_config: Optional[Dict[str, Any]] = None,
log_interval: int = 1,
async_io_checkpoint_classes: List[Type[Any]] | None = None,
):
"""
Initialize the OneLogger NeMo callback.

Args:
callback_config (dict): Configuration dictionary with metadata
from MetaInfoManager(cfg).get_metadata()
log_interval (int): How often to log metrics
async_io_checkpoint_classes (List[Type]): Additional classes to identify as async checkpoints
"""
super().__init__()
self.log_interval = log_interval
self.async_io_checkpoint_classes = async_io_checkpoint_classes or []
self.state = {
"is_async_checkpoint": None,
}

# Extract configuration values
if callback_config is not None:
self.app_name = callback_config.get("app_name", "")
self.perf_tag = callback_config.get("perf_tag", "")
self.session_tag = callback_config.get("session_tag", "")
self.global_batch_size = callback_config.get("global_batch_size", 0)
else:
self.app_name = ""
self.perf_tag = ""
self.session_tag = ""
self.global_batch_size = 0

def __getattr__(self, name: str) -> Any:
"""Automatically forward any undefined method calls to the OneLogger v2 callbacks mainly for non-trainer methods.

This eliminates the need for manually writing pass-through methods for each OneLogger API.
Only methods that need custom logic (like those interacting with the trainer) need to be
explicitly defined in this class.

Args:
name: The name of the method being called
Returns:
The method from the OneLogger v2 callbacks
Raises:
AttributeError: If the method is not found in the OneLogger callbacks
"""
# Check if the method exists in the OneLogger callbacks module
if hasattr(CB, name):
# Get the original method
original_method = getattr(CB, name)

# Create a wrapper that adds rank_zero_only decorator
@functools.wraps(original_method)
def wrapper(*args, **kwargs):
return rank_zero_only(original_method)(*args, **kwargs)

return wrapper

# If not found, raise AttributeError as normal
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")

@rank_zero_only
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Called when training begins."""
# Extract necessary information from the trainer
current_step = trainer.global_step
max_steps = trainer.max_steps if hasattr(trainer, 'max_steps') else 0

CB.on_train_start(train_iterations_start=current_step, train_iterations_target_or_fn=max_steps)

@rank_zero_only
def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
CB.on_train_end()

@rank_zero_only
def on_train_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int) -> None:
CB.on_training_single_iteration_start()

@rank_zero_only
def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
) -> None:
CB.on_training_single_iteration_end()

@rank_zero_only
def on_validation_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
CB.on_validation_start()

@rank_zero_only
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
CB.on_validation_end()

@rank_zero_only
def on_validation_batch_start(
self,
trainer: Trainer,
pl_module: LightningModule,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
CB.on_validation_single_iteration_start()

@rank_zero_only
def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs: STEP_OUTPUT,
batch: Any,
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
CB.on_validation_single_iteration_end()
Loading
Loading