diff --git a/.gitignore b/.gitignore index 3679b631..ee459e85 100644 --- a/.gitignore +++ b/.gitignore @@ -120,7 +120,6 @@ tmp/ temp/ logs/ _dev/ -outputs *tmp_data/ # Project specific diff --git a/src/anemoi/inference/context.py b/src/anemoi/inference/context.py index 90c7d300..2acf31de 100644 --- a/src/anemoi/inference/context.py +++ b/src/anemoi/inference/context.py @@ -18,16 +18,16 @@ from typing import List from typing import Optional -from anemoi.inference.input import Input -from anemoi.inference.output import Output from anemoi.inference.processor import Processor from anemoi.inference.types import IntArray if TYPE_CHECKING: + from anemoi.inference.input import Input + from anemoi.inference.output import Output + from .checkpoint import Checkpoint from .forcings import Forcings - LOG = logging.getLogger(__name__) @@ -64,7 +64,7 @@ def checkpoint(self) -> "Checkpoint": # expected to provide the forcings directly as input to the runner. ################################################################## - def create_input(self) -> Input: + def create_input(self) -> "Input": """Creates an input object for the inference. Returns @@ -74,7 +74,7 @@ def create_input(self) -> Input: """ raise NotImplementedError() - def create_output(self) -> Output: + def create_output(self) -> "Output": """Creates an output object for the inference. Returns diff --git a/src/anemoi/inference/output.py b/src/anemoi/inference/output.py index cc108ce4..d01dbd6e 100644 --- a/src/anemoi/inference/output.py +++ b/src/anemoi/inference/output.py @@ -180,7 +180,11 @@ class ForwardOutput(Output): """ def __init__( - self, context: "Context", output_frequency: Optional[int] = None, write_initial_state: Optional[bool] = None + self, + context: "Context", + output: dict, + output_frequency: Optional[int] = None, + write_initial_state: Optional[bool] = None, ): """Initialize the ForwardOutput object. @@ -188,12 +192,20 @@ def __init__( ---------- context : Context The context in which the output operates. + output : dict + The output configuration dictionary. output_frequency : Optional[int], optional The frequency at which to output states, by default None. write_initial_state : Optional[bool], optional Whether to write the initial state, by default None. """ + + from anemoi.inference.outputs import create_output + super().__init__(context, output_frequency=None, write_initial_state=write_initial_state) + + self.output = None if output is None else create_output(context, output) + if self.context.output_frequency is not None: LOG.warning("output_frequency is ignored for '%s'", self.__class__.__name__) @@ -201,3 +213,65 @@ def __init__( def output_frequency(self) -> Optional[datetime.timedelta]: """Get the output frequency.""" return None + + def modify_state(self, state: State) -> State: + """Modify the state before writing. + + Parameters + ---------- + state : State + The state to modify. + + Returns + ------- + State + The modified state. + """ + return state + + def open(self, state) -> None: + """Open the output for writing. + Parameters + ---------- + state : State + The initial state. + """ + self.output.open(self.modify_state(state)) + + def close(self) -> None: + """Close the output.""" + + self.output.close() + + def write_initial_step(self, state: State) -> None: + """Write the initial step of the state. + + Parameters + ---------- + state : State + The state dictionary. + """ + state.setdefault("step", datetime.timedelta(0)) + + self.output.write_initial_state(self.modify_state(state)) + + def write_step(self, state: State) -> None: + """Write a step of the state. + + Parameters + ---------- + state : State + The state to write. + """ + self.output.write_state(self.modify_state(state)) + + def print_summary(self, depth: int = 0) -> None: + """Print a summary of the output. + + Parameters + ---------- + depth : int, optional + The depth of the summary, by default 0. + """ + super().print_summary(depth) + self.output.print_summary(depth + 1) diff --git a/src/anemoi/inference/outputs/apply_mask.py b/src/anemoi/inference/outputs/apply_mask.py index 7a20ebd2..833301b6 100644 --- a/src/anemoi/inference/outputs/apply_mask.py +++ b/src/anemoi/inference/outputs/apply_mask.py @@ -12,17 +12,15 @@ from anemoi.inference.config import Configuration from anemoi.inference.context import Context -from anemoi.inference.types import State -from ..output import ForwardOutput -from . import create_output from . import output_registry +from .masked import MaskedOutput LOG = logging.getLogger(__name__) @output_registry.register("apply_mask") -class ApplyMaskOutput(ForwardOutput): +class ApplyMaskOutput(MaskedOutput): """Apply mask output class. Parameters @@ -48,75 +46,10 @@ def __init__( output_frequency: Optional[int] = None, write_initial_state: Optional[bool] = None, ) -> None: - super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state) - self.mask = self.checkpoint.load_supporting_array(mask) - self.output = create_output(context, output) - - def __repr__(self) -> str: - """Return a string representation of the ApplyMaskOutput object.""" - return f"ApplyMaskOutput({self.mask}, {self.output})" - - def write_initial_step(self, state: State) -> None: - """Write the initial step of the state. - - Parameters - ---------- - state : State - The state dictionary. - """ - # Note: we foreward to 'state', so we write-up options again - self.output.write_initial_state(self._apply_mask(state)) - - def write_step(self, state: State) -> None: - """Write a step of the state. - - Parameters - ---------- - state : State - The state dictionary. - """ - # Note: we foreward to 'state', so we write-up options again - self.output.write_state(self._apply_mask(state)) - - def _apply_mask(self, state: State) -> State: - """Apply the mask to the state. - - Parameters - ---------- - state : State - The state dictionary. - - Returns - ------- - State - The masked state dictionary. - """ - state = state.copy() - state["fields"] = state["fields"].copy() - state["latitudes"] = state["latitudes"][self.mask] - state["longitudes"] = state["longitudes"][self.mask] - - for field in state["fields"]: - data = state["fields"][field] - if data.ndim == 1: - data = data[self.mask] - else: - data = data[..., self.mask] - state["fields"][field] = data - - return state - - def close(self) -> None: - """Close the output.""" - self.output.close() - - def print_summary(self, depth: int = 0) -> None: - """Print the summary of the output. - - Parameters - ---------- - depth : int, optional - The depth of the summary, by default 0. - """ - super().print_summary(depth) - self.output.print_summary(depth + 1) + super().__init__( + context, + mask=self.checkpoint.load_supporting_array(mask), + output=output, + output_frequency=output_frequency, + write_initial_state=write_initial_state, + ) diff --git a/src/anemoi/inference/outputs/extract_lam.py b/src/anemoi/inference/outputs/extract_lam.py index 80f5bde3..08a5ec0f 100644 --- a/src/anemoi/inference/outputs/extract_lam.py +++ b/src/anemoi/inference/outputs/extract_lam.py @@ -14,17 +14,15 @@ from anemoi.inference.config import Configuration from anemoi.inference.context import Context -from anemoi.inference.types import State -from ..output import ForwardOutput -from . import create_output from . import output_registry +from .masked import MaskedOutput LOG = logging.getLogger(__name__) @output_registry.register("extract_lam") -class ExtractLamOutput(ForwardOutput): +class ExtractLamOutput(MaskedOutput): """Extract LAM output class. Parameters @@ -50,91 +48,27 @@ def __init__( output_frequency: Optional[int] = None, write_initial_state: Optional[bool] = None, ) -> None: - super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state) - if "cutout_mask" in self.checkpoint.supporting_arrays: + if "cutout_mask" in context.checkpoint.supporting_arrays: # Backwards compatibility - mask = self.checkpoint.load_supporting_array("cutout_mask") + mask = context.checkpoint.load_supporting_array("cutout_mask") points = slice(None, -np.sum(mask)) else: if "lam_0" not in lam: raise NotImplementedError("Only lam_0 is supported") - if "lam_1/cutout_mask" in self.checkpoint.supporting_arrays: + if "lam_1/cutout_mask" in context.checkpoint.supporting_arrays: raise NotImplementedError("Only lam_0 is supported") - mask = self.checkpoint.load_supporting_array(f"{lam}/cutout_mask") + mask = context.checkpoint.load_supporting_array(f"{lam}/cutout_mask") + assert len(mask) == np.sum(mask) points = slice(None, np.sum(mask)) - self.points = points - self.output = create_output(context, output) - - def __repr__(self) -> str: - """Return a string representation of the ExtractLamOutput object.""" - return f"ExtractLamOutput({self.points}, {self.output})" - - def write_initial_state(self, state: State) -> None: - """Write the initial step of the state. - - Parameters - ---------- - state : State - The state dictionary. - """ - # Note: we foreward to 'state', so we write-up options again - self.output.write_initial_state(self._apply_mask(state)) - - def write_step(self, state: State) -> None: - """Write a step of the state. - - Parameters - ---------- - state : State - The state dictionary. - """ - # Note: we foreward to 'state', so we write-up options again - self.output.write_state(self._apply_mask(state)) - - def _apply_mask(self, state: State) -> State: - """Apply the mask to the state. - - Parameters - ---------- - state : State - The state dictionary. - - Returns - ------- - State - The masked state dictionary. - """ - state = state.copy() - state["fields"] = state["fields"].copy() - state["latitudes"] = state["latitudes"][self.points] - state["longitudes"] = state["longitudes"][self.points] - - for field in state["fields"]: - data = state["fields"][field] - if data.ndim == 1: - data = data[self.points] - else: - data = data[..., self.points] - state["fields"][field] = data - - return state - - def close(self) -> None: - """Close the output.""" - self.output.close() - - def print_summary(self, depth: int = 0) -> None: - """Print the summary of the output. - - Parameters - ---------- - depth : int, optional - The depth of the summary, by default 0. - """ - super().print_summary(depth) - self.output.print_summary(depth + 1) + super().__init__( + context, + mask=points, + output=output, + output_frequency=output_frequency, + write_initial_state=write_initial_state, + ) diff --git a/src/anemoi/inference/outputs/masked.py b/src/anemoi/inference/outputs/masked.py new file mode 100644 index 00000000..16ecf012 --- /dev/null +++ b/src/anemoi/inference/outputs/masked.py @@ -0,0 +1,82 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +from typing import Any +from typing import Optional + +from anemoi.inference.config import Configuration +from anemoi.inference.context import Context +from anemoi.inference.types import State + +from ..output import ForwardOutput + +LOG = logging.getLogger(__name__) + + +class MaskedOutput(ForwardOutput): + """Apply mask output class. + + Parameters + ---------- + context : dict + The context dictionary. + mask : Any + The mask. + output : dict + The output configuration dictionary. + output_frequency : int, optional + The frequency of output, by default None. + write_initial_state : bool, optional + Whether to write the initial state, by default None. + """ + + def __init__( + self, + context: Context, + *, + mask: Any, + output: Configuration, + output_frequency: Optional[int] = None, + write_initial_state: Optional[bool] = None, + ) -> None: + super().__init__(context, output, output_frequency=output_frequency, write_initial_state=write_initial_state) + self.mask = mask + + def modify_state(self, state: State) -> State: + """Apply the mask to the state. + + Parameters + ---------- + state : State + The state dictionary. + + Returns + ------- + State + The masked state dictionary. + """ + state = state.copy() + state["fields"] = state["fields"].copy() + state["latitudes"] = state["latitudes"][self.mask] + state["longitudes"] = state["longitudes"][self.mask] + + for field in state["fields"]: + data = state["fields"][field] + if data.ndim == 1: + data = data[self.mask] + else: + data = data[..., self.mask] + state["fields"][field] = data + + return state + + def __repr__(self) -> str: + """Return a string representation of the object.""" + return f"{self.__class__.__name__}({self.mask}, {self.output})" diff --git a/src/anemoi/inference/outputs/netcdf.py b/src/anemoi/inference/outputs/netcdf.py index b5579b3c..cb7349d2 100644 --- a/src/anemoi/inference/outputs/netcdf.py +++ b/src/anemoi/inference/outputs/netcdf.py @@ -31,19 +31,7 @@ @output_registry.register("netcdf") @main_argument("path") class NetCDFOutput(Output): - """NetCDF output class. - - Parameters - ---------- - context : dict - The context dictionary. - path : str - The path to save the NetCDF file. - output_frequency : int, optional - The frequency of output, by default None. - write_initial_state : bool, optional - Whether to write the initial state, by default None. - """ + """NetCDF output class.""" def __init__( self, @@ -51,14 +39,39 @@ def __init__( path: str, output_frequency: Optional[int] = None, write_initial_state: Optional[bool] = None, + float_size: str = "f4", + missing_value: Optional[float] = np.nan, ) -> None: + """Initialize the NetCDF output object. + + Parameters + ---------- + context : dict + The context dictionary. + path : str + The path to save the NetCDF file. + output_frequency : int, optional + The frequency of output, by default None. + write_initial_state : bool, optional + Whether to write the initial state, by default None. + float_size : str, optional + The size of the float, by default "f4". + missing_value : float, optional + The missing value, by default np.nan. + """ + super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state) from netCDF4 import Dataset self.path = path self.ncfile: Optional[Dataset] = None - self.float_size = "f4" + self.float_size = float_size + self.missing_value = missing_value + if self.write_step_zero: + self.extra_time = 1 + else: + self.extra_time = 0 def __repr__(self) -> str: """Return a string representation of the NetCDFOutput object.""" @@ -95,6 +108,8 @@ def open(self, state: State) -> None: lead_time := getattr(self.context, "lead_time", None) ): time = lead_time // time_step + time += self.extra_time + if reference_date := getattr(self.context, "reference_date", None): self.reference_date = reference_date @@ -123,7 +138,6 @@ def open(self, state: State) -> None: self.longitude_var[:] = longitudes self.vars = {} - self.ensure_variables(state) self.n = 0 @@ -135,6 +149,7 @@ def ensure_variables(self, state: State) -> None: state : State The state dictionary. """ + values = len(state["latitudes"]) compression = {} # dict(zlib=False, complevel=0) @@ -148,14 +163,20 @@ def ensure_variables(self, state: State) -> None: chunksizes = tuple(int(np.ceil(x / 2)) for x in chunksizes) with LOCK: + missing_value = self.missing_value + self.vars[name] = self.ncfile.createVariable( name, self.float_size, ("time", "values"), chunksizes=chunksizes, + fill_value=missing_value, **compression, ) + self.vars[name].fill_value = missing_value + self.vars[name].missing_value = missing_value + def write_step(self, state: State) -> None: """Write the state. @@ -172,6 +193,7 @@ def write_step(self, state: State) -> None: for name, value in state["fields"].items(): with LOCK: + LOG.info(f"🚧🚧🚧🚧🚧🚧 XXXXXX {name}, {self.n}, {value.shape}") self.vars[name][self.n] = value self.n += 1 diff --git a/src/anemoi/inference/outputs/tee.py b/src/anemoi/inference/outputs/tee.py index 543ba73f..ef6ae2a5 100644 --- a/src/anemoi/inference/outputs/tee.py +++ b/src/anemoi/inference/outputs/tee.py @@ -54,7 +54,7 @@ def __init__( **kwargs : Any Additional keyword arguments. """ - super().__init__(context, output_frequency=output_frequency, write_initial_state=write_initial_state) + super().__init__(context, None, output_frequency=output_frequency, write_initial_state=write_initial_state) if outputs is None: outputs = args diff --git a/src/anemoi/inference/outputs/truth.py b/src/anemoi/inference/outputs/truth.py index 56f3ac11..96d0426f 100644 --- a/src/anemoi/inference/outputs/truth.py +++ b/src/anemoi/inference/outputs/truth.py @@ -15,8 +15,6 @@ from ..context import Context from ..output import ForwardOutput -from ..output import Output -from . import create_output from . import output_registry LOG = logging.getLogger(__name__) @@ -42,19 +40,8 @@ def __init__(self, context: Context, output: Configuration, **kwargs: Any) -> No kwargs : dict Additional keyword arguments. """ - super().__init__(context, **kwargs) + super().__init__(context, output, **kwargs) self._input = self.context.create_input() - self.output: Output = create_output(context, output) - - def write_initial_state(self, state: State) -> None: - """Write the initial state. - - Parameters - ---------- - state : State - The initial state to write. - """ - self.output.write_initial_state(state) def write_step(self, state: State) -> None: """Write a step of the state. @@ -68,20 +55,6 @@ def write_step(self, state: State) -> None: reduced_state = self.reduce(truth_state) self.output.write_state(reduced_state) - def open(self, state: State) -> None: - """Open the output for writing. - - Parameters - ---------- - state : State - The state to open. - """ - self.output.open(state) - - def close(self) -> None: - """Close the output.""" - self.output.close() - def __repr__(self) -> str: """Return a string representation of the TruthOutput. @@ -91,14 +64,3 @@ def __repr__(self) -> str: String representation of the TruthOutput. """ return f"TruthOutput({self.output})" - - def print_summary(self, depth: int = 0) -> None: - """Print a summary of the output. - - Parameters - ---------- - depth : int, optional - The depth of the summary, by default 0. - """ - super().print_summary(depth) - self.output.print_summary(depth + 1) diff --git a/src/anemoi/inference/runners/default.py b/src/anemoi/inference/runners/default.py index d3280931..2be2d7e1 100644 --- a/src/anemoi/inference/runners/default.py +++ b/src/anemoi/inference/runners/default.py @@ -144,8 +144,7 @@ def create_output(self) -> Output: The created output. """ output = create_output(self, self.config.output) - LOG.info("Output:") - output.print_summary() + LOG.info("Output: %s", output) return output def create_constant_computed_forcings(self, variables: List[str], mask: IntArray) -> List[Forcings]: