-
Notifications
You must be signed in to change notification settings - Fork 3k
Add CallbackGroup & Metadata factory function #13437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
PytLab
wants to merge
43
commits into
main
Choose a base branch
from
zshao/add_callback_group
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+808
−2
Draft
Changes from 9 commits
Commits
Show all changes
43 commits
Select commit
Hold shift + click to select a range
642e360
feat: add callback group definition & callback ABC
PytLab 1badf29
Apply isort and black reformatting
PytLab 3bf3367
feat: insert callback functions of CallbackGroup
PytLab 2b51e12
Apply isort and black reformatting
PytLab 249dad3
chore: PR test for jiashang
liquor233 db2b15d
feat: use __init_subclass__ to cover all ModelPT subclasses
PytLab d921d64
Apply isort and black reformatting
PytLab 3e32f1a
feat: Adding metadata config manager poc
e1074f6
Apply isort and black reformatting
sajup-oss d79f4f1
feat: revert test changes.
liquor233 263f7e9
fix: Updating metadata attributes
sajup-oss 81cd1d9
fix: Merging changes
sajup-oss 4852936
Apply isort and black reformatting
sajup-oss 48d6d87
fix: Adding OneloggerCallback
sajup-oss 2ba6cc5
fix: Reverting changes in examples/multimodal/speech_llm/modular_audi…
sajup-oss c908b53
fix: Merge branch 'zshao/add_callback_group' of github.com:NVIDIA/NeM…
sajup-oss bd39d8f
Apply isort and black reformatting
sajup-oss ba4e4a6
fix: update modular models and megatron GPT models
liquor233 515136c
Apply isort and black reformatting
liquor233 bc030f7
feat: add on_app_start and on_app_end
liquor233 2ed58f4
Apply isort and black reformatting
liquor233 35d2f2c
fix: Adding small test example for testing
sajup-oss ddc99fb
Apply isort and black reformatting
sajup-oss ca6ff4d
fix: Fixing review comments as discussed with Jiashang
9f11d01
Apply isort and black reformatting
sajup-oss 64e0e03
fix: updating nemo code to v2
sajup-oss 181bb3e
fix: updating code to v2
sajup-oss 61d631c
Apply isort and black reformatting
sajup-oss 8eb4fc6
fix: updating wandb to get info from env
sajup-oss 2900246
fix: updating wandb to get info from env
sajup-oss 4acbc2c
Apply isort and black reformatting
sajup-oss dffccfa
fix: fix som impl issue
liquor233 60eb727
Apply isort and black reformatting
liquor233 b97fbda
fix: fix issue for exp manager.
liquor233 5c144ed
feat: Merge branch 'zshao/add_callback_group' of https://github.com/N…
liquor233 041a32b
Apply isort and black reformatting
liquor233 b70f85b
feat: remove callback_group
liquor233 f473d1b
feat: fix timingtracker issue
liquor233 1705b19
Apply isort and black reformatting
liquor233 e6b4e64
feat: fix for startup callbcaks
liquor233 5b7bd1c
Apply isort and black reformatting
liquor233 c687003
feat: change to adapter
liquor233 42181c5
Apply isort and black reformatting
liquor233 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,7 +68,6 @@ | |
get_num_microbatches, | ||
) | ||
|
||
|
||
__all__ = ['MegatronGPTSFTModel'] | ||
|
||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import functools | ||
from typing import Callable | ||
|
||
from lightning.pytorch.callbacks import Callback | ||
|
||
|
||
class CallbackGroup: | ||
"""A class for hosting a collection of callback objects. | ||
|
||
It is used to execute callback functions of multiple callback objects with the same method name. | ||
When callbackgroup.func(args) is executed, internally it loops through the objects in | ||
self._callbacks and runs self._callbacks[0].func(args), self._callbacks[1].func(args), etc. | ||
The method name and arguments should match. | ||
|
||
Attributes: | ||
_callbacks (list[Callback]): List of callback objects. | ||
""" | ||
|
||
_instance = None | ||
|
||
@classmethod | ||
def get_instance(cls) -> 'CallbackGroup': | ||
"""Get the singleton instance of the CallbackGroup. | ||
Args: | ||
cls (CallbackGroup): The class of the CallbackGroup. | ||
Returns: | ||
CallbackGroup: The singleton instance of the CallbackGroup. | ||
""" | ||
if cls._instance is None: | ||
cls._instance = CallbackGroup() | ||
|
||
return cls._instance | ||
|
||
def __init__(self) -> None: | ||
"""Initializes the list of callback objects.""" | ||
self._callbacks = [] | ||
|
||
def register(self, callback: Callback) -> None: | ||
"""Register a callback to the callback group. | ||
|
||
Args: | ||
callback (Callback): The callback to register. | ||
""" | ||
self._callbacks.append(callback) | ||
|
||
def __getattr__(self, method_name: str) -> Callable: | ||
"""Loops through the callback objects to call the corresponding callback function. | ||
|
||
Args: | ||
method_name (str): Callback method name. | ||
""" | ||
|
||
def multi_callback_wrapper(*args, **kwargs) -> None: | ||
for callback in self._callbacks: | ||
assert hasattr(callback, method_name) | ||
method = getattr(callback, method_name) | ||
assert callable(method) | ||
_ = method(*args, **kwargs) | ||
|
||
return multi_callback_wrapper | ||
|
||
@property | ||
def callbacks(self): | ||
"""Return callbacks in order. | ||
|
||
Returns: | ||
list: callback objects | ||
""" | ||
return self._callbacks | ||
|
||
|
||
class Callback(Callback): | ||
"""The base class for all callbacks. It inherits the pytorch lightning callback so the callback can be also passed to PTL trainer to reuse. | ||
Below list extra callback functions in NeMo. | ||
""" | ||
|
||
def on_dataloader_init_start(self): | ||
"""Called at the start of the data loading.""" | ||
|
||
def on_dataloader_init_end(self): | ||
"""Called at the end of the data loading.""" | ||
|
||
def on_model_init_start(self): | ||
"""Called at the start of the model initialization.""" | ||
|
||
def on_model_init_end(self): | ||
"""Called at the end of the model initialization.""" | ||
|
||
def on_optimizer_init_start(self) -> None: | ||
"""Called at the beginning of optimizer initialization.""" | ||
|
||
def on_optimizer_init_end(self) -> None: | ||
"""Called at the end of optimizer initialization.""" | ||
|
||
def on_load_checkpoint_start(self) -> None: | ||
"""Called at the beginning of loading checkpoint.""" | ||
|
||
def on_load_checkpoint_end(self) -> None: | ||
"""Called at the end of loading checkpoint.""" | ||
|
||
def on_save_checkpoint_start(self, iteration: int = 0) -> None: | ||
"""Called when start saving a checkpoint.""" | ||
|
||
def on_save_checkpoint_end(self, iteration: int = 0) -> None: | ||
"""Called when saving checkpoint (sync part) call ends.""" | ||
|
||
def on_save_checkpoint_success(self, iteration: int = 0) -> None: | ||
"""Called when checkpoint is saved successfully.""" | ||
|
||
|
||
CB_WRAP_RULES = { | ||
# The function name is the name of the method to wrap. | ||
# The start_hook and end_hook are the names of the methods to call before and after the original method. | ||
# The callback_method_name is the name of the method to call in the callback group. | ||
# Example: | ||
# function name: { | ||
# "start_hook": callback_method_name, | ||
# "end_hook": callback_method_name | ||
# } | ||
"setup_training_data": {"start_hook": "on_dataloader_init_start", "end_hook": "on_dataloader_init_end"}, | ||
"setup_optimization": {"start_hook": "on_optimizer_init_start", "end_hook": "on_optimizer_init_end"}, | ||
"restore_from_pretrained_models": {"start_hook": "on_load_checkpoint_start", "end_hook": "on_load_checkpoint_end"}, | ||
"__init__": {"start_hook": "on_model_init_start", "end_hook": "on_model_init_end"}, | ||
"configure_optimizers": {"start_hook": "on_optimizer_init_start", "end_hook": "on_optimizer_init_end"}, | ||
"setup_training_dataloader": {"start_hook": "on_dataloader_init_start", "end_hook": "on_dataloader_init_end"}, | ||
} | ||
|
||
|
||
def _make_callback_wrapped_method(original_method): | ||
"""Wrap a method with the start and end hooks of the callback group. | ||
|
||
Args: | ||
original_method (Callable): The original method to wrap. | ||
hooks (dict): The hooks to call. | ||
""" | ||
callback_group = CallbackGroup.get_instance() | ||
hooks = CB_WRAP_RULES.get(original_method.__name__) | ||
|
||
is_classmethod = isinstance(original_method, classmethod) | ||
|
||
if not hooks: | ||
return original_method | ||
|
||
@functools.wraps(original_method) | ||
def wrapped_instance_method(self, *args, **kwargs): | ||
if hasattr(callback_group, hooks["start_hook"]): | ||
getattr(callback_group, hooks["start_hook"])() | ||
result = original_method(self, *args, **kwargs) | ||
if hasattr(callback_group, hooks["end_hook"]): | ||
getattr(callback_group, hooks["end_hook"])() | ||
return result | ||
|
||
@functools.wraps(original_method) | ||
def wrapped_class_method(*args, **kwargs): | ||
if hasattr(callback_group, hooks["start_hook"]): | ||
getattr(callback_group, hooks["start_hook"])() | ||
result = original_method(*args, **kwargs) | ||
if hasattr(callback_group, hooks["end_hook"]): | ||
getattr(callback_group, hooks["end_hook"])() | ||
return result | ||
|
||
if is_classmethod: | ||
return classmethod(wrapped_class_method) | ||
else: | ||
return wrapped_instance_method | ||
|
||
|
||
def wrap_methods_with_callbacks(cls) -> None: | ||
"""Wrap class/instance methods with the start and end hooks of the callback group. | ||
|
||
Args: | ||
cls (type): The class to wrap the methods of. | ||
""" | ||
for method_name in CB_WRAP_RULES.keys(): | ||
if method_name in cls.__dict__: | ||
original_method = cls.__dict__[method_name] | ||
cls.__dict__[method_name] = _make_callback_wrapped_method(original_method) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ | |
|
||
from nemo.lightning.ckpt_utils import ckpt_to_dir | ||
from nemo.lightning.io.pl import TrainerContext | ||
from nemo.lightning.pytorch.callbacks.callback_group import CallbackGroup | ||
from nemo.utils import logging | ||
from nemo.utils.app_state import AppState | ||
|
||
|
@@ -567,6 +568,7 @@ def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str) | |
ValueError: (mcore) async_save with EMA not supported | ||
ValueError: (mcore) Async save requires async compatible CheckpointIO | ||
""" | ||
CallbackGroup.get_instance().on_save_checkpoint_start() | ||
|
||
from nemo.utils.get_rank import is_global_rank_zero | ||
|
||
|
@@ -598,6 +600,7 @@ def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str) | |
rank_zero_info(f"Saving EMA weights to separate checkpoint {filepath}") | ||
super()._save_checkpoint(trainer, filepath) | ||
self.remove_checkpoint_unfinished_marker(filepath, barrier_before=True) | ||
CallbackGroup.get_instance().on_save_checkpoint_success(global_step=trainer.global_step) | ||
else: | ||
# Determine whether to include optimizer states in the checkpoint | ||
# optimizer states are included when | ||
|
@@ -632,6 +635,7 @@ def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str) | |
logging.info(f'Scheduled async checkpoint save for {filepath}') | ||
else: | ||
finalize_fn() | ||
CallbackGroup.get_instance().on_save_checkpoint_end(global_step=trainer.global_step) | ||
|
||
def _get_finalize_save_checkpoint_callback( | ||
self, trainer: 'lightning.pytorch.Trainer', filepath: str, global_step: int | ||
|
@@ -655,6 +659,7 @@ def _cb(): | |
return | ||
|
||
logging.info(f'Async checkpoint save for step {global_step} ({filepath}) finalized successfully.') | ||
CallbackGroup.get_instance().on_save_checkpoint_success(global_step=trainer.global_step) | ||
|
||
if str(filepath) in self.ckpts_to_link: | ||
self._link_checkpoint(trainer, filepath, self.ckpts_to_link.pop(filepath), override_async=True) | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.