Skip to content

Commit e242290

Browse files
committed
tests
1 parent 4af6496 commit e242290

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

pytorch_lightning/core/saving.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
313313
return {}
314314

315315
with open(config_yaml) as fp:
316-
tags = yaml.load(fp, Loader=yaml.SafeLoader)
316+
tags = yaml.load(fp)
317317

318318
return tags
319319

tests/loggers/test_csv.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,11 @@ def test_file_logger_log_metrics(tmpdir, step_idx):
7070
logger.log_metrics(metrics, step_idx)
7171
logger.save()
7272

73-
path_yaml = os.path.join(logger.log_dir, ExperimentWriter.NAME_HPARAMS_FILE)
74-
params = load_hparams_from_yaml(path_yaml)
75-
assert all([n in params for n in metrics])
73+
path_csv = os.path.join(logger.log_dir, ExperimentWriter.NAME_METRICS_FILE)
74+
with open(path_csv, 'r') as fp:
75+
lines = fp.readlines()
76+
assert len(lines) == 2
77+
assert all([n in lines[0] for n in metrics])
7678

7779

7880
def test_file_logger_log_hyperparams(tmpdir):
@@ -89,3 +91,7 @@ def test_file_logger_log_hyperparams(tmpdir):
8991
}
9092
logger.log_hyperparams(hparams)
9193
logger.save()
94+
95+
path_yaml = os.path.join(logger.log_dir, ExperimentWriter.NAME_HPARAMS_FILE)
96+
params = load_hparams_from_yaml(path_yaml)
97+
assert all([n in params for n in hparams])

0 commit comments

Comments
 (0)