Skip to content

feat(models): adding leaky boundings #256

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions models/src/anemoi/models/layers/activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from __future__ import annotations

import torch


def leaky_hardtanh(
input: torch.Tensor,
min_val: float = -1.0,
max_val: float = 1.0,
negative_slope: float = 0.01,
positive_slope: float = 0.01,
) -> torch.Tensor:
"""Leaky version of hardtanh where regions outside [min_val, max_val] have small non-zero slopes.

Args:
input: Input tensor
min_val: Minimum value for the hardtanh region
max_val: Maximum value for the hardtanh region
negative_slope: Slope for values below min_val
positive_slope: Slope for values above max_val

Returns:
Tensor with leaky hardtanh applied
"""
below_min = input < min_val
above_max = input > max_val
# Standard hardtanh behavior for the middle region
result = torch.clamp(input, min_val, max_val)
# Add leaky behavior for regions outside the clamped range
result = torch.where(below_min, min_val + negative_slope * (input - min_val), result)
result = torch.where(above_max, max_val + positive_slope * (input - max_val), result)
return result
56 changes: 53 additions & 3 deletions models/src/anemoi/models/layers/bounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from torch import nn

from anemoi.models.data_indices.tensor import InputTensorIndex
from anemoi.models.layers.activations import leaky_hardtanh


class BaseBounding(nn.Module, ABC):
Expand Down Expand Up @@ -82,6 +83,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class LeakyReluBounding(BaseBounding):
"""Initializes the bounding with a Leaky ReLU activation / zero clamping."""

def forward(self, x: torch.Tensor) -> torch.Tensor:
x[..., self.data_index] = torch.nn.functional.leaky_relu(x[..., self.data_index])
return x


class NormalizedReluBounding(BaseBounding):
"""Bounding variable with a ReLU activation and customizable normalized thresholds."""

Expand Down Expand Up @@ -175,6 +184,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class LeakyNormalizedReluBounding(NormalizedReluBounding):
"""Initializes the bounding with a Leaky ReLU activation and customizable normalized thresholds."""

def forward(self, x: torch.Tensor) -> torch.Tensor:
x[..., self.data_index] = (
torch.nn.functional.leaky_relu(x[..., self.data_index] - self.norm_min_val) + self.norm_min_val
)
return x


class HardtanhBounding(BaseBounding):
"""Initializes the bounding with specified minimum and maximum values for bounding.

Expand All @@ -188,6 +207,10 @@ class HardtanhBounding(BaseBounding):
The minimum value for the HardTanh activation.
max_val : float
The maximum value for the HardTanh activation.
statistics : dict, optional
A dictionary containing the statistics of the variables.
name_to_index_stats : dict, optional
A dictionary mapping the variable names to their corresponding indices in the statistics dictionary.
"""

