15
15
16
16
import logging
17
17
import os
18
- from functools import wraps
19
- from typing import Callable , Optional , TypeVar , overload
18
+ from typing import Optional
20
19
21
20
import lightning_utilities .core .rank_zero as rank_zero_module
22
21
29
28
rank_zero_info ,
30
29
rank_zero_warn ,
31
30
)
32
- from typing_extensions import ParamSpec
33
-
34
- from lightning .fabric .utilities .imports import _UTILITIES_GREATER_EQUAL_0_10
35
31
36
32
rank_zero_module .log = logging .getLogger (__name__ )
37
33
@@ -48,33 +44,7 @@ def _get_rank() -> Optional[int]:
48
44
return None
49
45
50
46
51
- if not _UTILITIES_GREATER_EQUAL_0_10 :
52
- T = TypeVar ("T" )
53
- P = ParamSpec ("P" )
54
-
55
- @overload
56
- def rank_zero_only (fn : Callable [P , T ]) -> Callable [P , Optional [T ]]:
57
- """Rank zero only."""
58
-
59
- @overload
60
- def rank_zero_only (fn : Callable [P , T ], default : T ) -> Callable [P , T ]:
61
- """Rank zero only."""
62
-
63
- def rank_zero_only (fn : Callable [P , T ], default : Optional [T ] = None ) -> Callable [P , Optional [T ]]:
64
- @wraps (fn )
65
- def wrapped_fn (* args : P .args , ** kwargs : P .kwargs ) -> Optional [T ]:
66
- rank = getattr (rank_zero_only , "rank" , None )
67
- if rank is None :
68
- raise RuntimeError ("The `rank_zero_only.rank` needs to be set before use" )
69
- if rank == 0 :
70
- return fn (* args , ** kwargs )
71
- return default
72
-
73
- return wrapped_fn
74
-
75
- rank_zero_module .rank_zero_only .rank = getattr (rank_zero_module .rank_zero_only , "rank" , _get_rank () or 0 )
76
- else :
77
- rank_zero_only = rank_zero_module .rank_zero_only
47
+ rank_zero_only = rank_zero_module .rank_zero_only
78
48
79
49
# add the attribute to the function but don't overwrite in case Trainer has already set it
80
50
rank_zero_only .rank = getattr (rank_zero_only , "rank" , _get_rank () or 0 )
0 commit comments