Skip to content

Datasets correlation callback #20962

Open
@maoragai

Description

@maoragai

Description & Motivation

Intro
this is a description of a method i used in an offline manner in my own project and thought might serve other pytorch-lightning users.
this might be a too specific use case but i thought it worth suggesting and submit a PR as well.

in many ML applications, training procedure datasets (train,val and test) distributions (or their metadata features) should follow the same distribution. To visualize the correlation between the datasets distribution one may use the QQ-plot (quantile to quantile plot) which compares the quantiles of each distribution on different axis on the XY plane (QQ plot wiki).
logging this plots, describing the dataset may serve the users on the begining of the training routin to determine how their datasets 'match'.
i believe that doing so using a lightning.pytorch.callbacks.Callback may be a clean way to log the plots.
This feature should:

  • Validate that their dataset splits are statistically aligned

Pitch

small example of how it might look like

import matplotlib.pyplot as plt
import numpy as np
import lightning as L
from typing import Optional, Callable, List, Any
from torch.utils.data import Dataset


class QQPlotLoggerCallback(L.Callback):
    def __init__(
        self,
        target_key: Optional[str] = None,
        extract_fn: Optional[Callable[[Any], float]] = None,
        num_samples: int = 1000,
        title: Optional[str] = "QQ Plot Data Comparison",
        quantiles: int = 100
    ):
        super().__init__()
        self.target_key = target_key
        self.extract_fn = extract_fn
        self.num_samples = num_samples
        self.title = title
        self.quantiles = quantiles

    def _extract_data(self, dataset: Dataset) -> List[float]:
        # function to pull values from a dataset
        values = []
        for i in range(min(len(dataset), self.num_samples)):
            sample = dataset[i]
            try:
                if self.extract_fn: # if external data extraction function is provided
                    val = self.extract_fn(sample)
                elif self.target_key: # if a specific data feature is provided
                    val = sample[self.target_key]
                else:
                    raise ValueError("Either target_key or extract_fn must be provided.")
                val = float(val)
                values.append(val)
            except Exception as e:
                print(f"Skipping sample {i}: {e}")
        return values

    
    def _log_qq_plot(
        self,
        data_a: List[float],
        data_b: List[float],
        label_a: str,
        label_b: str,
        logger,
        tag: str
    ):
       # function to handle the creation of an QQ plot and logging it
        x = np.array(data_a)
        y = np.array(data_b)

        # Use quantiles to align distributions
        q_levels = np.linspace(0.01, 0.99, self.quantiles)
        q_x = np.quantile(x, q_levels)
        q_y = np.quantile(y, q_levels)

        fig, ax = plt.subplots()
        ax.scatter(q_x, q_y, alpha=0.6, label="Empirical Q-Q")
        min_val = min(q_x.min(), q_y.min())
        max_val = max(q_x.max(), q_y.max())
        ax.plot([min_val, max_val], [min_val, max_val], 'r--', label="Ideal match")
        ax.set_title(f"{self.title}: {label_a} vs {label_b}")
        ax.set_xlabel(f"{label_a} Quantiles")
        ax.set_ylabel(f"{label_b} Quantiles")
        ax.legend()

        if isinstance(logger, L.pytorch.loggers.TensorBoardLogger):
            logger.experiment.add_figure(tag, fig, global_step=0)
        elif isinstance(logger, L.pytorch.loggers.WandbLogger):
            logger.experiment.log({tag: fig})
        else:
            print(f"Logger {type(logger)} not supported for figure logging.")

        plt.close(fig)

    def on_train_start(self, trainer: L.Trainer, pl_module: L.LightningModule):
        dm = trainer.datamodule
        if dm is None:
            print("No datamodule found. QQPlotLoggerCallback will not run.")
            return

        datasets = {}
        if hasattr(dm, 'train_dataloader'):
            datasets['train'] = dm.train_dataloader().dataset
        if hasattr(dm, 'val_dataloader'):
            datasets['val'] = dm.val_dataloader().dataset
        if hasattr(dm, 'test_dataloader'):
            datasets['test'] = dm.test_dataloader().dataset

        data = {}
        for name, ds in datasets.items():
            try:
                values = self._extract_data(ds)
                if values:
                    data[name] = values
            except Exception as e:
                print(f"Failed to extract data from {name} dataset: {e}")

        logger = trainer.logger
        pairs = [('train', 'val'), ('train', 'test'), ('val', 'test')]

        for a, b in pairs:
            if a in data and b in data:
                self._log_qq_plot(data[a], data[b], a, b, logger, tag=f"qqplot/{a}_vs_{b}")

Alternatives

As an alternative to calling the on_train_start one may call the callback on each epoch in case where the datasets change dynamically during the training routin.

Additional context

No response

cc @lantiga @Borda

Metadata

Metadata

Assignees

No one assigned

    Labels

    callbackfeatureIs an improvement or enhancementneeds triageWaiting to be triaged by maintainers

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions