Skip to content

Commit f0d37de

Browse files
author
xmotli02
committed
Added basic file logger #1803
1 parent 3f2c102 commit f0d37de

File tree

5 files changed

+274
-0
lines changed

5 files changed

+274
-0
lines changed

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
- Added FileLogger ([#2721](https://github.com/PyTorchLightning/pytorch-lightning/pull/2721))
1213
- Added SSIM metrics ([#2671](https://github.com/PyTorchLightning/pytorch-lightning/pull/2671))
1314
- Added BLEU metrics ([#2535](https://github.com/PyTorchLightning/pytorch-lightning/pull/2535))
1415

docs/source/loggers.rst

+6
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,10 @@ Test-tube
138138
^^^^^^^^^
139139

140140
.. autoclass:: pytorch_lightning.loggers.test_tube.TestTubeLogger
141+
:noindex:
142+
143+
FileLogger
144+
^^^^^^^^^^
145+
146+
.. autoclass:: pytorch_lightning.loggers.file_logger.FileLogger
141147
:noindex:

pytorch_lightning/loggers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from pytorch_lightning.loggers.base import LightningLoggerBase, LoggerCollection
44
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
5+
from pytorch_lightning.loggers.file_logger import FileLogger
6+
57

68
__all__ = [
79
'LightningLoggerBase',
+180
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
"""
2+
File logger
3+
-----------
4+
"""
5+
import io
6+
import os
7+
import csv
8+
import torch
9+
10+
from argparse import Namespace
11+
from typing import Optional, Dict, Any, Union
12+
13+
14+
from pytorch_lightning import _logger as log
15+
from pytorch_lightning.core.saving import save_hparams_to_yaml
16+
from pytorch_lightning.loggers.base import LightningLoggerBase
17+
from pytorch_lightning.utilities.distributed import rank_zero_only
18+
19+
20+
class ExperimentWriter(object):
21+
NAME_HPARAMS_FILE = 'hparams.yaml'
22+
NAME_METRICS_FILE = 'metrics.csv'
23+
24+
def __init__(self, log_dir):
25+
self.hparams = {}
26+
self.metrics = []
27+
self.metrics_keys = ["step"]
28+
29+
self.log_dir = log_dir
30+
if not os.path.exists(log_dir):
31+
os.makedirs(log_dir)
32+
33+
self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE)
34+
35+
def log_hparams(self, params):
36+
self.hparams.update(params)
37+
38+
def log_metrics(self, metrics_dict, step=None):
39+
def _handle_value(value):
40+
if isinstance(value, torch.Tensor):
41+
return value.item()
42+
return value
43+
44+
if step is None:
45+
step = len(self.metrics)
46+
47+
new_row = dict.fromkeys(self.metrics_keys)
48+
new_row['step'] = step
49+
for k, v in metrics_dict.items():
50+
if k not in self.metrics_keys:
51+
self.metrics_keys.append(k)
52+
new_row[k] = _handle_value(v)
53+
self.metrics.append(new_row)
54+
55+
def save(self):
56+
hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE)
57+
save_hparams_to_yaml(hparams_file, self.hparams)
58+
59+
if self.metrics:
60+
with io.open(self.metrics_file_path, 'w', newline='') as f:
61+
self.writer = csv.DictWriter(f, fieldnames=self.metrics_keys)
62+
self.writer.writeheader()
63+
self.writer.writerows(self.metrics)
64+
65+
66+
class FileLogger(LightningLoggerBase):
67+
r"""
68+
Log to local file system in yaml and CSV format. Logs are saved to
69+
``os.path.join(save_dir, name, version)``.
70+
71+
Example:
72+
>>> from pytorch_lightning import Trainer
73+
>>> from pytorch_lightning.loggers import FileLogger
74+
>>> logger = FileLogger("logs", name="my_exp_name")
75+
>>> trainer = Trainer(logger=logger)
76+
77+
Args:
78+
save_dir: Save directory
79+
name: Experiment name. Defaults to ``'default'``.
80+
version: Experiment version. If version is not specified the logger inspects the save
81+
directory for existing versions, then automatically assigns the next available version.
82+
"""
83+
84+
def __init__(self,
85+
save_dir: str,
86+
name: Optional[str] = "default",
87+
version: Optional[Union[int, str]] = None):
88+
89+
super().__init__()
90+
self._save_dir = save_dir
91+
self._name = name or ''
92+
self._version = version
93+
self._experiment = None
94+
95+
@property
96+
def root_dir(self) -> str:
97+
"""
98+
Parent directory for all checkpoint subdirectories.
99+
If the experiment name parameter is ``None`` or the empty string, no experiment subdirectory is used
100+
and the checkpoint will be saved in "save_dir/version_dir"
101+
"""
102+
if self.name is None or len(self.name) == 0:
103+
return self._save_dir
104+
return os.path.join(self._save_dir, self.name)
105+
106+
@property
107+
def log_dir(self) -> str:
108+
"""
109+
The log directory for this run. By default, it is named
110+
``'version_${self.version}'`` but it can be overridden by passing a string value
111+
for the constructor's version parameter instead of ``None`` or an int.
112+
"""
113+
# create a pseudo standard path ala test-tube
114+
version = self.version if isinstance(self.version, str) else f"version_{self.version}"
115+
log_dir = os.path.join(self.root_dir, version)
116+
return log_dir
117+
118+
@property
119+
def experiment(self) -> ExperimentWriter:
120+
r"""
121+
122+
Actual ExperimentWriter object. To use ExperimentWriter features in your
123+
:class:`~pytorch_lightning.core.lightning.LightningModule` do the following.
124+
125+
Example::
126+
127+
self.logger.experiment.some_experiment_writer_function()
128+
129+
"""
130+
if self._experiment is not None:
131+
return self._experiment
132+
133+
os.makedirs(self.root_dir, exist_ok=True)
134+
self._experiment = ExperimentWriter(log_dir=self.log_dir)
135+
return self._experiment
136+
137+
@rank_zero_only
138+
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
139+
params = self._convert_params(params)
140+
self.experiment.log_hparams(params)
141+
142+
@rank_zero_only
143+
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
144+
self.experiment.log_metrics(metrics, step)
145+
146+
@rank_zero_only
147+
def save(self) -> None:
148+
super().save()
149+
self.experiment.save()
150+
151+
@rank_zero_only
152+
def finalize(self, status: str) -> None:
153+
self.save()
154+
155+
@property
156+
def name(self) -> str:
157+
return self._name
158+
159+
@property
160+
def version(self) -> int:
161+
if self._version is None:
162+
self._version = self._get_next_version()
163+
return self._version
164+
165+
def _get_next_version(self):
166+
root_dir = os.path.join(self._save_dir, self.name)
167+
168+
if not os.path.isdir(root_dir):
169+
log.warning('Missing logger folder: %s', root_dir)
170+
return 0
171+
172+
existing_versions = []
173+
for d in os.listdir(root_dir):
174+
if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
175+
existing_versions.append(int(d.split("_")[1]))
176+
177+
if len(existing_versions) == 0:
178+
return 0
179+
180+
return max(existing_versions) + 1

tests/loggers/test_file_logger.py

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from argparse import Namespace
2+
3+
import pytest
4+
import torch
5+
import os
6+
7+
from pytorch_lightning.loggers import FileLogger
8+
9+
10+
def test_file_logger_automatic_versioning(tmpdir):
11+
"""Verify that automatic versioning works"""
12+
13+
root_dir = tmpdir.mkdir("exp")
14+
root_dir.mkdir("version_0")
15+
root_dir.mkdir("version_1")
16+
17+
logger = FileLogger(save_dir=tmpdir, name="exp")
18+
19+
assert logger.version == 2
20+
21+
22+
def test_file_logger_manual_versioning(tmpdir):
23+
"""Verify that manual versioning works"""
24+
25+
root_dir = tmpdir.mkdir("exp")
26+
root_dir.mkdir("version_0")
27+
root_dir.mkdir("version_1")
28+
root_dir.mkdir("version_2")
29+
30+
logger = FileLogger(save_dir=tmpdir, name="exp", version=1)
31+
32+
assert logger.version == 1
33+
34+
35+
def test_file_logger_named_version(tmpdir):
36+
"""Verify that manual versioning works for string versions, e.g. '2020-02-05-162402' """
37+
38+
exp_name = "exp"
39+
tmpdir.mkdir(exp_name)
40+
expected_version = "2020-02-05-162402"
41+
42+
logger = FileLogger(save_dir=tmpdir, name=exp_name, version=expected_version)
43+
logger.log_hyperparams({"a": 1, "b": 2})
44+
logger.save()
45+
assert logger.version == expected_version
46+
assert os.listdir(tmpdir / exp_name) == [expected_version]
47+
assert os.listdir(tmpdir / exp_name / expected_version)
48+
49+
50+
@pytest.mark.parametrize("name", ['', None])
51+
def test_file_logger_no_name(tmpdir, name):
52+
"""Verify that None or empty name works"""
53+
logger = FileLogger(save_dir=tmpdir, name=name)
54+
logger.save()
55+
assert logger.root_dir == tmpdir
56+
assert os.listdir(tmpdir / 'version_0')
57+
58+
59+
@pytest.mark.parametrize("step_idx", [10, None])
60+
def test_file_logger_log_metrics(tmpdir, step_idx):
61+
logger = FileLogger(tmpdir)
62+
metrics = {
63+
"float": 0.3,
64+
"int": 1,
65+
"FloatTensor": torch.tensor(0.1),
66+
"IntTensor": torch.tensor(1)
67+
}
68+
logger.log_metrics(metrics, step_idx)
69+
logger.save()
70+
71+
72+
def test_file_logger_log_hyperparams(tmpdir):
73+
logger = FileLogger(tmpdir)
74+
hparams = {
75+
"float": 0.3,
76+
"int": 1,
77+
"string": "abc",
78+
"bool": True,
79+
"dict": {'a': {'b': 'c'}},
80+
"list": [1, 2, 3],
81+
"namespace": Namespace(foo=Namespace(bar='buzz')),
82+
"layer": torch.nn.BatchNorm1d
83+
}
84+
logger.log_hyperparams(hparams)
85+
logger.save()

0 commit comments

Comments
 (0)