Description
Description & Motivation
Intro
this is a description of a method i used in an offline manner in my own project and thought might serve other pytorch-lightning users.
this might be a too specific use case but i thought it worth suggesting and submit a PR as well.
in many ML applications, training procedure datasets (train,val and test) distributions (or their metadata features) should follow the same distribution. To visualize the correlation between the datasets distribution one may use the QQ-plot (quantile to quantile plot) which compares the quantiles of each distribution on different axis on the XY plane (QQ plot wiki).
logging this plots, describing the dataset may serve the users on the begining of the training routin to determine how their datasets 'match'.
i believe that doing so using a lightning.pytorch.callbacks.Callback
may be a clean way to log the plots.
This feature should:
- Validate that their dataset splits are statistically aligned
Pitch
small example of how it might look like
import matplotlib.pyplot as plt
import numpy as np
import lightning as L
from typing import Optional, Callable, List, Any
from torch.utils.data import Dataset
class QQPlotLoggerCallback(L.Callback):
def __init__(
self,
target_key: Optional[str] = None,
extract_fn: Optional[Callable[[Any], float]] = None,
num_samples: int = 1000,
title: Optional[str] = "QQ Plot Data Comparison",
quantiles: int = 100
):
super().__init__()
self.target_key = target_key
self.extract_fn = extract_fn
self.num_samples = num_samples
self.title = title
self.quantiles = quantiles
def _extract_data(self, dataset: Dataset) -> List[float]:
# function to pull values from a dataset
values = []
for i in range(min(len(dataset), self.num_samples)):
sample = dataset[i]
try:
if self.extract_fn: # if external data extraction function is provided
val = self.extract_fn(sample)
elif self.target_key: # if a specific data feature is provided
val = sample[self.target_key]
else:
raise ValueError("Either target_key or extract_fn must be provided.")
val = float(val)
values.append(val)
except Exception as e:
print(f"Skipping sample {i}: {e}")
return values
def _log_qq_plot(
self,
data_a: List[float],
data_b: List[float],
label_a: str,
label_b: str,
logger,
tag: str
):
# function to handle the creation of an QQ plot and logging it
x = np.array(data_a)
y = np.array(data_b)
# Use quantiles to align distributions
q_levels = np.linspace(0.01, 0.99, self.quantiles)
q_x = np.quantile(x, q_levels)
q_y = np.quantile(y, q_levels)
fig, ax = plt.subplots()
ax.scatter(q_x, q_y, alpha=0.6, label="Empirical Q-Q")
min_val = min(q_x.min(), q_y.min())
max_val = max(q_x.max(), q_y.max())
ax.plot([min_val, max_val], [min_val, max_val], 'r--', label="Ideal match")
ax.set_title(f"{self.title}: {label_a} vs {label_b}")
ax.set_xlabel(f"{label_a} Quantiles")
ax.set_ylabel(f"{label_b} Quantiles")
ax.legend()
if isinstance(logger, L.pytorch.loggers.TensorBoardLogger):
logger.experiment.add_figure(tag, fig, global_step=0)
elif isinstance(logger, L.pytorch.loggers.WandbLogger):
logger.experiment.log({tag: fig})
else:
print(f"Logger {type(logger)} not supported for figure logging.")
plt.close(fig)
def on_train_start(self, trainer: L.Trainer, pl_module: L.LightningModule):
dm = trainer.datamodule
if dm is None:
print("No datamodule found. QQPlotLoggerCallback will not run.")
return
datasets = {}
if hasattr(dm, 'train_dataloader'):
datasets['train'] = dm.train_dataloader().dataset
if hasattr(dm, 'val_dataloader'):
datasets['val'] = dm.val_dataloader().dataset
if hasattr(dm, 'test_dataloader'):
datasets['test'] = dm.test_dataloader().dataset
data = {}
for name, ds in datasets.items():
try:
values = self._extract_data(ds)
if values:
data[name] = values
except Exception as e:
print(f"Failed to extract data from {name} dataset: {e}")
logger = trainer.logger
pairs = [('train', 'val'), ('train', 'test'), ('val', 'test')]
for a, b in pairs:
if a in data and b in data:
self._log_qq_plot(data[a], data[b], a, b, logger, tag=f"qqplot/{a}_vs_{b}")
Alternatives
As an alternative to calling the on_train_start
one may call the callback on each epoch in case where the datasets change dynamically during the training routin.
Additional context
No response