Skip to content

fix: support for netcdf missing values #214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ tmp/
temp/
logs/
_dev/
outputs
*tmp_data/

# Project specific
Expand Down
10 changes: 5 additions & 5 deletions src/anemoi/inference/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@
from typing import List
from typing import Optional

from anemoi.inference.input import Input
from anemoi.inference.output import Output
from anemoi.inference.processor import Processor
from anemoi.inference.types import IntArray

if TYPE_CHECKING:
from anemoi.inference.input import Input
from anemoi.inference.output import Output

from .checkpoint import Checkpoint
from .forcings import Forcings


LOG = logging.getLogger(__name__)


Expand Down Expand Up @@ -64,7 +64,7 @@ def checkpoint(self) -> "Checkpoint":
# expected to provide the forcings directly as input to the runner.
##################################################################

def create_input(self) -> Input:
def create_input(self) -> "Input":
"""Creates an input object for the inference.

Returns
Expand All @@ -74,7 +74,7 @@ def create_input(self) -> Input:
"""
raise NotImplementedError()

def create_output(self) -> Output:
def create_output(self) -> "Output":
"""Creates an output object for the inference.

Returns
Expand Down
76 changes: 75 additions & 1 deletion src/anemoi/inference/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,24 +180,98 @@ class ForwardOutput(Output):
"""

def __init__(
self, context: "Context", output_frequency: Optional[int] = None, write_initial_state: Optional[bool] = None
self,
context: "Context",
output: dict,
output_frequency: Optional[int] = None,
write_initial_state: Optional[bool] = None,
):
"""Initialize the ForwardOutput object.

Parameters
----------
context : Context
The context in which the output operates.
output : dict
The output configuration dictionary.
output_frequency : Optional[int], optional
The frequency at which to output states, by default None.
write_initial_state : Optional[bool], optional
Whether to write the initial state, by default None.
"""

from anemoi.inference.outputs import create_output

super().__init__(context, output_frequency=None, write_initial_state=write_initial_state)

self.output = None if output is None else create_output(context, output)

if self.context.output_frequency is not None:
LOG.warning("output_frequency is ignored for '%s'", self.__class__.__name__)

@cached_property
def output_frequency(self) -> Optional[datetime.timedelta]:
"""Get the output frequency."""
return None

def modify_state(self, state: State) -> State:
"""Modify the state before writing.

Parameters
----------
state : State
The state to modify.

Returns
-------
State
The modified state.
"""
return state

def open(self, state) -> None:
"""Open the output for writing.
Parameters
----------
state : State
The initial state.
"""
self.output.open(self.modify_state(state))

def close(self) -> None:
"""Close the output."""

self.output.close()

def write_initial_step(self, state: State) -> None:
"""Write the initial step of the state.

Parameters
----------
state : State
The state dictionary.
"""
state.setdefault("step", datetime.timedelta(0))

self.output.write_initial_state(self.modify_state(state))

def write_step(self, state: State) -> None:
"""Write a step of the state.

Parameters
----------
state : State
The state to write.
"""
self.output.write_state(self.modify_state(state))

def print_summary(self, depth: int = 0) -> None:
"""Print a summary of the output.

Parameters
----------
depth : int, optional
The depth of the summary, by default 0.
"""
super().print_summary(depth)
self.output.print_summary(depth + 1)
85 changes: 9 additions & 76 deletions src/anemoi/inference/outputs/apply_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,15 @@

from anemoi.inference.config import Configuration
from anemoi.inference.context import Context
from anemoi.inference.types import State

from ..output import ForwardOutput
from . import create_output
from . import output_registry
from .masked import MaskedOutput

LOG = logging.getLogger(__name__)


@output_registry.register("apply_mask")
class ApplyMaskOutput(ForwardOutput):
class ApplyMaskOutput(MaskedOutput):
"""Apply mask output class.

Parameters
Expand All @@ -48,75 +46,10 @@ def __init__(
output_frequency: Optional[int] = None,
write_initial_state: Optional[bool] = None,
) -> None:
super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state)
self.mask = self.checkpoint.load_supporting_array(mask)
self.output = create_output(context, output)

def __repr__(self) -> str:
"""Return a string representation of the ApplyMaskOutput object."""
return f"ApplyMaskOutput({self.mask}, {self.output})"

def write_initial_step(self, state: State) -> None:
"""Write the initial step of the state.

Parameters
----------
state : State
The state dictionary.
"""
# Note: we foreward to 'state', so we write-up options again
self.output.write_initial_state(self._apply_mask(state))

def write_step(self, state: State) -> None:
"""Write a step of the state.

