Skip to content

Commit d47fc12

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 4014c4a commit d47fc12

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

src/lightning/pytorch/callbacks/timer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ def __init__(
111111
self._duration = duration.total_seconds() if duration is not None else None
112112
self._interval = interval
113113
self._verbose = verbose
114-
self._start_time: dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
115-
self._end_time: dict[RunningStage, Optional[float]] = {stage: None for stage in RunningStage}
114+
self._start_time: dict[RunningStage, Optional[float]] = dict.fromkeys(RunningStage)
115+
self._end_time: dict[RunningStage, Optional[float]] = dict.fromkeys(RunningStage)
116116
self._offset = 0
117117

118118
def start_time(self, stage: str = RunningStage.TRAINING) -> Optional[float]:

tests/tests_pytorch/helpers/datasets.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def __init__(self, root: str, num_samples: int = 100, digits: Optional[Sequence]
148148

149149
@staticmethod
150150
def _prepare_subset(full_data: Tensor, full_targets: Tensor, num_samples: int, digits: Sequence):
151-
classes = {d: 0 for d in digits}
151+
classes = dict.fromkeys(digits, 0)
152152
indexes = []
153153
for idx, target in enumerate(full_targets):
154154
label = target.item()

tests/tests_pytorch/trainer/connectors/test_data_connector.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,7 @@ def test_dataloader_source_request_from_module():
497497

498498

499499
@pytest.mark.parametrize(
500-
"hook_name", ("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer")
500+
"hook_name", ["on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer"]
501501
)
502502
class TestDataHookSelector:
503503
def overridden_func(self, batch, *args, **kwargs):

tests/tests_pytorch/trainer/logging_/test_logger_connector.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def test_fx_validator_integration(tmp_path):
246246
})
247247
trainer.test(model, verbose=False)
248248

249-
not_supported.update({k: "result collection is not registered yet" for k in not_supported})
249+
not_supported.update(dict.fromkeys(not_supported, "result collection is not registered yet"))
250250
not_supported.update({
251251
"predict_dataloader": "result collection is not registered yet",
252252
"on_predict_model_eval": "result collection is not registered yet",

0 commit comments

Comments
 (0)