Skip to content

Multiple Datasets I - Base support in anemoi-core #230

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
JPXKQX opened this issue Apr 7, 2025 · 3 comments
Open

Multiple Datasets I - Base support in anemoi-core #230

JPXKQX opened this issue Apr 7, 2025 · 3 comments
Labels
enhancement New feature or request

Comments

@JPXKQX
Copy link
Member

JPXKQX commented Apr 7, 2025

Implement base support for handling multiple datasets in the anemoi-core data pipeline, specifically for the GraphForecaster model.

Goal

Train a forecasting model (era → era) relying on the DataHandler design proposed in #69 without rollout or diagnostics.

Scope

  • Datamodule: Implement the new DataHandler class. The goal of this task is that PyTorch Lightning datamodule returns batches of type:
{"era": torch.Tensor()}

instead of

torch.Tensor
  • Model:
    Refactor model interface to accept and return dictionaries in the same format. For this stage, just handle a single input/output dictionary of {"era": torch.Tensor}.

  • Training loop:
    Modify loss computation to operate over {pred_dict} -> {target_dict} by comparing predicted and target tensors per dataset key.

Notes

  • This is the first step toward full multi-resolution support.
  • Keep the scope limited to make integration and review straightforward.
  • Should be compatible with metadata changes defined in New Metadata Schema #229 .
@JPXKQX JPXKQX added the enhancement New feature or request label Apr 7, 2025
@JPXKQX JPXKQX added this to the Multiple datasets milestone Apr 7, 2025
@havardhhaugen
Copy link
Contributor

havardhhaugen commented Apr 25, 2025

I have done an initial implementation for the training loop and model to the point where i can run training_step/validation_step (without rollout / diagnostics) with mocked input/output data in a dictionary. Biggest changes / points for discussion
Model
9624ba2

  • ModelIndex - for now I have made a dummy ModelIndex object, which the model sets up based on input/output. This is sent back to the forecaster and used instead of dataloader.data_indices. Is this a good solution?

Training
023fb21

  • GraphForecasterMultiDataset as a temporary replacement for GraphForecaster while we work on this
  • DictLoss wrapper loss function with dictionairy input / output. Moved most of the logic to set up loss-function from init to a separate function
  • Training default config (refactored_default.yaml) with one dict for each output (for everything output specific).
  • Training/validation_step now takes batch = (batch_input, batch_target) as input, which moves some of the slicing work from forecaster to datahandler, but will probably make advance_input more complicated.

@VeraChristina
Copy link
Collaborator

VeraChristina commented Apr 29, 2025

Nice work getting it to run @havardhhaugen !

RE ModelIndex/data_indices, I think it would be good to still have this information as an attribute of AnemoiTrainer, but to collect it from the relevant config entries instead of taking the info from the dataloader -- similar to what you did in the current version. (In view of the multiple inputs/outputs we'll probably want it to be a dictionary that contains similar information as the previous data_indices but for all input/output data sources.)

Potential upsides of having it as an attribute of AnemoiTrainer:

  • We don't need to do this separately for all tasks (and we'll want this information for all tasks)
  • We can implement a fallback to the old approach at this level (if we want that)
  • The info is included in the metadata stored in the checkpoint (which we need for inference)
  • We could add a basic check at this level on whether data provided, input specs, and encoder specs line up (once we have added multiple encoder specs to the configs)

I'll start moving it there. Do let me know if you have any feedback on the above!

@JPXKQX
Copy link
Member Author

JPXKQX commented May 7, 2025

Update on the dataloading side
The current design of the data loading is shown below:

  • data_handlers : a DataHandlers(dict) object where each datahandler has its own time resolution, spatial resolution, and variables.
  • sample_provider: dict[str, SampleProvider]. In most cases, it will have two keys input & output, but may have a third one graph (for multi-domain training). The keys of this dictionary will match the keys of the batch returned by the dataloader.
  • datasets: dict[str, NativeGridMultDatasets(torch.IterableDataset)].
  • Data module: AnemoiDataModule(pl.LightningDataModule)
  • sampler: AnemoiSampler. This class synchronises the reference date for each sample. It sets the valid_time_indices. Should this class handle the shuffling?

Please check the conftest.py to check the structure of the new config file

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
Status: No status
Development

No branches or pull requests

3 participants