Skip to content

Commit be3d138

Browse files
Add RankFilter to skip logging when the rank is not meeting criteria (#6243)
Partially fixes #6189 Fixes #6230 ### Description The RankFilter class is a convenient filter that extends the Filter class in the Python logging module. The purpose is to control which log records are processed based on the rank in a distributed environment. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mingxin Zheng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 8eceabf commit be3d138

File tree

4 files changed

+87
-2
lines changed

4 files changed

+87
-2
lines changed

monai/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .aliases import alias, resolve_name
1616
from .decorators import MethodReplacer, RestartGenerator
1717
from .deprecate_utils import DeprecatedError, deprecated, deprecated_arg, deprecated_arg_default
18-
from .dist import evenly_divisible_all_gather, get_dist_device, string_list_all_gather
18+
from .dist import RankFilter, evenly_divisible_all_gather, get_dist_device, string_list_all_gather
1919
from .enums import (
2020
Average,
2121
BlendMode,

monai/utils/dist.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from __future__ import annotations
1313

1414
import sys
15+
import warnings
16+
from collections.abc import Callable
17+
from logging import Filter
1518

1619
if sys.version_info >= (3, 8):
1720
from typing import Literal
@@ -26,7 +29,7 @@
2629

2730
idist, has_ignite = optional_import("ignite", IgniteInfo.OPT_IMPORT_VERSION, min_version, "distributed")
2831

29-
__all__ = ["get_dist_device", "evenly_divisible_all_gather", "string_list_all_gather"]
32+
__all__ = ["get_dist_device", "evenly_divisible_all_gather", "string_list_all_gather", "RankFilter"]
3033

3134

3235
def get_dist_device():
@@ -174,3 +177,31 @@ def string_list_all_gather(strings: list[str], delimiter: str = "\t") -> list[st
174177
_gathered = [bytearray(g.tolist()).decode("utf-8").split(delimiter) for g in gathered]
175178

176179
return [i for k in _gathered for i in k]
180+
181+
182+
class RankFilter(Filter):
183+
"""
184+
The RankFilter class is a convenient filter that extends the Filter class in the Python logging module.
185+
The purpose is to control which log records are processed based on the rank in a distributed environment.
186+
187+
Args:
188+
rank: the rank of the process in the torch.distributed. Default is None and then it will use dist.get_rank().
189+
filter_fn: an optional lambda function used as the filtering criteria.
190+
The default function logs only if the rank of the process is 0,
191+
but the user can define their own function to implement custom filtering logic.
192+
"""
193+
194+
def __init__(self, rank: int | None = None, filter_fn: Callable = lambda rank: rank == 0):
195+
super().__init__()
196+
self.filter_fn: Callable = filter_fn
197+
if dist.is_available() and dist.is_initialized():
198+
self.rank: int = rank if rank is not None else dist.get_rank()
199+
else:
200+
warnings.warn(
201+
"The torch.distributed is either unavailable and uninitiated when RankFilter is instiantiated. "
202+
"If torch.distributed is used, please ensure that the RankFilter() is called "
203+
"after torch.distributed.init_process_group() in the script."
204+
)
205+
206+
def filter(self, *_args):
207+
return self.filter_fn(self.rank)

tests/min_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def run_testsuit():
156156
"test_rand_zoom",
157157
"test_rand_zoomd",
158158
"test_randtorchvisiond",
159+
"test_rankfilter_dist",
159160
"test_resample_backends",
160161
"test_resize",
161162
"test_resized",

tests/test_rankfilter_dist.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import logging
15+
import os
16+
import tempfile
17+
import unittest
18+
19+
import torch.distributed as dist
20+
21+
from monai.utils import RankFilter
22+
from tests.utils import DistCall, DistTestCase
23+
24+
25+
class DistributedRankFilterTest(DistTestCase):
26+
def setUp(self):
27+
self.log_dir = tempfile.TemporaryDirectory()
28+
29+
@DistCall(nnodes=1, nproc_per_node=2)
30+
def test_rankfilter(self):
31+
logger = logging.getLogger(__name__)
32+
log_filename = os.path.join(self.log_dir.name, "records.log")
33+
h1 = logging.FileHandler(filename=log_filename)
34+
h1.setLevel(logging.WARNING)
35+
36+
logger.addHandler(h1)
37+
38+
logger.addFilter(RankFilter())
39+
logger.warning("test_warnings")
40+
41+
dist.barrier()
42+
if dist.get_rank() == 0:
43+
with open(log_filename) as file:
44+
lines = [line.rstrip() for line in file]
45+
log_message = " ".join(lines)
46+
assert log_message.count("test_warnings") == 1
47+
48+
def tearDown(self) -> None:
49+
self.log_dir.cleanup()
50+
51+
52+
if __name__ == "__main__":
53+
unittest.main()

0 commit comments

Comments
 (0)