@@ -34,7 +34,7 @@ class ExperimentWriter(object):
34
34
NAME_HPARAMS_FILE = 'hparams.yaml'
35
35
NAME_METRICS_FILE = 'metrics.csv'
36
36
37
- def __init__ (self , log_dir ) :
37
+ def __init__ (self , log_dir : str ) -> None :
38
38
self .hparams = {}
39
39
self .metrics = []
40
40
self .metrics_keys = ["step" ]
@@ -49,11 +49,11 @@ def __init__(self, log_dir):
49
49
50
50
self .metrics_file_path = os .path .join (self .log_dir , self .NAME_METRICS_FILE )
51
51
52
- def log_hparams (self , params ) :
52
+ def log_hparams (self , params : Dict [ str , Any ]) -> None :
53
53
"""Record hparams"""
54
54
self .hparams .update (params )
55
55
56
- def log_metrics (self , metrics_dict , step = None ):
56
+ def log_metrics (self , metrics_dict : Dict [ str , float ], step : Optional [ int ] = None ) -> None :
57
57
"""Record metrics"""
58
58
def _handle_value (value ):
59
59
if isinstance (value , torch .Tensor ):
@@ -71,7 +71,7 @@ def _handle_value(value):
71
71
new_row [k ] = _handle_value (v )
72
72
self .metrics .append (new_row )
73
73
74
- def save (self ):
74
+ def save (self ) -> None :
75
75
"""Save recorded hparams and metrics into files"""
76
76
hparams_file = os .path .join (self .log_dir , self .NAME_HPARAMS_FILE )
77
77
save_hparams_to_yaml (hparams_file , self .hparams )
@@ -135,6 +135,10 @@ def log_dir(self) -> str:
135
135
log_dir = os .path .join (self .root_dir , version )
136
136
return log_dir
137
137
138
+ @property
139
+ def save_dir (self ) -> Optional [str ]:
140
+ return self ._save_dir
141
+
138
142
@property
139
143
def experiment (self ) -> ExperimentWriter :
140
144
r"""
0 commit comments