Skip to content

Commit 79ba08c

Browse files
authored
Merge pull request #32 from ecmwf/feat/reorganise_functions
Feat/reorganise functions
2 parents 5af881b + 6478ed9 commit 79ba08c

File tree

10 files changed

+856
-426
lines changed

10 files changed

+856
-426
lines changed

README.md

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -65,65 +65,57 @@ Loads a precomputed `RiverNetwork`. Current options are
6565

6666

6767
```
68-
ekh.from_netcdf_d8(filename)
68+
ekh.create_river_network(path, river_network_format, source)
6969
```
70-
Creates a `RiverNetwork` from a D8 (PCRaster LDD convention) NetCDF format.
70+
Creates a `RiverNetwork`. Current options are
71+
- river_network_format: "esri_d8", "pcr_d8", "cama" or "precomputed"
72+
- source: An earthkit-data compatable source. See [list](https://earthkit-data.readthedocs.io/en/latest/guide/sources.html)
7173

72-
```
73-
ekh.from_netcdf_cama(filename, type)
74-
```
75-
Creates a `RiverNetwork` from a CaMa-Flood NetCDF format of type "downxy" or "nextxy".
76-
77-
```
78-
ekh.from_bin_cama(filename, type)
79-
```
80-
Creates a `RiverNetwork` from a CaMa-Flood bin format of type "downxy" or "nextxy".
81-
82-
### RiverNetwork methods
74+
### Methods
8375

8476
```
85-
network.accuflux(field)
77+
ekh.flow_downstream(river_network, field)
8678
```
8779
Calculates the total accumulated flux down a river network.\
8880
$$v_i^{\prime}=v_i+\sum_{j \rightarrow i}~v_j^{\prime}$$
8981

9082
<img src="docs/images/accuflux.gif" width="200px" height="160px" />
9183

9284
```
93-
network.upstream(field)
85+
ekh.move_downstream(river_network, field)
9486
```
9587
Updates each node with the sum of its upstream nodes.\
9688
$$v_i^{\prime}=\sum_{j \rightarrow i}~v_j$$
9789

9890
```
99-
network.downstream(field)
91+
ekh.move_upstream(river_network, field)
10092
```
10193
Updates each node with its downstream node.\
10294
$$v_i^{\prime} = v_j, ~j ~ \text{s.t.} ~ i \rightarrow j$$
10395

10496
```
105-
network.catchment(field)
97+
ekh.find_catchments(river_network, field)
10698
```
10799
Finds the catchments (all upstream nodes of specified nodes, with overwriting).\
108100
$$v_i^{\prime} = v_j^{\prime} ~ \text{if} ~ v_j^{\prime} \neq 0 ~ \text{else} ~ v_i, ~j ~ \text{s.t.} ~ i \rightarrow j$$
109101

110102
<img src="docs/images/catchment.gif" width="200px" height="160px" />
111103

112104
```
113-
network.subcatchment(field)
105+
ekh.find_subcatchments(river_network, field)
114106
```
115107
Finds the subcatchments (all upstream nodes of specified nodes, without overwriting).\
116108
$$v_i^{\prime} = v_j^{\prime} ~ \text{if} ~ (v_j^{\prime} \neq 0 ~ \text{and} ~ v_j = 0) ~ \text{else} ~ v_i, ~j ~ \text{s.t.} ~ i \rightarrow j$$
117109

118110
<img src="docs/images/subcatchment.gif" width="200px" height="160px" />
119111

120112
```
121-
network.create_subnetwork(mask)
113+
river_network.create_subnetwork(field)
122114
```
123-
Computes the river subnetwork defined by a mask of the domain.
115+
Computes the river subnetwork defined by a field mask of the domain.
124116

125117
```
126-
network.export(filename)
118+
river_network.export(filename)
127119
```
128120
Exports the `RiverNetwork` as a joblib pickle.
129121

src/earthkit/hydro/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
11
from .readers import from_cama_downxy, from_cama_nextxy, from_d8, load_river_network, create_river_network
22
from .river_network import RiverNetwork
3+
from .accumulation import flow_downstream
4+
from .movement import move_downstream, move_upstream
5+
from .catchment import find_catchments, find_subcatchments
6+
from .core import flow

src/earthkit/hydro/accumulation.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import numpy as np
2+
from .utils import mask_and_unmask_data, check_missing, is_missing
3+
from .core import flow
4+
5+
6+
@mask_and_unmask_data
7+
def flow_downstream(river_network, field, mv=np.nan, in_place=False, ufunc=np.add, accept_missing=False):
8+
"""
9+
Accumulates field values downstream.
10+
11+
Parameters
12+
----------
13+
river_network : earthkit.hydro.RiverNetwork
14+
An earthkit-hydro river network object.
15+
field : numpy.ndarray
16+
The input field.
17+
mv : scalar, optional
18+
The missing value indicator. Default is np.nan.
19+
in_place : bool, optional
20+
If True, modifies the input field in place. Default is False.
21+
ufunc : numpy.ufunc, optional
22+
The universal function (ufunc) to use for accumulation. Default is np.add.
23+
accept_missing : bool, optional
24+
If True, accepts missing values in the field. Default is False.
25+
Returns
26+
-------
27+
numpy.ndarray
28+
The field values accumulated downstream.
29+
"""
30+
31+
missing_values_present = check_missing(field, mv, accept_missing)
32+
33+
if not in_place:
34+
field = field.copy()
35+
36+
if not missing_values_present or np.isnan(mv):
37+
op = _ufunc_to_downstream
38+
else:
39+
if len(field.shape) == 1:
40+
op = _ufunc_to_downstream_missing_values_2D
41+
else:
42+
op = _ufunc_to_downstream_missing_values_ND
43+
44+
def operation(river_network, field, grouping, mv):
45+
return op(river_network, field, grouping, mv, ufunc=ufunc)
46+
47+
return flow(river_network, field, False, operation, mv)
48+
49+
50+
def _ufunc_to_downstream(river_network, field, grouping, mv, ufunc):
51+
"""
52+
Updates field in-place by applying a ufunc at the downstream nodes of the grouping, ignoring missing values.
53+
54+
Parameters
55+
----------
56+
river_network : earthkit.hydro.RiverNetwork
57+
An earthkit-hydro river network object.
58+
field : numpy.ndarray
59+
The input field.
60+
grouping : numpy.ndarray
61+
An array of indices.
62+
mv : scalar
63+
A missing value indicator (not used in the function but kept for consistency).
64+
ufunc : numpy.ufunc
65+
A universal function from the numpy library to be applied to the field data.
66+
Available ufuncs: https://numpy.org/doc/2.2/reference/ufuncs.html. Note: must allow two operands.
67+
68+
Returns
69+
-------
70+
None
71+
"""
72+
ufunc.at(field, river_network.downstream_nodes[grouping], field[grouping])
73+
74+
75+
def _ufunc_to_downstream_missing_values_2D(river_network, field, grouping, mv, ufunc):
76+
"""
77+
Applies a universal function (ufunc) to downstream nodes in a river network, dealing with missing values for 2D fields.
78+
79+
Parameters
80+
----------
81+
river_network : earthkit.hydro.RiverNetwork
82+
An earthkit-hydro river network object.
83+
field : numpy.ndarray
84+
The input field.
85+
grouping : numpy.ndarray
86+
An array of indices.
87+
mv : scalar
88+
A missing value indicator.
89+
ufunc : numpy.ufunc
90+
A universal function from the numpy library to be applied to the field data.
91+
Available ufuncs: https://numpy.org/doc/2.2/reference/ufuncs.html. Note: must allow two operands.
92+
93+
Returns
94+
-------
95+
None
96+
"""
97+
nodes_to_update = river_network.downstream_nodes[grouping]
98+
values_to_add = field[grouping]
99+
missing_indices = np.logical_or(is_missing(values_to_add, mv), is_missing(field[nodes_to_update], mv))
100+
ufunc.at(field, nodes_to_update, values_to_add)
101+
field[nodes_to_update[missing_indices]] = mv
102+
103+
104+
def _ufunc_to_downstream_missing_values_ND(river_network, field, grouping, mv, ufunc):
105+
"""
106+
Applies a universal function (ufunc) to downstream nodes in a river network, dealing with missing values for ND fields.
107+
108+
Parameters
109+
----------
110+
river_network : earthkit.hydro.RiverNetwork
111+
An earthkit-hydro river network object.
112+
field : numpy.ndarray
113+
The input field.
114+
grouping : numpy.ndarray
115+
An array of indices.
116+
mv : scalar
117+
A missing value indicator.
118+
ufunc : numpy.ufunc
119+
A universal function from the numpy library to be applied to the field data.
120+
Available ufuncs: https://numpy.org/doc/2.2/reference/ufuncs.html. Note: must allow two operands.
121+
122+
Returns
123+
-------
124+
None
125+
"""
126+
nodes_to_update = river_network.downstream_nodes[grouping]
127+
values_to_add = field[grouping]
128+
missing_indices = np.logical_or(is_missing(values_to_add, mv), is_missing(field[nodes_to_update], mv))
129+
ufunc.at(field, nodes_to_update, values_to_add)
130+
missing_indices = np.array(np.where(missing_indices))
131+
missing_indices[0] = nodes_to_update[missing_indices[0]]
132+
field[tuple(missing_indices)] = mv

src/earthkit/hydro/catchment.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
import numpy as np
2+
from .core import flow
3+
from .utils import mask_and_unmask_data, is_missing
4+
5+
6+
@mask_and_unmask_data
7+
def find_catchments(river_network, field, mv=0, in_place=False):
8+
"""
9+
Labels the catchments given a field of labelled sinks.
10+
11+
Parameters
12+
----------
13+
river_network : earthkit.hydro.RiverNetwork
14+
An earthkit-hydro river network object.
15+
field : numpy.ndarray
16+
The input field.
17+
mv : scalar, optional
18+
The missing value indicator. Default is 0.
19+
in_place : bool, optional
20+
If True, modifies the input field in place. Default is False.
21+
Returns
22+
-------
23+
numpy.ndarray
24+
The field values accumulated downstream.
25+
"""
26+
if not in_place:
27+
field = field.copy()
28+
29+
if len(field.shape) == 1:
30+
op = _find_catchments_2D
31+
else:
32+
op = _find_catchments_ND
33+
34+
def operation(river_network, field, grouping, mv):
35+
return op(river_network, field, grouping, mv, overwrite=True)
36+
37+
return flow(river_network, field, True, operation, mv)
38+
39+
40+
@mask_and_unmask_data
41+
def find_subcatchments(river_network, field, mv=0, in_place=False):
42+
"""
43+
Labels the subcatchments given a field of labelled sinks.
44+
45+
Parameters
46+
----------
47+
river_network : earthkit.hydro.RiverNetwork
48+
An earthkit-hydro river network object.
49+
field : numpy.ndarray
50+
The input field.
51+
mv : scalar, optional
52+
The missing value indicator. Default is 0.
53+
in_place : bool, optional
54+
If True, modifies the input field in place. Default is False.
55+
Returns
56+
-------
57+
numpy.ndarray
58+
The field values accumulated downstream.
59+
"""
60+
if not in_place:
61+
field = field.copy()
62+
63+
if len(field.shape) == 1:
64+
op = _find_catchments_2D
65+
else:
66+
op = _find_catchments_ND
67+
68+
def operation(river_network, field, grouping, mv):
69+
return op(river_network, field, grouping, mv, overwrite=False)
70+
71+
return flow(river_network, field, True, operation, mv)
72+
73+
74+
def _find_catchments_2D(river_network, field, grouping, mv, overwrite):
75+
"""
76+
Updates field in-place with the value of its downstream nodes, dealing with missing values for 2D fields.
77+
78+
Parameters
79+
----------
80+
river_network : earthkit.hydro.RiverNetwork
81+
An earthkit-hydro river network object.
82+
field : numpy.ndarray
83+
The input field.
84+
grouping : numpy.ndarray
85+
The array of node indices.
86+
mv : scalar
87+
The missing value indicator.
88+
overwrite : bool
89+
If True, overwrite existing non-missing values in the field array.
90+
91+
Returns
92+
-------
93+
None
94+
"""
95+
valid_group = grouping[
96+
~is_missing(field[river_network.downstream_nodes[grouping]], mv)
97+
] # only update nodes where the downstream belongs to a catchment
98+
if not overwrite:
99+
valid_group = valid_group[is_missing(field[valid_group], mv)]
100+
field[valid_group] = field[river_network.downstream_nodes[valid_group]]
101+
102+
103+
def _find_catchments_ND(river_network, field, grouping, mv, overwrite):
104+
"""
105+
Updates field in-place with the value of its downstream nodes, dealing with missing values for ND fields.
106+
107+
Parameters
108+
----------
109+
river_network : earthkit.hydro.RiverNetwork
110+
An earthkit-hydro river network object.
111+
field : numpy.ndarray
112+
The input field.
113+
grouping : numpy.ndarray
114+
The array of node indices.
115+
mv : scalar
116+
The missing value indicator.
117+
overwrite : bool
118+
If True, overwrite existing non-missing values in the field array.
119+
120+
Returns
121+
-------
122+
None
123+
"""
124+
valid_mask = ~is_missing(field[river_network.downstream_nodes[grouping]], mv)
125+
valid_indices = np.array(np.where(valid_mask))
126+
valid_indices[0] = grouping[valid_indices[0]]
127+
if not overwrite:
128+
temp_valid_indices = valid_indices[0]
129+
valid_mask = is_missing(field[valid_indices], mv)
130+
valid_indices = np.array(np.where(valid_mask))
131+
valid_indices[0] = temp_valid_indices[valid_indices[0]]
132+
field[tuple(valid_indices)] = field[river_network.downstream_nodes[valid_indices]]

src/earthkit/hydro/core.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
def flow(river_network, field, invert_graph, operation, mv):
2+
"""
3+
Apply an operation to a field along a river network.
4+
5+
Parameters
6+
----------
7+
river_network : RiverNetwork
8+
The river network object containing topological groups.
9+
field : ndarray
10+
The field data to be modified in place.
11+
invert_graph : bool
12+
If True, process the river network from sinks to sources.
13+
If False, process from sources to sinks.
14+
operation : callable
15+
The operation to apply to the field. This function should
16+
take four arguments: river_network, field, grouping, and mv.
17+
mv : any
18+
The value representing missing data in the field.
19+
Returns
20+
-------
21+
ndarray
22+
The modified field after applying the operation along the river network.
23+
"""
24+
25+
if invert_graph:
26+
groupings = river_network.topological_groups[:-1][::-1] # go from sinks to sources
27+
else:
28+
groupings = river_network.topological_groups[:-1] # go from sources to sinks
29+
30+
for grouping in groupings:
31+
# modify field in_place with desired operation
32+
# NB: this function needs to handle missing values
33+
# mv if they are allowed in input
34+
operation(river_network, field, grouping, mv)
35+
36+
return field

0 commit comments

Comments
 (0)