Skip to content

Commit 76f85cf

Browse files
committed
Add training code
1 parent fd3616d commit 76f85cf

File tree

9 files changed

+319
-38
lines changed

9 files changed

+319
-38
lines changed

docs/zh/examples/gencast.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,16 @@
55
- 下载目录`dm_graphcast/gencast/stats`下的所有文件放入`./data/stats/`目录下。
66
- 下载目录`dm_graphcast/gencast/dataset`下的任意或所有文件(例如:source-era5_date-2019-03-29_res-1.0_levels-13_steps-12.nc)放入`./data/dataset/`目录下。
77

8+
=== "模型训练命令"
9+
10+
``` sh
11+
# 设置路径到 PaddleScience/jointContribution 文件夹
12+
cd PaddleScience/jointContribution
13+
export PYTHONPATH=$PWD:$PYTHONPAT
14+
# 运行训练脚本
15+
python run_gencast.py mode=train
16+
```
17+
818
=== "模型评估命令"
919

1020
``` sh
@@ -15,7 +25,7 @@
1525
cd gencast/
1626
wget -nc https://paddle-org.bj.bcebos.com/paddlescience/models/gencast/gencast_params_GenCast-1p0deg-Mini-_2019.pdparams -P ./data/params/
1727
# 运行评估脚本
18-
python run_gencast.py
28+
python run_gencast.py mode=eval
1929
```
2030

2131
## 1. 背景简介

jointContribution/gencast/conf/gencast.yaml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ mean_path: data/stats/gencast_stats_mean_by_level.nc
2828
min_path: data/stats/gencast_stats_min_by_level.nc
2929
param_path: data/params/gencast_params_GenCast-1p0deg-Mini-_2019.pdparams
3030

31+
train:
32+
learning_rate: 0.001
33+
weight_decay: 0.1
34+
num_epochs: 2000000
35+
batch_size: 1
36+
snapshot_freq: 10
37+
3138
sampler_config:
3239
max_noise_level: 80.0
3340
min_noise_level: 0.03
@@ -63,8 +70,9 @@ denoiser_architecture_config:
6370
block_q_dkv: 512
6471
block_kv_dkv: 1024
6572
block_kv_dkv_compute: 1024
66-
ffw_winit_final_mult: 0.0
67-
attn_winit_final_mult: 0.0
73+
ffw_winit_final_mult: 1.0
74+
attn_winit_final_mult: 1.0
75+
attn_winit_mult: 2.0
6876
ffw_hidden: 2048
6977
mesh_node_dim: 186
7078
mesh_node_emb_dim: 512

