diff --git a/models/src/anemoi/models/layers/activations.py b/models/src/anemoi/models/layers/activations.py new file mode 100644 index 000000000..cfd80ae26 --- /dev/null +++ b/models/src/anemoi/models/layers/activations.py @@ -0,0 +1,41 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import torch + + +def leaky_hardtanh( + input: torch.Tensor, + min_val: float = -1.0, + max_val: float = 1.0, + negative_slope: float = 0.01, + positive_slope: float = 0.01, +) -> torch.Tensor: + """Leaky version of hardtanh where regions outside [min_val, max_val] have small non-zero slopes. + + Args: + input: Input tensor + min_val: Minimum value for the hardtanh region + max_val: Maximum value for the hardtanh region + negative_slope: Slope for values below min_val + positive_slope: Slope for values above max_val + + Returns: + Tensor with leaky hardtanh applied + """ + below_min = input < min_val + above_max = input > max_val + # Standard hardtanh behavior for the middle region + result = torch.clamp(input, min_val, max_val) + # Add leaky behavior for regions outside the clamped range + result = torch.where(below_min, min_val + negative_slope * (input - min_val), result) + result = torch.where(above_max, max_val + positive_slope * (input - max_val), result) + return result diff --git a/models/src/anemoi/models/layers/bounding.py b/models/src/anemoi/models/layers/bounding.py index 54b5846d0..d68f676c8 100644 --- a/models/src/anemoi/models/layers/bounding.py +++ b/models/src/anemoi/models/layers/bounding.py @@ -17,6 +17,7 @@ from torch import nn from anemoi.models.data_indices.tensor import InputTensorIndex +from anemoi.models.layers.activations import leaky_hardtanh class BaseBounding(nn.Module, ABC): @@ -82,6 +83,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class LeakyReluBounding(BaseBounding): + """Initializes the bounding with a Leaky ReLU activation / zero clamping.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x[..., self.data_index] = torch.nn.functional.leaky_relu(x[..., self.data_index]) + return x + + class NormalizedReluBounding(BaseBounding): """Bounding variable with a ReLU activation and customizable normalized thresholds.""" @@ -175,6 +184,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class LeakyNormalizedReluBounding(NormalizedReluBounding): + """Initializes the bounding with a Leaky ReLU activation and customizable normalized thresholds.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x[..., self.data_index] = ( + torch.nn.functional.leaky_relu(x[..., self.data_index] - self.norm_min_val) + self.norm_min_val + ) + return x + + class HardtanhBounding(BaseBounding): """Initializes the bounding with specified minimum and maximum values for bounding. @@ -188,6 +207,10 @@ class HardtanhBounding(BaseBounding): The minimum value for the HardTanh activation. max_val : float The maximum value for the HardTanh activation. + statistics : dict, optional + A dictionary containing the statistics of the variables. + name_to_index_stats : dict, optional + A dictionary mapping the variable names to their corresponding indices in the statistics dictionary. """ def __init__( @@ -211,7 +234,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x -class FractionBounding(HardtanhBounding): +class LeakyHardtanhBounding(HardtanhBounding): + """Initializes the bounding with a Leaky HardTanh activation.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x[..., self.data_index] = leaky_hardtanh(x[..., self.data_index], min_val=self.min_val, max_val=self.max_val) + return x + + +class FractionBounding(BaseBounding): """Initializes the FractionBounding with specified parameters. Parameters @@ -227,6 +258,10 @@ class FractionBounding(HardtanhBounding): total_var : str A string representing a variable from which a secondary variable is derived. For example, in the case of convective precipitation (Cp), total_var = Tp (total precipitation). + statistics : dict, optional + A dictionary containing the statistics of the variables. + name_to_index_stats : dict, optional + A dictionary mapping the variable names to their corresponding indices in the statistics dictionary. """ def __init__( @@ -240,12 +275,27 @@ def __init__( statistics: Optional[dict] = None, name_to_index_stats: Optional[dict] = None, ) -> None: - super().__init__(variables=variables, name_to_index=name_to_index, min_val=min_val, max_val=max_val) + super().__init__(variables=variables, name_to_index=name_to_index) + self.min_val = min_val + self.max_val = max_val self.total_variable = self._create_index(variables=[total_var]) def forward(self, x: torch.Tensor) -> torch.Tensor: # Apply the HardTanh bounding to the data_index variables - x = super().forward(x) + x[..., self.data_index] = torch.nn.functional.hardtanh( + x[..., self.data_index], min_val=self.min_val, max_val=self.max_val + ) + # Calculate the fraction of the total variable + x[..., self.data_index] *= x[..., self.total_variable] + return x + + +class LeakyFractionBounding(FractionBounding): + """Initializes the bounding with a Leaky HardTanh activation and a fraction of the total variable.""" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Apply the LeakyHardTanh bounding to the data_index variables + x[..., self.data_index] = leaky_hardtanh(x[..., self.data_index], min_val=self.min_val, max_val=self.max_val) # Calculate the fraction of the total variable x[..., self.data_index] *= x[..., self.total_variable] return x diff --git a/models/tests/layers/test_bounding.py b/models/tests/layers/test_bounding.py index a7d8b4206..ed38fcb28 100644 --- a/models/tests/layers/test_bounding.py +++ b/models/tests/layers/test_bounding.py @@ -14,6 +14,10 @@ from anemoi.models.layers.bounding import FractionBounding from anemoi.models.layers.bounding import HardtanhBounding +from anemoi.models.layers.bounding import LeakyFractionBounding +from anemoi.models.layers.bounding import LeakyHardtanhBounding +from anemoi.models.layers.bounding import LeakyNormalizedReluBounding +from anemoi.models.layers.bounding import LeakyReluBounding from anemoi.models.layers.bounding import NormalizedReluBounding from anemoi.models.layers.bounding import ReluBounding from anemoi.utils.config import DotDict @@ -108,18 +112,28 @@ def test_multi_chained_bounding(config, name_to_index, input_tensor): assert torch.equal(output, expected_output) -def test_hydra_instantiate_bounding(config, name_to_index, input_tensor): +def test_hydra_instantiate_bounding(config, name_to_index, name_to_index_stats, input_tensor, statistics): layer_definitions = [ { "_target_": "anemoi.models.layers.bounding.ReluBounding", "variables": config.variables, }, + { + "_target_": "anemoi.models.layers.bounding.LeakyReluBounding", + "variables": config.variables, + }, { "_target_": "anemoi.models.layers.bounding.HardtanhBounding", "variables": config.variables, "min_val": 0.0, "max_val": 1.0, }, + { + "_target_": "anemoi.models.layers.bounding.LeakyHardtanhBounding", + "variables": config.variables, + "min_val": 0.0, + "max_val": 1.0, + }, { "_target_": "anemoi.models.layers.bounding.FractionBounding", "variables": config.variables, @@ -127,7 +141,124 @@ def test_hydra_instantiate_bounding(config, name_to_index, input_tensor): "max_val": 1.0, "total_var": config.total_var, }, + { + "_target_": "anemoi.models.layers.bounding.LeakyFractionBounding", + "variables": config.variables, + "min_val": 0.0, + "max_val": 1.0, + "total_var": config.total_var, + }, + { + "_target_": "anemoi.models.layers.bounding.LeakyNormalizedReluBounding", + "variables": config.variables, + "min_val": [2.0, 2.0], + "normalizer": ["mean-std", "min-max"], + "statistics": statistics, + "name_to_index_stats": name_to_index_stats, + }, ] for layer_definition in layer_definitions: bounding = instantiate(layer_definition, name_to_index=name_to_index) bounding(input_tensor.clone()) + + +def test_leaky_relu_bounding(config, name_to_index, input_tensor): + bounding = LeakyReluBounding(variables=config.variables, name_to_index=name_to_index) + output = bounding(input_tensor.clone()) + # LeakyReLU should keep negative values but scale them by 0.01 (default negative_slope) + expected_output = torch.tensor([[-0.01, 2.0, 3.0], [4.0, -0.05, 6.0], [0.5, 0.5, 0.5]]) + assert torch.allclose(output, expected_output, atol=1e-4) + + +def test_leaky_hardtanh_bounding(config, name_to_index, input_tensor): + minimum, maximum = -1.0, 1.0 + bounding = LeakyHardtanhBounding( + variables=config.variables, name_to_index=name_to_index, min_val=minimum, max_val=maximum + ) + output = bounding(input_tensor.clone()) + # Values below min_val should be min_val + 0.01 * (input - min_val) + # Values above max_val should be max_val + 0.01 * (input - max_val) + expected_output = torch.tensor( + [ + [minimum + 0.01 * (-1.0 - minimum), maximum + 0.01 * (2.0 - maximum), 3.0], + [maximum + 0.01 * (4.0 - maximum), minimum + 0.01 * (-5.0 - minimum), 6.0], + [0.5, 0.5, 0.5], + ] + ) + assert torch.allclose(output, expected_output, atol=1e-4) + + +def test_leaky_fraction_bounding(config, name_to_index, input_tensor): + bounding = LeakyFractionBounding( + variables=config.variables, name_to_index=name_to_index, min_val=0.0, max_val=1.0, total_var=config.total_var + ) + output = bounding(input_tensor.clone()) + # First apply leaky hardtanh, then multiply by total_var + expected_output = torch.tensor( + [ + [-0.03, 3.03, 3.0], # [-1, 2, 3] -> [leaky(0), leaky(1), 3] -> [leaky(0)*3, leaky(1)*3, 3] + [6.18, -0.3, 6.0], # [4, -5, 6] -> [leaky(1), leaky(0), 6] -> [leaky(1)*6, leaky(0)*6, 6] + [0.25, 0.25, 0.5], # [0.5, 0.5, 0.5] -> [0.5, 0.5, 0.5] -> [0.5*0.5, 0.5*0.5, 0.5] + ] + ) + assert torch.allclose(output, expected_output, atol=1e-4) + + +def test_multi_chained_bounding_with_leaky(config, name_to_index, input_tensor): + # Apply LeakyReLU first on the first variable only + bounding1 = LeakyReluBounding(variables=config.variables[:-1], name_to_index=name_to_index) + expected_output = torch.tensor([[-0.01, 2.0, 3.0], [4.0, -5.0, 6.0], [0.5, 0.5, 0.5]]) + # Check intermediate result + assert torch.allclose(bounding1(input_tensor.clone()), expected_output, atol=1e-4) + + minimum, maximum = 0.5, 1.75 + bounding2 = LeakyHardtanhBounding( + variables=config.variables, name_to_index=name_to_index, min_val=minimum, max_val=maximum + ) + # Use full chaining on the input tensor + output = bounding2(bounding1(input_tensor.clone())) + # Data with LeakyReLU applied first and then LeakyHardtanh + expected_output = torch.tensor( + [ + [minimum + 0.01 * (-0.01 - minimum), maximum + 0.01 * (2.0 - maximum), 3.0], + [maximum + 0.01 * (4.0 - maximum), minimum + 0.01 * (-5.0 - minimum), 6.0], + [0.5, 0.5, 0.5], + ] + ) + assert torch.allclose(output, expected_output, atol=1e-4) + + +def test_leaky_normalized_relu_bounding(config, name_to_index, name_to_index_stats, input_tensor, statistics): + bounding = LeakyNormalizedReluBounding( + variables=config.variables, + name_to_index=name_to_index, + min_val=[2.0, 2.0], + normalizer=["mean-std", "min-max"], + statistics=statistics, + name_to_index_stats=name_to_index_stats, + ) + output = bounding(input_tensor.clone()) + + # For mean-std normalization: + # normalized = (input - mean) / stdev + # For min-max normalization: + # normalized = (input - min) / (max - min) + + # First variable (mean-std): + # [-1, 4, 0.5] -> [(-1-1)/0.5, (4-1)/0.5, (0.5-1)/0.5] = [-4, 6, -1] + # Then leaky_relu: [-4, 6, -1] -> [-4*0.01, 6, -1*0.01] = [-0.04, 6, -0.01] + # Then add min_val: [-0.04+2, 6+2, -0.01+2] = [1.96, 8, 1.99] + + # Second variable (min-max): + # [2, -5, 0.5] -> [(2-1)/(10-1), (-5-1)/(10-1), (0.5-1)/(10-1)] = [0.111, -0.667, -0.056] + # Then leaky_relu: [0.111, -0.667, -0.056] -> [0.111, -0.667*0.01, -0.056*0.01] = [0.111, -0.00667, -0.00056] + # Then add min_val: [0.111+2, -0.00667+2, -0.00056+2] = [2.111, 1.993, 1.999] + + expected_output = torch.tensor( + [ + [1.97, 2.0, 3.0], # [-1, 2, 3] -> [1.97, 2.0, 3.0] + [4.0, 0.06, 6.0], # [4, -5, 6] -> [4.0, 0.06, 6.0] + [1.985, 0.5, 0.5], # [0.5, 0.5, 0.5] -> [1.985, 0.5, 0.5] + ] + ) + assert torch.allclose(output, expected_output, atol=1e-4) diff --git a/training/src/anemoi/training/schemas/models/models.py b/training/src/anemoi/training/schemas/models/models.py index 9f9e5d166..f537a79b0 100644 --- a/training/src/anemoi/training/schemas/models/models.py +++ b/training/src/anemoi/training/schemas/models/models.py @@ -65,6 +65,11 @@ class ReluBoundingSchema(BaseModel): "List of variables to bound using the Relu method." +class LeakyReluBoundingSchema(ReluBoundingSchema): + target_: Literal["anemoi.models.layers.bounding.LeakyReluBounding"] = Field(..., alias="_target_") + "Leaky Relu bounding object defined in anemoi.models.layers.bounding." + + class FractionBoundingSchema(BaseModel): target_: Literal["anemoi.models.layers.bounding.FractionBounding"] = Field(..., alias="_target_") "Fraction bounding object defined in anemoi.models.layers.bounding." @@ -79,6 +84,11 @@ class FractionBoundingSchema(BaseModel): For example, convective precipitation should be a fraction of total precipitation." +class LeakyFractionBoundingSchema(FractionBoundingSchema): + target_: Literal["anemoi.models.layers.bounding.LeakyFractionBounding"] = Field(..., alias="_target_") + "Leaky fraction bounding object defined in anemoi.models.layers.bounding." + + class HardtanhBoundingSchema(BaseModel): target_: Literal["anemoi.models.layers.bounding.HardtanhBounding"] = Field(..., alias="_target_") "Hard tanh bounding method function from anemoi.models.layers.bounding." @@ -90,6 +100,11 @@ class HardtanhBoundingSchema(BaseModel): "The maximum value for the HardTanh activation." +class LeakyHardtanhBoundingSchema(HardtanhBoundingSchema): + target_: Literal["anemoi.models.layers.bounding.LeakyHardtanhBounding"] = Field(..., alias="_target_") + "Leaky hard tanh bounding method function from anemoi.models.layers.bounding." + + class NormalizedReluBoundingSchema(BaseModel): target_: Literal["anemoi.models.layers.bounding.NormalizedReluBounding"] = Field(..., alias="_target_") variables: list[str] @@ -107,8 +122,22 @@ def check_num_normalizers_and_min_val_matches_num_variables(self) -> NormalizedR return self +class LeakyNormalizedReluBoundingSchema(NormalizedReluBoundingSchema): + target_: Literal["anemoi.models.layers.bounding.LeakyNormalizedReluBounding"] = Field(..., alias="_target_") + "Leaky normalized Relu bounding object defined in anemoi.models.layers.bounding." + + Bounding = Annotated[ - Union[ReluBoundingSchema, FractionBoundingSchema, HardtanhBoundingSchema, NormalizedReluBoundingSchema], + Union[ + ReluBoundingSchema, + LeakyReluBoundingSchema, + FractionBoundingSchema, + LeakyFractionBoundingSchema, + HardtanhBoundingSchema, + LeakyHardtanhBoundingSchema, + NormalizedReluBoundingSchema, + LeakyNormalizedReluBoundingSchema, + ], Field(discriminator="target_"), ]