Skip to content

Commit 767c449

Browse files
motlikmxmotli02Borda
authored
Added basic file logger (#2721)
* Added basic file logger #1803 * fixup! Added basic file logger #1803 * fixup! Added basic file logger #1803 * fixup! Added basic file logger #1803 * fixup! Added basic file logger #1803 * fixup! Added basic file logger #1803 * csv * Apply suggestions from code review * tests * tests * tests * miss * docs Co-authored-by: xmotli02 <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent ac4a215 commit 767c449

File tree

7 files changed

+320
-1
lines changed

7 files changed

+320
-1
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111

1212
- Added SyncBN for DDP ([#2801](https://github.com/PyTorchLightning/pytorch-lightning/pull/2801))
1313

14+
- Added basic `CSVLogger` ([#2721](https://github.com/PyTorchLightning/pytorch-lightning/pull/2721))
15+
1416
- Added SSIM metrics ([#2671](https://github.com/PyTorchLightning/pytorch-lightning/pull/2671))
1517

1618
- Added BLEU metrics ([#2535](https://github.com/PyTorchLightning/pytorch-lightning/pull/2535))

docs/source/loggers.rst

+6
Original file line numberDiff line numberDiff line change
@@ -339,4 +339,10 @@ Test-tube
339339
^^^^^^^^^
340340

341341
.. autoclass:: pytorch_lightning.loggers.test_tube.TestTubeLogger
342+
:noindex:
343+
344+
CSVLogger
345+
^^^^^^^^^
346+
347+
.. autoclass:: pytorch_lightning.loggers.csv_logs.CSVLogger
342348
:noindex:

pytorch_lightning/core/saving.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def load_hparams_from_yaml(config_yaml: str) -> Dict[str, Any]:
313313
return {}
314314

315315
with open(config_yaml) as fp:
316-
tags = yaml.load(fp, Loader=yaml.SafeLoader)
316+
tags = yaml.load(fp)
317317

318318
return tags
319319

pytorch_lightning/loggers/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22

33
from pytorch_lightning.loggers.base import LightningLoggerBase, LoggerCollection
44
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
5+
from pytorch_lightning.loggers.csv_logs import CSVLogger
6+
57

68
__all__ = [
79
'LightningLoggerBase',
810
'LoggerCollection',
911
'TensorBoardLogger',
12+
'CSVLogger',
1013
]
1114

1215
try:

pytorch_lightning/loggers/csv_logs.py

+204
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
"""
2+
CSV logger
3+
----------
4+
5+
CSV logger for basic experiment logging that does not require opening ports
6+
7+
"""
8+
import io
9+
import os
10+
import csv
11+
import torch
12+
from argparse import Namespace
13+
from typing import Optional, Dict, Any, Union
14+
15+
from pytorch_lightning import _logger as log
16+
from pytorch_lightning.core.saving import save_hparams_to_yaml
17+
from pytorch_lightning.loggers.base import LightningLoggerBase
18+
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_only
19+
20+
21+
class ExperimentWriter(object):
22+
r"""
23+
Experiment writer for CSVLogger.
24+
25+
Currently supports to log hyperparameters and metrics in YAML and CSV
26+
format, respectively.
27+
28+
Args:
29+
log_dir: Directory for the experiment logs
30+
"""
31+
32+
NAME_HPARAMS_FILE = 'hparams.yaml'
33+
NAME_METRICS_FILE = 'metrics.csv'
34+
35+
def __init__(self, log_dir: str) -> None:
36+
self.hparams = {}
37+
self.metrics = []
38+
39+
self.log_dir = log_dir
40+
if os.path.exists(self.log_dir):
41+
rank_zero_warn(
42+
f"Experiment logs directory {self.log_dir} exists and is not empty."
43+
" Previous log files in this directory will be deleted when the new ones are saved!"
44+
)
45+
os.makedirs(self.log_dir, exist_ok=True)
46+
47+
self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE)
48+
49+
def log_hparams(self, params: Dict[str, Any]) -> None:
50+
"""Record hparams"""
51+
self.hparams.update(params)
52+
53+
def log_metrics(self, metrics_dict: Dict[str, float], step: Optional[int] = None) -> None:
54+
"""Record metrics"""
55+
def _handle_value(value):
56+
if isinstance(value, torch.Tensor):
57+
return value.item()
58+
return value
59+
60+
if step is None:
61+
step = len(self.metrics)
62+
63+
metrics = {k: _handle_value(v) for k, v in metrics_dict.items()}
64+
metrics['step'] = step
65+
self.metrics.append(metrics)
66+
67+
def save(self) -> None:
68+
"""Save recorded hparams and metrics into files"""
69+
hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE)
70+
save_hparams_to_yaml(hparams_file, self.hparams)
71+
72+
if not self.metrics:
73+
return
74+
75+
last_m = {}
76+
for m in self.metrics:
77+
last_m.update(m)
78+
metrics_keys = list(last_m.keys())
79+
80+
with io.open(self.metrics_file_path, 'w', newline='') as f:
81+
self.writer = csv.DictWriter(f, fieldnames=metrics_keys)
82+
self.writer.writeheader()
83+
self.writer.writerows(self.metrics)
84+
85+
86+
class CSVLogger(LightningLoggerBase):
87+
r"""
88+
Log to local file system in yaml and CSV format. Logs are saved to
89+
``os.path.join(save_dir, name, version)``.
90+
91+
Example:
92+
>>> from pytorch_lightning import Trainer
93+
>>> from pytorch_lightning.loggers import CSVLogger
94+
>>> logger = CSVLogger("logs", name="my_exp_name")
95+
>>> trainer = Trainer(logger=logger)
96+
97+
Args:
98+
save_dir: Save directory
99+
name: Experiment name. Defaults to ``'default'``.
100+
version: Experiment version. If version is not specified the logger inspects the save
101+
directory for existing versions, then automatically assigns the next available version.
102+
"""
103+
104+
def __init__(self,
105+
save_dir: str,
106+
name: Optional[str] = "default",
107+
version: Optional[Union[int, str]] = None):
108+
109+
super().__init__()
110+
self._save_dir = save_dir
111+
self._name = name or ''
112+
self._version = version
113+
self._experiment = None
114+
115+
@property
116+
def root_dir(self) -> str:
117+
"""
118+
Parent directory for all checkpoint subdirectories.
119+
If the experiment name parameter is ``None`` or the empty string, no experiment subdirectory is used
120+
and the checkpoint will be saved in "save_dir/version_dir"
121+
"""
122+
if not self.name:
123+
return self.save_dir
124+
return os.path.join(self.save_dir, self.name)
125+
126+
@property
127+
def log_dir(self) -> str:
128+
"""
129+
The log directory for this run. By default, it is named
130+
``'version_${self.version}'`` but it can be overridden by passing a string value
131+
for the constructor's version parameter instead of ``None`` or an int.
132+
"""
133+
# create a pseudo standard path ala test-tube
134+
version = self.version if isinstance(self.version, str) else f"version_{self.version}"
135+
log_dir = os.path.join(self.root_dir, version)
136+
return log_dir
137+
138+
@property
139+
def save_dir(self) -> Optional[str]:
140+
return self._save_dir
141+
142+
@property
143+
def experiment(self) -> ExperimentWriter:
144+
r"""
145+
146+
Actual ExperimentWriter object. To use ExperimentWriter features in your
147+
:class:`~pytorch_lightning.core.lightning.LightningModule` do the following.
148+
149+
Example::
150+
151+
self.logger.experiment.some_experiment_writer_function()
152+
153+
"""
154+
if self._experiment:
155+
return self._experiment
156+
157+
os.makedirs(self.root_dir, exist_ok=True)
158+
self._experiment = ExperimentWriter(log_dir=self.log_dir)
159+
return self._experiment
160+
161+
@rank_zero_only
162+
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
163+
params = self._convert_params(params)
164+
self.experiment.log_hparams(params)
165+
166+
@rank_zero_only
167+
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
168+
self.experiment.log_metrics(metrics, step)
169+
170+
@rank_zero_only
171+
def save(self) -> None:
172+
super().save()
173+
self.experiment.save()
174+
175+
@rank_zero_only
176+
def finalize(self, status: str) -> None:
177+
self.save()
178+
179+
@property
180+
def name(self) -> str:
181+
return self._name
182+
183+
@property
184+
def version(self) -> int:
185+
if self._version is None:
186+
self._version = self._get_next_version()
187+
return self._version
188+
189+
def _get_next_version(self):
190+
root_dir = os.path.join(self._save_dir, self.name)
191+
192+
if not os.path.isdir(root_dir):
193+
log.warning('Missing logger folder: %s', root_dir)
194+
return 0
195+
196+
existing_versions = []
197+
for d in os.listdir(root_dir):
198+
if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"):
199+
existing_versions.append(int(d.split("_")[1]))
200+
201+
if len(existing_versions) == 0:
202+
return 0
203+
204+
return max(existing_versions) + 1

tests/loggers/test_all.py

+7
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
import platform
66
from unittest import mock
77

8+
import cloudpickle
89
import pytest
910

1011
import tests.base.develop_utils as tutils
1112
from pytorch_lightning import Trainer, Callback
1213
from pytorch_lightning.loggers import (
14+
CSVLogger,
1315
TensorBoardLogger,
1416
MLFlowLogger,
1517
NeptuneLogger,
@@ -34,6 +36,7 @@ def _get_logger_args(logger_class, save_dir):
3436

3537
@pytest.mark.parametrize("logger_class", [
3638
TensorBoardLogger,
39+
CSVLogger,
3740
CometLogger,
3841
MLFlowLogger,
3942
NeptuneLogger,
@@ -85,6 +88,7 @@ def log_metrics(self, metrics, step):
8588

8689

8790
@pytest.mark.parametrize("logger_class", [
91+
CSVLogger,
8892
TensorBoardLogger,
8993
CometLogger,
9094
MLFlowLogger,
@@ -148,6 +152,7 @@ def name(self):
148152

149153
@pytest.mark.parametrize("logger_class", [
150154
TensorBoardLogger,
155+
CSVLogger,
151156
CometLogger,
152157
MLFlowLogger,
153158
NeptuneLogger,
@@ -170,6 +175,7 @@ def test_loggers_pickle(tmpdir, monkeypatch, logger_class):
170175

171176
# test pickling loggers
172177
pickle.dumps(logger)
178+
cloudpickle.dumps(logger)
173179

174180
trainer = Trainer(
175181
max_epochs=1,
@@ -226,6 +232,7 @@ def on_train_batch_start(self, trainer, pl_module):
226232
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
227233
@pytest.mark.parametrize("logger_class", [
228234
TensorBoardLogger,
235+
# CSVLogger, # todo
229236
CometLogger,
230237
MLFlowLogger,
231238
NeptuneLogger,

0 commit comments

Comments
 (0)