Skip to content

Commit ca6d37f

Browse files
authored
feat(input): add FDB input class (#190)
1 parent fbc3399 commit ca6d37f

File tree

1 file changed

+88
-0
lines changed
  • src/anemoi/inference/inputs

1 file changed

+88
-0
lines changed

src/anemoi/inference/inputs/fdb.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# (C) Copyright 2024 Anemoi contributors.
2+
#
3+
# This software is licensed under the terms of the Apache Licence Version 2.0
4+
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
5+
#
6+
# In applying this licence, ECMWF does not waive the privileges and immunities
7+
# granted to it by virtue of its status as an intergovernmental organisation
8+
# nor does it submit to any jurisdiction.
9+
10+
import datetime
11+
import logging
12+
from typing import Any
13+
from typing import List
14+
from typing import Optional
15+
16+
import earthkit.data as ekd
17+
import numpy as np
18+
19+
from ..types import Date
20+
from ..types import State
21+
from . import input_registry
22+
from .grib import GribInput
23+
24+
LOG = logging.getLogger(__name__)
25+
26+
27+
@input_registry.register("fdb")
28+
class FDBInput(GribInput):
29+
"""Get input fields from FDB."""
30+
31+
trace_name = "fdb"
32+
33+
def __init__(
34+
self,
35+
context,
36+
*,
37+
namer=None,
38+
fdb_config: dict | None = None,
39+
fdb_userconfig: dict | None = None,
40+
**kwargs: dict[str, Any],
41+
):
42+
"""Initialise the FDB input.
43+
44+
Parameters
45+
----------
46+
context : dict
47+
The context runner.
48+
namer : optional
49+
The namer to use for the input.
50+
fdb_config : dict, optional
51+
The FDB config to use.
52+
fdb_userconfig : dict, optional
53+
The FDB userconfig to use.
54+
kwargs : dict, optional
55+
Additional keyword arguments for the request to FDB.
56+
"""
57+
super().__init__(context, namer=namer)
58+
self.kwargs = kwargs
59+
self.configs = {"config": fdb_config, "userconfig": fdb_userconfig}
60+
# NOTE: this is a temporary workaround for #191 thus not documented
61+
self.param_id_map = kwargs.pop("param_id_map", {})
62+
self.variables = self.checkpoint.variables_from_input(include_forcings=False)
63+
64+
def create_input_state(self, *, date: Optional[Date]) -> State:
65+
date = np.datetime64(date).astype(datetime.datetime)
66+
dates = [date + h for h in self.checkpoint.lagged]
67+
ds = self.retrieve(variables=self.variables, dates=dates)
68+
res = self._create_input_state(ds, variables=None, date=date)
69+
return res
70+
71+
def load_forcings_state(self, *, variables: List[str], dates: List[Date], current_state: State) -> State:
72+
ds = self.retrieve(variables=variables, dates=dates)
73+
return self._load_forcings_state(ds, variables=variables, dates=dates, current_state=current_state)
74+
75+
def retrieve(self, variables: List[str], dates: List[Date]) -> Any:
76+
requests = self.checkpoint.mars_requests(
77+
variables=variables,
78+
dates=dates,
79+
use_grib_paramid=self.context.use_grib_paramid,
80+
patch_request=self.context.patch_data_request,
81+
)
82+
requests = [self.kwargs | r for r in requests]
83+
# NOTE: this is a temporary workaround for #191
84+
for request in requests:
85+
request["param"] = [self.param_id_map.get(p, p) for p in request["param"]]
86+
sources = [ekd.from_source("fdb", request, stream=False, **self.configs) for request in requests]
87+
ds = ekd.from_source("multi", sources)
88+
return ds

0 commit comments

Comments
 (0)