Skip to content

Commit 57828ac

Browse files
clessigChristian Lessig
andauthored
Fixed path handling to have multiple paths for different stream types. (#50)
Co-authored-by: Christian Lessig <[email protected]>
1 parent 664b663 commit 57828ac

File tree

4 files changed

+19
-17
lines changed

4 files changed

+19
-17
lines changed

src/weathergen/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def train(run_id=None) -> None:
7575
# directory where input streams are specified
7676
# cf.streams_directory = './streams_large/'
7777
cf.streams_directory = "./config/streams/streams_anemoi/"
78+
# cf.streams_directory = "./config/streams/streams_mixed/"
7879
# cf.streams_directory = "./streams_mixed/"
7980

8081
# embed_orientation : 'channels' or 'columns'
@@ -181,10 +182,11 @@ def train(run_id=None) -> None:
181182
cf.norm_type = "LayerNorm" #'LayerNorm' #'RMSNorm'
182183
cf.nn_module = "te"
183184

184-
cf.data_path = private_cf["data_path"]
185-
# "/home/mlx/ai-ml/datasets/stable/"
186-
# cf.data_path = '/lus/h2resw01/fws4/lb/project/ai-ml/observations/v1'
187-
# cf.data_path = '/leonardo_scratch/large/userexternal/clessig0/obs/v1'
185+
# merge private config
186+
for k, v in private_cf.items():
187+
setattr(cf, k, v)
188+
cf.data_path = private_cf["data_path_anemoi"] # for backward compatibility
189+
188190
cf.start_date = 201301010000
189191
cf.end_date = 202012310000
190192
cf.start_date_val = 202101010000

src/weathergen/datasets/multi_stream_data_sampler.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class MultiStreamDataSampler(torch.utils.data.IterableDataset):
3131
###################################################
3232
def __init__(
3333
self,
34-
data_path,
34+
cf,
3535
rank,
3636
num_ranks,
3737
streams,
@@ -91,9 +91,8 @@ def __init__(
9191
for fname in stream_info["filenames"]:
9292
ds = None
9393
if stream_info["type"] == "obs":
94-
c_data_path = "/gpfs/scratch/ehpc01/dop/v1/"
9594
ds = ObsDataset(
96-
c_data_path + "/" + fname,
95+
cf.data_path_obs + "/" + fname,
9796
start_date,
9897
end_date_padded,
9998
len_hrs,
@@ -127,12 +126,13 @@ def __init__(
127126
stats_offset = 0
128127

129128
elif stream_info["type"] == "anemoi":
130-
c_data_path = data_path
131-
if "CERRA" in stream_info["name"]:
132-
c_data_path = "/gpfs/scratch/ehpc03/weathergen/"
133-
134129
ds = AnemoiDataset(
135-
c_data_path + "/" + fname, start_date, end_date, len_hrs, step_hrs, False
130+
cf.data_path_anemoi + "/" + fname,
131+
start_date,
132+
end_date,
133+
len_hrs,
134+
step_hrs,
135+
False,
136136
)
137137
do = 0
138138
geoinfo_idx = [0, 1]

src/weathergen/train/trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def evaluate(self, cf, run_id_trained, epoch, run_id_new=False):
9999
self.init(cf, run_id_trained, epoch, run_id_new, run_mode="evaluate")
100100

101101
self.dataset_val = MultiStreamDataSampler(
102-
cf.data_path,
102+
cf,
103103
cf.rank,
104104
cf.num_ranks,
105105
cf.streams,
@@ -164,7 +164,7 @@ def evaluate_jac(self, cf, run_id, epoch, mode="row", date=None, obs_id=0, sampl
164164
self.init(cf, run_id, epoch, run_id_new=True, run_mode="offline")
165165

166166
self.dataset = MultiStreamDataSampler(
167-
cf.streams,
167+
cf,
168168
cf.start_date_val,
169169
cf.end_date_val,
170170
cf.delta_time,
@@ -298,7 +298,7 @@ def run(self, cf, private_cf, run_id_contd=None, epoch_contd=None, run_id_new=Fa
298298
self.init(cf, run_id_contd, epoch_contd, run_id_new)
299299

300300
self.dataset = MultiStreamDataSampler(
301-
cf.data_path,
301+
cf,
302302
cf.rank,
303303
cf.num_ranks,
304304
cf.streams,
@@ -324,7 +324,7 @@ def run(self, cf, private_cf, run_id_contd=None, epoch_contd=None, run_id_new=Fa
324324
sampling_rate_target=cf.sampling_rate_target,
325325
)
326326
self.dataset_val = MultiStreamDataSampler(
327-
cf.data_path,
327+
cf,
328328
cf.rank,
329329
cf.num_ranks,
330330
cf.streams,

src/weathergen/utils/validation_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def write_validation(
127127
ds_source.create_dataset("targets_lens", data=targets_lens_k)
128128
else:
129129
rn = rn + f"/{fs}"
130-
if source_lens_k.sum() > 0 :
130+
if source_lens_k.sum() > 0:
131131
ds[f"{rn}/sources"].append(source_k)
132132
ds[f"{rn}/sources_lens"].append(source_lens_k)
133133
ds[f"{rn}/preds"].append(preds_k)

0 commit comments

Comments
 (0)