Parameters
----------
state : State
The state dictionary.
"""
# Note: we foreward to 'state', so we write-up options again
self.output.write_state(self._apply_mask(state))

def _apply_mask(self, state: State) -> State:
"""Apply the mask to the state.

Parameters
----------
state : State
The state dictionary.

Returns
-------
State
The masked state dictionary.
"""
state = state.copy()
state["fields"] = state["fields"].copy()
state["latitudes"] = state["latitudes"][self.mask]
state["longitudes"] = state["longitudes"][self.mask]

for field in state["fields"]:
data = state["fields"][field]
if data.ndim == 1:
data = data[self.mask]
else:
data = data[..., self.mask]
state["fields"][field] = data

return state

def close(self) -> None:
"""Close the output."""
self.output.close()

def print_summary(self, depth: int = 0) -> None:
"""Print the summary of the output.

Parameters
----------
depth : int, optional
The depth of the summary, by default 0.
"""
super().print_summary(depth)
self.output.print_summary(depth + 1)
super().__init__(
context,
mask=self.checkpoint.load_supporting_array(mask),
output=output,
output_frequency=output_frequency,
write_initial_state=write_initial_state,
)
94 changes: 14 additions & 80 deletions src/anemoi/inference/outputs/extract_lam.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,15 @@

from anemoi.inference.config import Configuration
from anemoi.inference.context import Context
from anemoi.inference.types import State

from ..output import ForwardOutput
from . import create_output
from . import output_registry
from .masked import MaskedOutput

LOG = logging.getLogger(__name__)


@output_registry.register("extract_lam")
class ExtractLamOutput(ForwardOutput):
class ExtractLamOutput(MaskedOutput):
"""Extract LAM output class.

Parameters
Expand All @@ -50,91 +48,27 @@ def __init__(
output_frequency: Optional[int] = None,
write_initial_state: Optional[bool] = None,
) -> None:
super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state)

if "cutout_mask" in self.checkpoint.supporting_arrays:
if "cutout_mask" in context.checkpoint.supporting_arrays:
# Backwards compatibility
mask = self.checkpoint.load_supporting_array("cutout_mask")
mask = context.checkpoint.load_supporting_array("cutout_mask")
points = slice(None, -np.sum(mask))
else:
if "lam_0" not in lam:
raise NotImplementedError("Only lam_0 is supported")

if "lam_1/cutout_mask" in self.checkpoint.supporting_arrays:
if "lam_1/cutout_mask" in context.checkpoint.supporting_arrays:
raise NotImplementedError("Only lam_0 is supported")

mask = self.checkpoint.load_supporting_array(f"{lam}/cutout_mask")
mask = context.checkpoint.load_supporting_array(f"{lam}/cutout_mask")

assert len(mask) == np.sum(mask)
points = slice(None, np.sum(mask))

self.points = points
self.output = create_output(context, output)

def __repr__(self) -> str:
"""Return a string representation of the ExtractLamOutput object."""
return f"ExtractLamOutput({self.points}, {self.output})"

def write_initial_state(self, state: State) -> None:
"""Write the initial step of the state.

Parameters
----------
state : State
The state dictionary.
"""
# Note: we foreward to 'state', so we write-up options again
self.output.write_initial_state(self._apply_mask(state))

def write_step(self, state: State) -> None:
"""Write a step of the state.

Parameters
----------
state : State
The state dictionary.
"""
# Note: we foreward to 'state', so we write-up options again
self.output.write_state(self._apply_mask(state))

def _apply_mask(self, state: State) -> State:
"""Apply the mask to the state.

Parameters
----------
state : State
The state dictionary.

Returns
-------
State
The masked state dictionary.
"""
state = state.copy()
state["fields"] = state["fields"].copy()
state["latitudes"] = state["latitudes"][self.points]
state["longitudes"] = state["longitudes"][self.points]

for field in state["fields"]:
data = state["fields"][field]
if data.ndim == 1:
data = data[self.points]
else:
data = data[..., self.points]
state["fields"][field] = data

return state

def close(self) -> None:
"""Close the output."""
self.output.close()

def print_summary(self, depth: int = 0) -> None:
"""Print the summary of the output.

Parameters
----------
depth : int, optional
The depth of the summary, by default 0.
"""
super().print_summary(depth)
self.output.print_summary(depth + 1)
super().__init__(
context,
mask=points,
output=output,
output_frequency=output_frequency,
write_initial_state=write_initial_state,
)
Loading
Loading