Open
Description
Bug description
I experience a bug with logging when using torch compile.
Traceback (most recent call last):
File "/app/data/src_stardust/prediction/training/train_gnn_realdata.py", line 586, in <module>
trainer.fit(modele, train_loader, validation_loader, ckpt_path=last_checkpoint)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 545, in fit
call._call_and_handle_interrupt(
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 581, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 990, in _run
results = self._run_stage()
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1034, in _run_stage
self._run_sanity_check()
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1063, in _run_sanity_check
val_loop.run()
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py", line 181, in _decorator
return loop_run(self, *args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 134, in run
self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 391, in _evaluation_step
output = call._call_strategy_hook(trainer, hook_name, *step_args)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 309, in _call_strategy_hook
output = fn(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/strategies/strategy.py", line 403, in validation_step
return self.lightning_module.validation_step(*args, **kwargs)
File "/opt/conda/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/app/data/src_stardust/prediction/training/train_gnn_realdata.py", line 296, in validation_step
) = st.preprocessing_batch_real(
File "/app/data/src_stardust/prediction/training/train_gnn_realdata.py", line 344, in <resume in validation_step>
self.log(
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 412, in log
apply_to_collection(value, dict, self.__check_not_nested, name)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 413, in <resume in log>
apply_to_collection(
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/core/module.py", line 500, in <resume in log>
results.log(
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 390, in log
meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, _group=sync_dist_group, rank_zero_only=rank_zero_only)
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 398, in <resume in log>
raise MisconfigurationException(
File "/opt/conda/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 398, in <resume in log>
raise MisconfigurationException(
lightning.fabric.utilities.exceptions.MisconfigurationException: You called `self.log(val_aff_db, ...)` twice in `validation_step` with different arguments. This is not allowed
In my validation_step method, I only have one self.log statement :
self.log(
name="validation_metric",
value=myvalue,
on_step=True,
on_epoch=True,
logger=True,
batch_size=1,
prog_bar=False,
reduce_fx="mean",
enable_graph=False,
add_dataloader_idx=False,
metric_attribute=None,
rank_zero_only=True,
sync_dist=False,
)
This issue only happen using torch.compile() on the whole lightning module.
I have the latest lightning version 2.1
(and torch 2.1)
What version are you seeing the problem on?
master
How to reproduce the bug
I will try to use the simple BoringModule from the test to easily reproduce the bug (I will do that in the comment later).
Environment
Current environment
- Lightning Component (e.g. Trainer, LightningModule):
- PyTorch Lightning Version (e.g., 2.1):
- PyTorch Version (e.g., 2.1):
- I installed lightning with pip
More info
This bug only happen when I compile my module :
myligthningmodule = torch.compile(myligthningmodule)
cc @carmocca