Skip to content

Commit 26e4f65

Browse files
tjhunterkacpnowakclessigiluiseSindhu-Vasireddy
authored andcommitted
[251] Merge new IO class (ecmwf#469)
* Implement mock IO (ecmwf#336) * Adapt score class score class (ecmwf#339) * Implement mock IO * Adapt score class * Removing unused file (ecmwf#349) * remove database folder (ecmwf#355) * Small change - CI - pinning the version of formatting (ecmwf#361) * changes * changes * Update INSTALL.md * Update INSTALL.md * Fixed Exxx lint issues (ecmwf#284) * Rebased to the latest changes and linted new changes * addressed review comments * addressed review comments * Linted the latest changes. * corrected the formating * corrected the formating * configured ruff to use LF line endings in pyproject.toml * [357] Sub-package for evaluation (ecmwf#359) * working * changes * removing deps from non-core project * changes * fixes * comments * Iluise quick fix stac (ecmwf#374) * remove database folder * fix database * Simplifying workflow for plot_training (ecmwf#368) * Simplifying workflow for plot_training * Ruffed * Working on implementing exclude_source * Remove unused code * Fixed ruff issue * Fixing bug in lat handling (377) (ecmwf#378) * Fixing bug in lat handling * Added comment --------- Co-authored-by: Seb Hickman <[email protected]> * recover num_ranks from previous run to calculate epoch_base (ecmwf#317) * recover num_ranks from previous run to calculate epoch_base * set email settings for commits * addressing Tim's comment * make ruff happy * improve style * changes (ecmwf#385) Linter rule so np.ndarray is not used as type * changed the script name from evaluate to inference as it simply gener… (ecmwf#376) * changed the script name from evaluate to inference as it simply generate infer samples * changed evaluate to inference in the main scripts and corresponding calls in the config * update the main function for the inference script * changed evaluate to inference also in docstring, unit test scripts, and integration test scripts --------- Co-authored-by: Patnala,Ankit <[email protected]> * Introduce tuples instead for strings to avoid TypeError (ecmwf#392) * Exclude channels from src / target (ecmwf#363) * Exclude channels from src / target * Simplified code and added comment that pattern matching is used * Adding new stream config * Fixing bug that led to error when accessing self.ds when dataset is empty * Wokign on exlcude_source * work in progress * Fixing incorrect formating for logger (ecmwf#388) * Ruffed * Refactored and cleaned up channel selection. Also added check that channels are not empty * Cleaned channel parsing and selection * Adjustments * Removing asserts incompatible with empty dataset --------- Co-authored-by: Christian Lessig <[email protected]> * add embed_dropout_rate to config v1 (ecmwf#358) * [402] adds checks to the pull request (ecmwf#403) * chanegs * mistake * mistake * mistake * changes * doc * Introduce masking class and incorporate in TokenizerMasking (ecmwf#383) * creating masking class and adapting tokenizer_masking to use this class * minor changes to masking.py and tokenizer_masking * removed old tokenizer_masking * include masking_strategy in default_config * change ValueError to assert * linting formatting changes files * further linting of docstrings * create mask_source and mask_target in Masker, and update tokenizer_masking to use these, then style improvements * linted masking, tokenizer_masking * modify masker, rng and perm_sel now part of class, remove extra masking_rate, update comments, remove archived class * remove check if all masked, not masked * remove self.masking_rate from MultiStreamDS class, and masking args from batchify_source * update tokenizer utils with description of idx_ord_lens in comment * remove masking args from batchify_, perm_sel removed now internal to Masker class, remove handling special cases of masking (all masked) * adding masking_strategy: to config * remove unused mentions of masking_combination * removed comment about streams * changed assert to check self perm_sel is not None * ruff masking, tokenizer_masking * Ruffed * Added warning to capture corner case, likely due to incorrect user settings. * Fixed incorrect call twice * Fixed missing conditional for logger statement * Required changes for better handling of rngs * Improved handling of rngs * Improved handling of rng --------- Co-authored-by: Christian Lessig <[email protected]> * Implement per-channel logging (ecmwf#283) * Fix bug with seed being divided by 0 for worker ID=0 * Fix bug causing crash when secrets aren't in private config * Implement logging losses per channel * Fix issue with empty targets * Rework loss logging * ruff * Remove computing max_channels * Change variables names * ruffed * Remove redundant enumerations * Use stages for logging * Add type hints * Apply the review * ruff * fix * Fix type hints * ruff --------- Co-authored-by: Tim Hunter <[email protected]> * [346] Passing options through the slurm script (ecmwf#400) * changes * fixes * refactor `validation_io.write_validation` to make it more readable * remove legacy code `validation_io.read_validation` * encapsulate artifact path logic in config module * remove redundant attribute `Trainer.path_run` * use config to look up base_path in `write_validation` * remove unused `write_validation` args: `base_path`, `rank` * ensure correct type for pathes * remove streams initialization from `Trainer` * remove path logic from `Trainer.save_model` * simplify conditional * rename mock io module * update uv to include dask * Implement io module to support reading/writing model output * implement new validation_io routine * use new write_validation routine * remove unused code * rename output routine to `write_output` * ruffed and added comments * fixed annotation * use simple __init__ method for `OutputItem` instead of dataclasses magic * address reviewers comments * rename method * add simple docstrings * ruffed * typehint fixes * refactor names * update comments and typehints, dont import pytorch * remove `__post_init__` methods, cache properties * fixes and integration test * final fixes :) * changes * changes * changes * changes * changes * more work * changes * changes * changes * ruffed * ruffed * improve logging and comments * Update to score-class according to internal discussions and feedback in PR. * Add license header. * Ruffed code. * Update to score-class according to internal discussions and feedback in PR. * Add license header. * Ruffed code. * Add doc-string to call-method and provide example usage for efficient graph-construction. * Some fixes to score-class. * Some fixes to handling aggregation dimension. * Add missing import of MockIO. * changes * changes * removing the scores * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes * changes --------- Co-authored-by: Kacper Nowak <[email protected]> Co-authored-by: Christian Lessig <[email protected]> Co-authored-by: iluise <[email protected]> Co-authored-by: Sindhu-Vasireddy <[email protected]> Co-authored-by: Seb Hickman <[email protected]> Co-authored-by: Julian Kuehnert <[email protected]> Co-authored-by: ankitpatnala <[email protected]> Co-authored-by: Patnala,Ankit <[email protected]> Co-authored-by: Savvas Melidonis <[email protected]> Co-authored-by: Christian Lessig <[email protected]> Co-authored-by: Till Hauer <[email protected]> Co-authored-by: Simon Grasse <[email protected]> Co-authored-by: Michael <[email protected]>
1 parent ae5c1ec commit 26e4f65

File tree

18 files changed

+612
-163
lines changed

18 files changed

+612
-163
lines changed

config/mixed.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# The default configuration file for multi streams training.
2+
streams_directory: "./config/streams/streams_mixed/"

config/streams/streams_mixed/era5.yml

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99

1010
ERA5 :
1111
type : anemoi
12-
# filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2022-6h-v6.zarr']
13-
filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2023-1h-v1-with-ERA51.zarr']
14-
frequency : "02:00:00"
12+
filenames : ['aifs-ea-an-oper-0001-mars-o96-1979-2022-6h-v6.zarr']
1513
loss_weight : 1.
1614
source_exclude : ['w_', 'skt', 'sp', 'tcw', 'cp', 'tp']
1715
target_exclude : ['w_', 'skt', 'sp', 'tcw', 'cp', 'tp']
@@ -23,16 +21,17 @@ ERA5 :
2321
embed :
2422
net : transformer
2523
num_tokens : 1
26-
num_heads : 8
27-
dim_embed : 512
24+
num_heads : 4
25+
dim_embed : 128
2826
num_blocks : 2
2927
embed_target_coords :
3028
net : linear
31-
dim_embed : 256
29+
dim_embed : 128
3230
target_readout :
3331
type : 'obs_value'
3432
num_layers : 2
3533
num_heads : 4
3634
pred_head :
3735
ens_size : 1
3836
num_layers : 1
37+

config/streams/streams_mixed/npp_atms.yml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
# granted to it by virtue of its status as an intergovernmental organisation
88
# nor does it submit to any jurisdiction.
99

10-
NPP, ATMS :
10+
NPPATMS :
1111
type : obs
1212
filenames : ['observations-ea-ofb-0001-2012-2023-npp-atms-radiances-v2.zarr']
1313
loss_weight : 1.0
@@ -27,5 +27,4 @@ NPP, ATMS :
2727
num_heads : 4
2828
pred_head :
2929
ens_size : 1
30-
num_layers : 1
31-
30+
num_layers : 1

config/streams/streams_mixed/synop.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# 1 : geostationay satellites
44
# 2 : conventional observations
55

6-
Surface Combined :
6+
SurfaceCombined :
77
type : obs
88
filenames : ['observations-ea-ofb-0001-1979-2023-combined-surface-v2.zarr']
99
loss_weight : 1.0

integration_tests/small1_test.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515

1616
import pytest
1717

18-
from weathergen import inference_from_args, train_with_args
18+
import weathergen.common.io as io
19+
import weathergen.utils.config as config
20+
from weathergen.run_train import inference_from_args, train_with_args
1921

2022
logger = logging.getLogger(__name__)
2123

@@ -55,6 +57,7 @@ def test_train(setup, test_run_id):
5557
f"{weathergen_home}/config/streams/streams_test/",
5658
)
5759

60+
logger.info("run inference")
5861
inference_from_args(
5962
["-start", "2022-10-10", "-end", "2022-10-11", "--samples", "10", "--epoch", "0"]
6063
+ [
@@ -66,12 +69,35 @@ def test_train(setup, test_run_id):
6669
f"{weathergen_home}/integration_tests/small1.yaml",
6770
]
6871
)
72+
logger.info("run evaluation")
73+
evaluate_results(test_run_id)
6974
assert_missing_metrics_file(test_run_id)
7075
assert_train_loss_below_threshold(test_run_id)
7176
assert_val_loss_below_threshold(test_run_id)
7277
logger.info("end test_train")
7378

7479

80+
def evaluate_results(run_id):
81+
cf = config.load_model_config(run_id, None, None)
82+
data_root = config.get_path_output(cf, 0)
83+
84+
with io.ZarrIO(data_root) as reader:
85+
samples = reader.samples
86+
fsteps = reader.forecast_steps
87+
streams = reader.streams
88+
89+
item = reader.get_data(samples[0], streams[0], fsteps[0])
90+
ds = item.prediction.as_xarray()
91+
logger.info(ds)
92+
item.target.as_xarray()
93+
logger.info(ds)
94+
if item.key.with_source:
95+
ds = item.source.as_xarray()
96+
logger.info(ds)
97+
98+
# TODO: test concat multiple samples
99+
100+
75101
def load_metrics(run_id):
76102
"""Helper function to load metrics"""
77103
file_path = f"{weathergen_home}/results/{run_id}/metrics.json"

packages/common/pyproject.toml

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,18 @@ dev = [
1616

1717

1818

19+
[tool.black]
20+
21+
# Wide rows
22+
line-length = 100
23+
24+
25+
# The linting configuration
26+
[tool.ruff]
27+
28+
# Wide rows
29+
line-length = 100
30+
1931
[tool.ruff.lint]
2032
# All disabled until the code is formatted.
2133
select = [
@@ -31,24 +43,26 @@ select = [
3143
"SIM",
3244
# isort
3345
"I",
46+
# Banned imports
47+
"TID"
3448
]
3549

3650
# These rules are sensible and should be enabled at a later stage.
3751
ignore = [
38-
"E501",
39-
"E721",
40-
"E722",
4152
# "B006",
4253
"B011",
4354
"UP008",
4455
"SIM117",
4556
"SIM118",
4657
"SIM102",
4758
"SIM401",
59+
"UP040", # TODO: enable later
4860
# To ignore, not relevant for us
49-
"E741",
61+
"SIM108" # in case additional norm layer supports are added in future
5062
]
5163

64+
65+
5266
[build-system]
5367
requires = ["hatchling"]
5468
build-backend = "hatchling.build"

0 commit comments

Comments
 (0)