jointContribution/gencast/denoiser.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""Support for wrapping a general Predictor to act as a Denoiser."""
1515

1616
import copy
17+
import math
1718
import os
1819
import pickle
1920
from typing import Optional
@@ -66,12 +67,16 @@ def __init__(
6667
self._num_frequencies = num_frequencies
6768
self._apply_log_first = apply_log_first
6869

69-
# 创建 MLP
70+
# Creating MLP
7071
layers = []
7172
input_size = 2 * num_frequencies
7273
num_layers = len(output_sizes)
7374
for i, output_size in enumerate(output_sizes):
74-
linear_layer = nn.Linear(input_size, output_size)
75+
limit = math.sqrt(6 / input_size)
76+
weight_attr = paddle.framework.ParamAttr(
77+
initializer=paddle.nn.initializer.Uniform(low=-limit, high=limit)
78+
)
79+
linear_layer = nn.Linear(input_size, output_size, weight_attr=weight_attr)
7580
layers.append(linear_layer)
7681
if i < num_layers - 1:
7782
layers.append(activation)
@@ -168,4 +173,12 @@ def forward(
168173
grid_node_outputs, noisy_targets
169174
)
170175

171-
return raw_predictions
176+
resolution = self.cfg.denoiser_architecture_config.resolution
177+
grid_lat = np.arange(-90.0, 90.0 + resolution, resolution).astype(np.float32)
178+
grid_lon = np.arange(0.0, 360.0, resolution).astype(np.float32)
179+
grid_shape = [grid_lat.shape[0], grid_lon.shape[0]]
180+
grid_outputs_lat_lon_leading = grid_node_outputs.reshape(
181+
grid_shape + grid_node_outputs.shape[1:]
182+
)
183+
184+
return raw_predictions, grid_outputs_lat_lon_leading

jointContribution/gencast/dpm_solver_plus_plus_2s.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def init_noise(template):
173173
mid_over_current = mid_noise_level / noise_level
174174
# x = xr.open_dataset('/workspace/workspace/graphcast/x.nc')
175175

176-
x_denoised = denoiser(noise_level, x)
176+
x_denoised, _ = denoiser(noise_level, x)
177177
# This turns out to be a convex combination of current and denoised x,
178178
# which isn't entirely apparent from the paper formulae:
179179
x_mid = (
@@ -182,7 +182,7 @@ def init_noise(template):
182182
)
183183

184184
next_over_current = next_noise_level / noise_level
185-
x_mid_denoised = denoiser(mid_noise_level, x_mid)
185+
x_mid_denoised, _ = denoiser(mid_noise_level, x_mid)
186186
x_next = (
187187
next_over_current.numpy() * x
188188
+ (1 - next_over_current.numpy()) * x_mid_denoised

jointContribution/gencast/gencast.py

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,13 @@
2424

2525
import denoiser
2626
import dpm_solver_plus_plus_2s
27+
import losses
28+
import numpy as np
29+
import paddle
2730
import paddle.nn as nn
31+
import samplers_utils
2832
import xarray as xr
33+
from graphcast import datasets
2934

3035

3136
class GenCast(nn.Layer):
@@ -54,6 +59,7 @@ def __init__(
5459
self._sampler_config = cfg.sampler_config
5560
self._sampler = None
5661
self._noise_config = cfg.noise_config
62+
self.cfg = cfg
5763

5864
def _c_in(self, noise_scale: xr.DataArray) -> xr.DataArray:
5965
"""Scaling applied to the noisy targets input to the underlying network."""
@@ -81,22 +87,95 @@ def _preconditioned_denoiser(
8187
) -> xr.Dataset:
8288
"""The preconditioned denoising function D from the paper (Eqn 7)."""
8389
# Convert xarray DataArray to Paddle tensor for operations
84-
raw_predictions = self._denoiser(
90+
raw_predictions, grid_node_outputs = self._denoiser(
8591
inputs=inputs,
8692
noisy_targets=noisy_targets * self._c_in(noise_levels),
8793
noise_levels=noise_levels,
8894
forcings=forcings,
8995
**kwargs
9096
)
9197

92-
return raw_predictions * self._c_out(
93-
noise_levels
94-
) + noisy_targets * self._c_skip(noise_levels)
98+
stacked_noisy_targets = datasets.dataset_to_stacked(noisy_targets)
99+
stacked_noisy_targets = stacked_noisy_targets.transpose("lat", "lon", ...)
100+
101+
out = grid_node_outputs * paddle.to_tensor(self._c_out(noise_levels).data)
102+
skip = paddle.to_tensor(
103+
stacked_noisy_targets.data * self._c_skip(noise_levels).data
104+
)
105+
grid_node_outputs = out + skip
106+
107+
return (
108+
raw_predictions * self._c_out(noise_levels)
109+
+ noisy_targets * self._c_skip(noise_levels),
110+
grid_node_outputs,
111+
)
112+
113+
def loss(
114+
self,
115+
inputs: xr.Dataset,
116+
targets: xr.Dataset,
117+
forcings: Optional[xr.Dataset] = None,
118+
):
119+
120+
if self._noise_config is None:
121+
raise ValueError("Noise config must be specified to train GenCast.")
122+
123+
grid_node_outputs, denoised_predictions, noise_levels = self.forward(
124+
inputs, targets, forcings
125+
)
126+
127+
loss, diagnostics = losses.weighted_mse_loss_from_xarray(
128+
grid_node_outputs,
129+
targets,
130+
# Weights are same as we used for GraphCast.
131+
per_variable_weights={
132+
# Any variables not specified here are weighted as 1.0.
133+
# A single-level variable, but an important headline variable
134+
# and also one which we have struggled to get good performance
135+
# on at short lead times, so leaving it weighted at 1.0, equal
136+
# to the multi-level variables:
137+
"2m_temperature": 1.0,
138+
# New single-level variables, which we don't weight too highly
139+
# to avoid hurting performance on other variables.
140+
"10m_u_component_of_wind": 0.1,
141+
"10m_v_component_of_wind": 0.1,
142+
"mean_sea_level_pressure": 0.1,
143+
"sea_surface_temperature": 0.1,
144+
"total_precipitation_12hr": 0.1,
145+
},
146+
)
147+
loss *= paddle.to_tensor(self._loss_weighting(noise_levels).data)
148+
return loss, diagnostics
95149

96150
def forward(self, inputs, targets_template, forcings=None, **kwargs):
151+
if self.cfg.mode == "eval":
152+
if self._sampler is None:
153+
self._sampler = dpm_solver_plus_plus_2s.Sampler(
154+
self._preconditioned_denoiser, **self._sampler_config
155+
)
156+
return self._sampler(inputs, targets_template, forcings, **kwargs)
157+
if self.cfg.mode == "train":
158+
# Sample noise levels:
159+
batch_size = inputs.sizes["batch"]
160+
noise_levels = xr.DataArray(
161+
data=samplers_utils.rho_inverse_cdf(
162+
min_value=self._noise_config.training_min_noise_level,
163+
max_value=self._noise_config.training_max_noise_level,
164+
rho=self._noise_config.training_noise_level_rho,
165+
cdf=np.random.uniform(size=(batch_size,)).astype("float32"),
166+
),
167+
dims=("batch",),
168+
)
169+
170+
# Sample noise and apply it to targets:
171+
noise = (
172+
samplers_utils.spherical_white_noise_like(targets_template)
173+
* noise_levels
174+
)
175+
176+
noisy_targets = targets_template + noise
97177

98-
if self._sampler is None:
99-
self._sampler = dpm_solver_plus_plus_2s.Sampler(
100-
self._preconditioned_denoiser, **self._sampler_config
178+
denoised_predictions, grid_node_outputs = self._preconditioned_denoiser(
179+
inputs, noisy_targets, noise_levels, forcings
101180
)
102-
return self._sampler(inputs, targets_template, forcings, **kwargs)
181+
return grid_node_outputs, denoised_predictions, noise_levels

jointContribution/gencast/run_gencast.py

Lines changed: 82 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,88 @@
2626
from omegaconf import DictConfig
2727

2828

29-
def crps(targets, predictions, bias_corrected=True):
30-
if predictions.sizes.get("sample", 1) < 2:
31-
raise ValueError("predictions must have dim 'sample' with size at least 2.")
32-
sum_dims = ["sample", "sample2"]
33-
preds2 = predictions.rename({"sample": "sample2"})
34-
num_samps = predictions.sizes["sample"]
35-
num_samps2 = (num_samps - 1) if bias_corrected else num_samps
36-
mean_abs_diff = np.abs(predictions - preds2).sum(dim=sum_dims, skipna=False) / (
37-
num_samps * num_samps2
29+
class CustomDataLoader(paddle.io.Dataset):
30+
def __init__(self, target_lead_times, cfg):
31+
super(CustomDataLoader, self).__init__()
32+
33+
self.target_lead_times = target_lead_times
34+
self.cfg = cfg
35+
36+
def __len__(self):
37+
# Return the number of time steps in target_lead_times
38+
return len(self.target_lead_times)
39+
40+
def __getitem__(self, index):
41+
# Select a specific time step
42+
time_step = self.target_lead_times[index]
43+
44+
# Multiply by 12 to get 'a'
45+
a = time_step * 12
46+
47+
# Create a string in the format 'ah'
48+
ah_str = f"{a}h"
49+
50+
# Update the config with this new 'ah' string
51+
self.cfg["target_lead_times"] = ah_str
52+
53+
# Call the ERA5Data function/class
54+
# Assuming ERA5Data is a function or class that processes this config
55+
data = datasets.ERA5Data(config=self.cfg)
56+
57+
return data
58+
59+
60+
def train(cfg: DictConfig):
61+
# Initialize the GenCast model with the given configuration.
62+
model = gencast.GenCast(cfg)
63+
model.train()
64+
65+
# set optimizer
66+
optimizer = paddle.optimizer.AdamW(
67+
parameters=model.parameters(),
68+
learning_rate=cfg.train.learning_rate,
69+
weight_decay=cfg.train.weight_decay,
3870
)
39-
mean_abs_err = (
40-
np.abs(targets - predictions).sum(dim="sample", skipna=False) / num_samps
71+
# Load the dataset using the given configuration.
72+
nc_dataset = xarray.open_dataset(cfg.data_path)
73+
time_total = len(nc_dataset.time.data)
74+
train_loader = CustomDataLoader(
75+
target_lead_times=list(range(1, time_total - 1)),
76+
cfg=cfg,
4177
)
42-
return mean_abs_err - 0.5 * mean_abs_diff
78+
79+
best_loss = float("inf")
80+
for epoch in range(cfg.train.num_epochs):
81+
epoch_loss = 0
82+
for dataset in train_loader:
83+
# Forward pass and compute loss
84+
loss, diagnostics = model.loss(
85+
dataset.inputs_template,
86+
dataset.targets_template,
87+
dataset.forcings_template,
88+
)
89+
# Backward pass and optimization
90+
loss.backward()
91+
optimizer.step()
92+
optimizer.clear_grad()
93+
94+
epoch_loss += loss.item()
95+
96+
# Average loss for the epoch
97+
epoch_loss /= len(train_loader)
98+
logging.info(f"Epoch {epoch}: Loss = {epoch_loss:.6f}")
99+
if epoch % cfg.train.snapshot_freq == 0 or epoch == 1:
100+
model_save_path = os.path.join(
101+
cfg.output_dir, f"last_model_epoch_{epoch}.pdparams"
102+
)
103+
paddle.save(model.state_dict(), model_save_path)
104+
105+
# Save model if it has the best loss
106+
if epoch_loss < best_loss:
107+
best_loss = epoch_loss
108+
model_save_path = os.path.join(cfg.output_dir, "best_model_epoch.pdparams")
109+
paddle.save(model.state_dict(), model_save_path)
110+
logging.info(f"Best model saved at epoch {epoch} with loss {best_loss:.6f}")
43111

44112

45113
def eval(cfg: DictConfig):
@@ -113,6 +181,8 @@ def eval(cfg: DictConfig):
113181
def main(cfg: DictConfig):
114182
if cfg.mode == "eval":
115183
eval(cfg)
184+
elif cfg.mode == "train":
185+
train(cfg)
116186
else:
117187
raise ValueError(f"cfg.mode should in ['eval'], but got '{cfg.mode}'")
118188

0 commit comments

Comments
 (0)