def __init__(
Expand All @@ -211,7 +234,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


class FractionBounding(HardtanhBounding):
class LeakyHardtanhBounding(HardtanhBounding):
"""Initializes the bounding with a Leaky HardTanh activation."""

def forward(self, x: torch.Tensor) -> torch.Tensor:
x[..., self.data_index] = leaky_hardtanh(x[..., self.data_index], min_val=self.min_val, max_val=self.max_val)
return x


class FractionBounding(BaseBounding):
"""Initializes the FractionBounding with specified parameters.

Parameters
Expand All @@ -227,6 +258,10 @@ class FractionBounding(HardtanhBounding):
total_var : str
A string representing a variable from which a secondary variable is derived. For
example, in the case of convective precipitation (Cp), total_var = Tp (total precipitation).
statistics : dict, optional
A dictionary containing the statistics of the variables.
name_to_index_stats : dict, optional
A dictionary mapping the variable names to their corresponding indices in the statistics dictionary.
"""

def __init__(
Expand All @@ -240,12 +275,27 @@ def __init__(
statistics: Optional[dict] = None,
name_to_index_stats: Optional[dict] = None,
) -> None:
super().__init__(variables=variables, name_to_index=name_to_index, min_val=min_val, max_val=max_val)
super().__init__(variables=variables, name_to_index=name_to_index)
self.min_val = min_val
self.max_val = max_val
self.total_variable = self._create_index(variables=[total_var])

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Apply the HardTanh bounding to the data_index variables
x = super().forward(x)
x[..., self.data_index] = torch.nn.functional.hardtanh(
x[..., self.data_index], min_val=self.min_val, max_val=self.max_val
)
# Calculate the fraction of the total variable
x[..., self.data_index] *= x[..., self.total_variable]
return x


class LeakyFractionBounding(FractionBounding):
"""Initializes the bounding with a Leaky HardTanh activation and a fraction of the total variable."""

def forward(self, x: torch.Tensor) -> torch.Tensor:
# Apply the LeakyHardTanh bounding to the data_index variables
x[..., self.data_index] = leaky_hardtanh(x[..., self.data_index], min_val=self.min_val, max_val=self.max_val)
# Calculate the fraction of the total variable
x[..., self.data_index] *= x[..., self.total_variable]
return x
133 changes: 132 additions & 1 deletion models/tests/layers/test_bounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@

from anemoi.models.layers.bounding import FractionBounding
from anemoi.models.layers.bounding import HardtanhBounding
from anemoi.models.layers.bounding import LeakyFractionBounding
from anemoi.models.layers.bounding import LeakyHardtanhBounding
from anemoi.models.layers.bounding import LeakyNormalizedReluBounding
from anemoi.models.layers.bounding import LeakyReluBounding
from anemoi.models.layers.bounding import NormalizedReluBounding
from anemoi.models.layers.bounding import ReluBounding
from anemoi.utils.config import DotDict
Expand Down Expand Up @@ -108,26 +112,153 @@ def test_multi_chained_bounding(config, name_to_index, input_tensor):
assert torch.equal(output, expected_output)


def test_hydra_instantiate_bounding(config, name_to_index, input_tensor):
def test_hydra_instantiate_bounding(config, name_to_index, name_to_index_stats, input_tensor, statistics):
layer_definitions = [
{
"_target_": "anemoi.models.layers.bounding.ReluBounding",
"variables": config.variables,
},
{
"_target_": "anemoi.models.layers.bounding.LeakyReluBounding",
"variables": config.variables,
},
{
"_target_": "anemoi.models.layers.bounding.HardtanhBounding",
"variables": config.variables,
"min_val": 0.0,
"max_val": 1.0,
},
{
"_target_": "anemoi.models.layers.bounding.LeakyHardtanhBounding",
"variables": config.variables,
"min_val": 0.0,
"max_val": 1.0,
},
{
"_target_": "anemoi.models.layers.bounding.FractionBounding",
"variables": config.variables,
"min_val": 0.0,
"max_val": 1.0,
"total_var": config.total_var,
},
{
"_target_": "anemoi.models.layers.bounding.LeakyFractionBounding",
"variables": config.variables,
"min_val": 0.0,
"max_val": 1.0,
"total_var": config.total_var,
},
{
"_target_": "anemoi.models.layers.bounding.LeakyNormalizedReluBounding",
"variables": config.variables,
"min_val": [2.0, 2.0],
"normalizer": ["mean-std", "min-max"],
"statistics": statistics,
"name_to_index_stats": name_to_index_stats,
},
]
for layer_definition in layer_definitions:
bounding = instantiate(layer_definition, name_to_index=name_to_index)
bounding(input_tensor.clone())


def test_leaky_relu_bounding(config, name_to_index, input_tensor):
bounding = LeakyReluBounding(variables=config.variables, name_to_index=name_to_index)
output = bounding(input_tensor.clone())
# LeakyReLU should keep negative values but scale them by 0.01 (default negative_slope)
expected_output = torch.tensor([[-0.01, 2.0, 3.0], [4.0, -0.05, 6.0], [0.5, 0.5, 0.5]])
assert torch.allclose(output, expected_output, atol=1e-4)


def test_leaky_hardtanh_bounding(config, name_to_index, input_tensor):
minimum, maximum = -1.0, 1.0
bounding = LeakyHardtanhBounding(
variables=config.variables, name_to_index=name_to_index, min_val=minimum, max_val=maximum
)
output = bounding(input_tensor.clone())
# Values below min_val should be min_val + 0.01 * (input - min_val)
# Values above max_val should be max_val + 0.01 * (input - max_val)
expected_output = torch.tensor(
[
[minimum + 0.01 * (-1.0 - minimum), maximum + 0.01 * (2.0 - maximum), 3.0],
[maximum + 0.01 * (4.0 - maximum), minimum + 0.01 * (-5.0 - minimum), 6.0],
[0.5, 0.5, 0.5],
]
)
assert torch.allclose(output, expected_output, atol=1e-4)


def test_leaky_fraction_bounding(config, name_to_index, input_tensor):
bounding = LeakyFractionBounding(
variables=config.variables, name_to_index=name_to_index, min_val=0.0, max_val=1.0, total_var=config.total_var
)
output = bounding(input_tensor.clone())
# First apply leaky hardtanh, then multiply by total_var
expected_output = torch.tensor(
[
[-0.03, 3.03, 3.0], # [-1, 2, 3] -> [leaky(0), leaky(1), 3] -> [leaky(0)*3, leaky(1)*3, 3]
[6.18, -0.3, 6.0], # [4, -5, 6] -> [leaky(1), leaky(0), 6] -> [leaky(1)*6, leaky(0)*6, 6]
[0.25, 0.25, 0.5], # [0.5, 0.5, 0.5] -> [0.5, 0.5, 0.5] -> [0.5*0.5, 0.5*0.5, 0.5]
]
)
assert torch.allclose(output, expected_output, atol=1e-4)


def test_multi_chained_bounding_with_leaky(config, name_to_index, input_tensor):
# Apply LeakyReLU first on the first variable only
bounding1 = LeakyReluBounding(variables=config.variables[:-1], name_to_index=name_to_index)
expected_output = torch.tensor([[-0.01, 2.0, 3.0], [4.0, -5.0, 6.0], [0.5, 0.5, 0.5]])
# Check intermediate result
assert torch.allclose(bounding1(input_tensor.clone()), expected_output, atol=1e-4)

minimum, maximum = 0.5, 1.75
bounding2 = LeakyHardtanhBounding(
variables=config.variables, name_to_index=name_to_index, min_val=minimum, max_val=maximum
)
# Use full chaining on the input tensor
output = bounding2(bounding1(input_tensor.clone()))
# Data with LeakyReLU applied first and then LeakyHardtanh
expected_output = torch.tensor(
[
[minimum + 0.01 * (-0.01 - minimum), maximum + 0.01 * (2.0 - maximum), 3.0],
[maximum + 0.01 * (4.0 - maximum), minimum + 0.01 * (-5.0 - minimum), 6.0],
[0.5, 0.5, 0.5],
]
)
assert torch.allclose(output, expected_output, atol=1e-4)


def test_leaky_normalized_relu_bounding(config, name_to_index, name_to_index_stats, input_tensor, statistics):
bounding = LeakyNormalizedReluBounding(
variables=config.variables,
name_to_index=name_to_index,
min_val=[2.0, 2.0],
normalizer=["mean-std", "min-max"],
statistics=statistics,
name_to_index_stats=name_to_index_stats,
)
output = bounding(input_tensor.clone())

# For mean-std normalization:
# normalized = (input - mean) / stdev
# For min-max normalization:
# normalized = (input - min) / (max - min)

# First variable (mean-std):
# [-1, 4, 0.5] -> [(-1-1)/0.5, (4-1)/0.5, (0.5-1)/0.5] = [-4, 6, -1]
# Then leaky_relu: [-4, 6, -1] -> [-4*0.01, 6, -1*0.01] = [-0.04, 6, -0.01]
# Then add min_val: [-0.04+2, 6+2, -0.01+2] = [1.96, 8, 1.99]

# Second variable (min-max):
# [2, -5, 0.5] -> [(2-1)/(10-1), (-5-1)/(10-1), (0.5-1)/(10-1)] = [0.111, -0.667, -0.056]
# Then leaky_relu: [0.111, -0.667, -0.056] -> [0.111, -0.667*0.01, -0.056*0.01] = [0.111, -0.00667, -0.00056]
# Then add min_val: [0.111+2, -0.00667+2, -0.00056+2] = [2.111, 1.993, 1.999]

expected_output = torch.tensor(
[
[1.97, 2.0, 3.0], # [-1, 2, 3] -> [1.97, 2.0, 3.0]
[4.0, 0.06, 6.0], # [4, -5, 6] -> [4.0, 0.06, 6.0]
[1.985, 0.5, 0.5], # [0.5, 0.5, 0.5] -> [1.985, 0.5, 0.5]
]
)
assert torch.allclose(output, expected_output, atol=1e-4)
31 changes: 30 additions & 1 deletion training/src/anemoi/training/schemas/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ class ReluBoundingSchema(BaseModel):
"List of variables to bound using the Relu method."


class LeakyReluBoundingSchema(ReluBoundingSchema):
target_: Literal["anemoi.models.layers.bounding.LeakyReluBounding"] = Field(..., alias="_target_")
"Leaky Relu bounding object defined in anemoi.models.layers.bounding."


class FractionBoundingSchema(BaseModel):
target_: Literal["anemoi.models.layers.bounding.FractionBounding"] = Field(..., alias="_target_")
"Fraction bounding object defined in anemoi.models.layers.bounding."
Expand All @@ -79,6 +84,11 @@ class FractionBoundingSchema(BaseModel):
For example, convective precipitation should be a fraction of total precipitation."


class LeakyFractionBoundingSchema(FractionBoundingSchema):
target_: Literal["anemoi.models.layers.bounding.LeakyFractionBounding"] = Field(..., alias="_target_")
"Leaky fraction bounding object defined in anemoi.models.layers.bounding."


class HardtanhBoundingSchema(BaseModel):
target_: Literal["anemoi.models.layers.bounding.HardtanhBounding"] = Field(..., alias="_target_")
"Hard tanh bounding method function from anemoi.models.layers.bounding."
Expand All @@ -90,6 +100,11 @@ class HardtanhBoundingSchema(BaseModel):
"The maximum value for the HardTanh activation."


class LeakyHardtanhBoundingSchema(HardtanhBoundingSchema):
target_: Literal["anemoi.models.layers.bounding.LeakyHardtanhBounding"] = Field(..., alias="_target_")
"Leaky hard tanh bounding method function from anemoi.models.layers.bounding."


class NormalizedReluBoundingSchema(BaseModel):
target_: Literal["anemoi.models.layers.bounding.NormalizedReluBounding"] = Field(..., alias="_target_")
variables: list[str]
Expand All @@ -107,8 +122,22 @@ def check_num_normalizers_and_min_val_matches_num_variables(self) -> NormalizedR
return self


class LeakyNormalizedReluBoundingSchema(NormalizedReluBoundingSchema):
target_: Literal["anemoi.models.layers.bounding.LeakyNormalizedReluBounding"] = Field(..., alias="_target_")
"Leaky normalized Relu bounding object defined in anemoi.models.layers.bounding."


Bounding = Annotated[
Union[ReluBoundingSchema, FractionBoundingSchema, HardtanhBoundingSchema, NormalizedReluBoundingSchema],
Union[
ReluBoundingSchema,
LeakyReluBoundingSchema,
FractionBoundingSchema,
LeakyFractionBoundingSchema,
HardtanhBoundingSchema,
LeakyHardtanhBoundingSchema,
NormalizedReluBoundingSchema,
LeakyNormalizedReluBoundingSchema,
],
Field(discriminator="target_"),
]

Expand Down