|
| 1 | +# (C) Copyright 2024 Anemoi contributors. |
| 2 | +# |
| 3 | +# This software is licensed under the terms of the Apache Licence Version 2.0 |
| 4 | +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. |
| 5 | +# |
| 6 | +# In applying this licence, ECMWF does not waive the privileges and immunities |
| 7 | +# granted to it by virtue of its status as an intergovernmental organisation |
| 8 | +# nor does it submit to any jurisdiction. |
| 9 | + |
| 10 | +import datetime |
| 11 | +import logging |
| 12 | +from typing import Any |
| 13 | +from typing import List |
| 14 | +from typing import Optional |
| 15 | + |
| 16 | +import earthkit.data as ekd |
| 17 | +import numpy as np |
| 18 | + |
| 19 | +from ..types import Date |
| 20 | +from ..types import State |
| 21 | +from . import input_registry |
| 22 | +from .grib import GribInput |
| 23 | + |
| 24 | +LOG = logging.getLogger(__name__) |
| 25 | + |
| 26 | + |
| 27 | +@input_registry.register("fdb") |
| 28 | +class FDBInput(GribInput): |
| 29 | + """Get input fields from FDB.""" |
| 30 | + |
| 31 | + trace_name = "fdb" |
| 32 | + |
| 33 | + def __init__( |
| 34 | + self, |
| 35 | + context, |
| 36 | + *, |
| 37 | + namer=None, |
| 38 | + fdb_config: dict | None = None, |
| 39 | + fdb_userconfig: dict | None = None, |
| 40 | + **kwargs: dict[str, Any], |
| 41 | + ): |
| 42 | + """Initialise the FDB input. |
| 43 | +
|
| 44 | + Parameters |
| 45 | + ---------- |
| 46 | + context : dict |
| 47 | + The context runner. |
| 48 | + namer : optional |
| 49 | + The namer to use for the input. |
| 50 | + fdb_config : dict, optional |
| 51 | + The FDB config to use. |
| 52 | + fdb_userconfig : dict, optional |
| 53 | + The FDB userconfig to use. |
| 54 | + kwargs : dict, optional |
| 55 | + Additional keyword arguments for the request to FDB. |
| 56 | + """ |
| 57 | + super().__init__(context, namer=namer) |
| 58 | + self.kwargs = kwargs |
| 59 | + self.configs = {"config": fdb_config, "userconfig": fdb_userconfig} |
| 60 | + # NOTE: this is a temporary workaround for #191 thus not documented |
| 61 | + self.param_id_map = kwargs.pop("param_id_map", {}) |
| 62 | + self.variables = self.checkpoint.variables_from_input(include_forcings=False) |
| 63 | + |
| 64 | + def create_input_state(self, *, date: Optional[Date]) -> State: |
| 65 | + date = np.datetime64(date).astype(datetime.datetime) |
| 66 | + dates = [date + h for h in self.checkpoint.lagged] |
| 67 | + ds = self.retrieve(variables=self.variables, dates=dates) |
| 68 | + res = self._create_input_state(ds, variables=None, date=date) |
| 69 | + return res |
| 70 | + |
| 71 | + def load_forcings_state(self, *, variables: List[str], dates: List[Date], current_state: State) -> State: |
| 72 | + ds = self.retrieve(variables=variables, dates=dates) |
| 73 | + return self._load_forcings_state(ds, variables=variables, dates=dates, current_state=current_state) |
| 74 | + |
| 75 | + def retrieve(self, variables: List[str], dates: List[Date]) -> Any: |
| 76 | + requests = self.checkpoint.mars_requests( |
| 77 | + variables=variables, |
| 78 | + dates=dates, |
| 79 | + use_grib_paramid=self.context.use_grib_paramid, |
| 80 | + patch_request=self.context.patch_data_request, |
| 81 | + ) |
| 82 | + requests = [self.kwargs | r for r in requests] |
| 83 | + # NOTE: this is a temporary workaround for #191 |
| 84 | + for request in requests: |
| 85 | + request["param"] = [self.param_id_map.get(p, p) for p in request["param"]] |
| 86 | + sources = [ekd.from_source("fdb", request, stream=False, **self.configs) for request in requests] |
| 87 | + ds = ekd.from_source("multi", sources) |
| 88 | + return ds |
0 commit comments