-
Notifications
You must be signed in to change notification settings - Fork 3k
Add CallbackGroup & Metadata factory function #13437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
PytLab
wants to merge
33
commits into
main
Choose a base branch
from
zshao/add_callback_group
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 4 commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
642e360
feat: add callback group definition & callback ABC
PytLab 1badf29
Apply isort and black reformatting
PytLab 3bf3367
feat: insert callback functions of CallbackGroup
PytLab 2b51e12
Apply isort and black reformatting
PytLab 249dad3
chore: PR test for jiashang
liquor233 db2b15d
feat: use __init_subclass__ to cover all ModelPT subclasses
PytLab d921d64
Apply isort and black reformatting
PytLab 3e32f1a
feat: Adding metadata config manager poc
e1074f6
Apply isort and black reformatting
sajup-oss d79f4f1
feat: revert test changes.
liquor233 263f7e9
fix: Updating metadata attributes
sajup-oss 81cd1d9
fix: Merging changes
sajup-oss 4852936
Apply isort and black reformatting
sajup-oss 48d6d87
fix: Adding OneloggerCallback
sajup-oss 2ba6cc5
fix: Reverting changes in examples/multimodal/speech_llm/modular_audi…
sajup-oss c908b53
fix: Merge branch 'zshao/add_callback_group' of github.com:NVIDIA/NeM…
sajup-oss bd39d8f
Apply isort and black reformatting
sajup-oss ba4e4a6
fix: update modular models and megatron GPT models
liquor233 515136c
Apply isort and black reformatting
liquor233 bc030f7
feat: add on_app_start and on_app_end
liquor233 2ed58f4
Apply isort and black reformatting
liquor233 35d2f2c
fix: Adding small test example for testing
sajup-oss ddc99fb
Apply isort and black reformatting
sajup-oss ca6ff4d
fix: Fixing review comments as discussed with Jiashang
9f11d01
Apply isort and black reformatting
sajup-oss 64e0e03
fix: updating nemo code to v2
sajup-oss 181bb3e
fix: updating code to v2
sajup-oss 61d631c
Apply isort and black reformatting
sajup-oss 8eb4fc6
fix: updating wandb to get info from env
sajup-oss 2900246
fix: updating wandb to get info from env
sajup-oss 4acbc2c
Apply isort and black reformatting
sajup-oss dffccfa
fix: fix som impl issue
liquor233 60eb727
Apply isort and black reformatting
liquor233 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Callable | ||
|
||
from lightning.pytorch.callbacks import Callback | ||
|
||
|
||
class CallbackGroup: | ||
"""A class for hosting a collection of callback objects. | ||
|
||
It is used to execute callback functions of multiple callback objects with the same method name. | ||
When callbackgroup.func(args) is executed, internally it loops through the objects in | ||
self._callbacks and runs self._callbacks[0].func(args), self._callbacks[1].func(args), etc. | ||
The method name and arguments should match. | ||
|
||
Attributes: | ||
_callbacks (list[Callback]): List of callback objects. | ||
""" | ||
|
||
_instance = None | ||
|
||
@classmethod | ||
def get_instance(cls) -> 'CallbackGroup': | ||
"""Get the singleton instance of the CallbackGroup. | ||
Args: | ||
cls (CallbackGroup): The class of the CallbackGroup. | ||
Returns: | ||
CallbackGroup: The singleton instance of the CallbackGroup. | ||
""" | ||
if cls._instance is None: | ||
cls._instance = CallbackGroup() | ||
|
||
return cls._instance | ||
|
||
def __init__(self, callbacks: list[Callback] | None) -> None: | ||
"""Initializes the list of callback objects. | ||
|
||
Args: | ||
callbacks (list[Callback]): List of callbacks. | ||
""" | ||
self._callbacks = callbacks or [] | ||
|
||
def __getattr__(self, method_name: str) -> Callable: | ||
"""Loops through the callback objects to call the corresponding callback function. | ||
|
||
Args: | ||
method_name (str): Callback method name. | ||
""" | ||
|
||
def multi_callback_wrapper(*args, **kwargs) -> None: | ||
for callback in self._callbacks: | ||
assert hasattr(callback, method_name) | ||
method = getattr(callback, method_name) | ||
assert callable(method) | ||
_ = method(*args, **kwargs) | ||
|
||
return multi_callback_wrapper | ||
|
||
@property | ||
def callbacks(self): | ||
"""Return callbacks in order. | ||
|
||
Returns: | ||
list: callback objects | ||
""" | ||
return self._callbacks | ||
|
||
|
||
class Callback(Callback): | ||
"""The base class for all callbacks. It inherits the pytorch lightning callback so the callback can be also passed to PTL trainer to reuse. | ||
Below list extra callback functions in NeMo. | ||
""" | ||
|
||
def on_dataloader_init_start(self): | ||
"""Called at the start of the data loading.""" | ||
|
||
def on_dataloader_init_end(self): | ||
"""Called at the end of the data loading.""" | ||
|
||
def on_model_init_start(self): | ||
"""Called at the start of the model initialization.""" | ||
|
||
def on_model_init_end(self): | ||
"""Called at the end of the model initialization.""" | ||
|
||
def on_optimizer_init_start(self) -> None: | ||
"""Called at the beginning of optimizer initialization.""" | ||
|
||
def on_optimizer_init_end(self) -> None: | ||
"""Called at the end of optimizer initialization.""" | ||
|
||
def on_load_checkpoint_start(self) -> None: | ||
"""Called at the beginning of loading checkpoint.""" | ||
|
||
def on_load_checkpoint_end(self) -> None: | ||
"""Called at the end of loading checkpoint.""" | ||
|
||
def on_save_checkpoint_start(self, iteration: int = 0) -> None: | ||
"""Called when start saving a checkpoint.""" | ||
|
||
def on_save_checkpoint_end(self, iteration: int = 0) -> None: | ||
"""Called when saving checkpoint (sync part) call ends.""" | ||
|
||
def on_save_checkpoint_success(self, iteration: int = 0) -> None: | ||
"""Called when checkpoint is saved successfully.""" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.