Skip to content

Commit 6478ed9

Browse files
committed
merge _accumulate.py to accumulate.py and _catchment.py to catchment.py
1 parent 9932ac1 commit 6478ed9

File tree

4 files changed

+149
-158
lines changed

4 files changed

+149
-158
lines changed

src/earthkit/hydro/_accumulation.py

Lines changed: 0 additions & 87 deletions
This file was deleted.

src/earthkit/hydro/_catchment.py

Lines changed: 0 additions & 63 deletions
This file was deleted.

src/earthkit/hydro/accumulation.py

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
11
import numpy as np
2-
from .utils import mask_and_unmask_data, check_missing
3-
from ._accumulation import (
4-
_ufunc_to_downstream,
5-
_ufunc_to_downstream_missing_values_2D,
6-
_ufunc_to_downstream_missing_values_ND,
7-
)
2+
from .utils import mask_and_unmask_data, check_missing, is_missing
83
from .core import flow
94

105

@@ -50,3 +45,88 @@ def operation(river_network, field, grouping, mv):
5045
return op(river_network, field, grouping, mv, ufunc=ufunc)
5146

5247
return flow(river_network, field, False, operation, mv)
48+
49+
50+
def _ufunc_to_downstream(river_network, field, grouping, mv, ufunc):
51+
"""
52+
Updates field in-place by applying a ufunc at the downstream nodes of the grouping, ignoring missing values.
53+
54+
Parameters
55+
----------
56+
river_network : earthkit.hydro.RiverNetwork
57+
An earthkit-hydro river network object.
58+
field : numpy.ndarray
59+
The input field.
60+
grouping : numpy.ndarray
61+
An array of indices.
62+
mv : scalar
63+
A missing value indicator (not used in the function but kept for consistency).
64+
ufunc : numpy.ufunc
65+
A universal function from the numpy library to be applied to the field data.
66+
Available ufuncs: https://numpy.org/doc/2.2/reference/ufuncs.html. Note: must allow two operands.
67+
68+
Returns
69+
-------
70+
None
71+
"""
72+
ufunc.at(field, river_network.downstream_nodes[grouping], field[grouping])
73+
74+
75+
def _ufunc_to_downstream_missing_values_2D(river_network, field, grouping, mv, ufunc):
76+
"""
77+
Applies a universal function (ufunc) to downstream nodes in a river network, dealing with missing values for 2D fields.
78+
79+
Parameters
80+
----------
81+
river_network : earthkit.hydro.RiverNetwork
82+
An earthkit-hydro river network object.
83+
field : numpy.ndarray
84+
The input field.
85+
grouping : numpy.ndarray
86+
An array of indices.
87+
mv : scalar
88+
A missing value indicator.
89+
ufunc : numpy.ufunc
90+
A universal function from the numpy library to be applied to the field data.
91+
Available ufuncs: https://numpy.org/doc/2.2/reference/ufuncs.html. Note: must allow two operands.
92+
93+
Returns
94+
-------
95+
None
96+
"""
97+
nodes_to_update = river_network.downstream_nodes[grouping]
98+
values_to_add = field[grouping]
99+
missing_indices = np.logical_or(is_missing(values_to_add, mv), is_missing(field[nodes_to_update], mv))
100+
ufunc.at(field, nodes_to_update, values_to_add)
101+
field[nodes_to_update[missing_indices]] = mv
102+
103+
104+
def _ufunc_to_downstream_missing_values_ND(river_network, field, grouping, mv, ufunc):
105+
"""
106+
Applies a universal function (ufunc) to downstream nodes in a river network, dealing with missing values for ND fields.
107+
108+
Parameters
109+
----------
110+
river_network : earthkit.hydro.RiverNetwork
111+
An earthkit-hydro river network object.
112+
field : numpy.ndarray
113+
The input field.
114+
grouping : numpy.ndarray
115+
An array of indices.
116+
mv : scalar
117+
A missing value indicator.
118+
ufunc : numpy.ufunc
119+
A universal function from the numpy library to be applied to the field data.
120+
Available ufuncs: https://numpy.org/doc/2.2/reference/ufuncs.html. Note: must allow two operands.
121+
122+
Returns
123+
-------
124+
None
125+
"""
126+
nodes_to_update = river_network.downstream_nodes[grouping]
127+
values_to_add = field[grouping]
128+
missing_indices = np.logical_or(is_missing(values_to_add, mv), is_missing(field[nodes_to_update], mv))
129+
ufunc.at(field, nodes_to_update, values_to_add)
130+
missing_indices = np.array(np.where(missing_indices))
131+
missing_indices[0] = nodes_to_update[missing_indices[0]]
132+
field[tuple(missing_indices)] = mv

src/earthkit/hydro/catchment.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import numpy as np
12
from .core import flow
2-
from .utils import mask_and_unmask_data
3-
from ._catchment import _find_catchments_2D, _find_catchments_ND
3+
from .utils import mask_and_unmask_data, is_missing
44

55

66
@mask_and_unmask_data
@@ -69,3 +69,64 @@ def operation(river_network, field, grouping, mv):
6969
return op(river_network, field, grouping, mv, overwrite=False)
7070

7171
return flow(river_network, field, True, operation, mv)
72+
73+
74+
def _find_catchments_2D(river_network, field, grouping, mv, overwrite):
75+
"""
76+
Updates field in-place with the value of its downstream nodes, dealing with missing values for 2D fields.
77+
78+
Parameters
79+
----------
80+
river_network : earthkit.hydro.RiverNetwork
81+
An earthkit-hydro river network object.
82+
field : numpy.ndarray
83+
The input field.
84+
grouping : numpy.ndarray
85+
The array of node indices.
86+
mv : scalar
87+
The missing value indicator.
88+
overwrite : bool
89+
If True, overwrite existing non-missing values in the field array.
90+
91+
Returns
92+
-------
93+
None
94+
"""
95+
valid_group = grouping[
96+
~is_missing(field[river_network.downstream_nodes[grouping]], mv)
97+
] # only update nodes where the downstream belongs to a catchment
98+
if not overwrite:
99+
valid_group = valid_group[is_missing(field[valid_group], mv)]
100+
field[valid_group] = field[river_network.downstream_nodes[valid_group]]
101+
102+
103+
def _find_catchments_ND(river_network, field, grouping, mv, overwrite):
104+
"""
105+
Updates field in-place with the value of its downstream nodes, dealing with missing values for ND fields.
106+
107+
Parameters
108+
----------
109+
river_network : earthkit.hydro.RiverNetwork
110+
An earthkit-hydro river network object.
111+
field : numpy.ndarray
112+
The input field.
113+
grouping : numpy.ndarray
114+
The array of node indices.
115+
mv : scalar
116+
The missing value indicator.
117+
overwrite : bool
118+
If True, overwrite existing non-missing values in the field array.
119+
120+
Returns
121+
-------
122+
None
123+
"""
124+
valid_mask = ~is_missing(field[river_network.downstream_nodes[grouping]], mv)
125+
valid_indices = np.array(np.where(valid_mask))
126+
valid_indices[0] = grouping[valid_indices[0]]
127+
if not overwrite:
128+
temp_valid_indices = valid_indices[0]
129+
valid_mask = is_missing(field[valid_indices], mv)
130+
valid_indices = np.array(np.where(valid_mask))
131+
valid_indices[0] = temp_valid_indices[valid_indices[0]]
132+
field[tuple(valid_indices)] = field[river_network.downstream_nodes[valid_indices]]

0 commit comments

Comments
 (0)