Skip to content

Commit e8b712c

Browse files
committed
edge weights to flow_downstream and flow_upstream
1 parent 58dbe02 commit e8b712c

File tree

6 files changed

+104
-99
lines changed

6 files changed

+104
-99
lines changed

src/earthkit/hydro/accumulation.py

Lines changed: 81 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ def flow_downstream(
2222
ufunc=np.add,
2323
accept_missing=False,
2424
skip_missing_check=False,
25-
additive_weight=None,
26-
multiplicative_weight=None,
27-
modifier_use_upstream=True,
25+
node_additive_weight=None,
26+
node_multiplicative_weight=None,
27+
node_modifier_use_upstream=True,
28+
edge_additive_weight=None,
29+
edge_multiplicative_weight=None,
2830
):
2931
"""Accumulates field values downstream.
3032
@@ -70,22 +72,24 @@ def flow_downstream(
7072
def operation(
7173
river_network,
7274
field,
73-
up_ids,
74-
down_ids,
75+
grouping,
7576
mv,
76-
additive_weight,
77-
multiplicative_weight,
78-
modifier_use_upstream,
77+
node_additive_weight,
78+
node_multiplicative_weight,
79+
node_modifier_use_upstream,
80+
edge_additive_weight,
81+
edge_multiplicative_weight,
7982
):
8083
return op(
8184
river_network,
8285
field,
83-
up_ids,
84-
down_ids,
86+
grouping,
8587
mv,
86-
additive_weight,
87-
multiplicative_weight,
88-
modifier_use_upstream,
88+
node_additive_weight,
89+
node_multiplicative_weight,
90+
node_modifier_use_upstream,
91+
edge_additive_weight,
92+
edge_multiplicative_weight,
8993
ufunc=ufunc,
9094
)
9195

@@ -95,9 +99,11 @@ def operation(
9599
False,
96100
operation,
97101
mv,
98-
additive_weight,
99-
multiplicative_weight,
100-
modifier_use_upstream,
102+
node_additive_weight,
103+
node_multiplicative_weight,
104+
node_modifier_use_upstream,
105+
edge_additive_weight,
106+
edge_multiplicative_weight,
101107
)
102108

103109
return nan_to_missing(field, field_dtype, mv)
@@ -106,12 +112,13 @@ def operation(
106112
def _ufunc_to_downstream(
107113
river_network,
108114
field,
109-
up_ids,
110-
down_ids,
115+
grouping,
111116
mv,
112-
additive_weight,
113-
multiplicative_weight,
114-
modifier_use_upstream,
117+
node_additive_weight,
118+
node_multiplicative_weight,
119+
node_modifier_use_upstream,
120+
edge_additive_weight,
121+
edge_multiplicative_weight,
115122
ufunc,
116123
):
117124
"""Updates field in-place by applying a ufunc at the downstream nodes of
@@ -144,20 +151,19 @@ def _ufunc_to_downstream(
144151
None
145152
146153
"""
147-
modifier_group = up_ids if modifier_use_upstream else down_ids
148-
if additive_weight is None:
149-
if multiplicative_weight is None:
150-
modifier_field = field[..., up_ids]
151-
else:
152-
modifier_field = field[..., up_ids] * multiplicative_weight[modifier_group]
153-
else:
154-
if multiplicative_weight is None:
155-
modifier_field = field[..., up_ids] + additive_weight[modifier_group]
156-
else:
157-
modifier_field = (
158-
field[..., up_ids] * multiplicative_weight[modifier_group]
159-
+ additive_weight[modifier_group]
160-
)
154+
up_ids, down_ids = river_network.get_up_down(grouping)
155+
modifier_group = up_ids if node_modifier_use_upstream else down_ids
156+
157+
modifier_field = field[..., up_ids]
158+
if node_multiplicative_weight is not None:
159+
modifier_field *= node_multiplicative_weight[modifier_group]
160+
if edge_multiplicative_weight is not None:
161+
modifier_field *= edge_multiplicative_weight[grouping]
162+
if node_additive_weight is not None:
163+
modifier_field += node_additive_weight[modifier_group]
164+
if edge_additive_weight is not None:
165+
modifier_field += edge_additive_weight[grouping]
166+
161167
ufunc.at(
162168
field,
163169
(*[slice(None)] * (field.ndim - 1), down_ids),
@@ -174,9 +180,11 @@ def flow_upstream(
174180
ufunc=np.add,
175181
accept_missing=False,
176182
skip_missing_check=False,
177-
additive_weight=None,
178-
multiplicative_weight=None,
179-
modifier_use_upstream=True,
183+
node_additive_weight=None,
184+
node_multiplicative_weight=None,
185+
node_modifier_use_upstream=True,
186+
edge_additive_weight=None,
187+
edge_multiplicative_weight=None,
180188
):
181189
"""Accumulates field values upstream.
182190
@@ -221,22 +229,24 @@ def flow_upstream(
221229
def operation(
222230
river_network,
223231
field,
224-
up_ids,
225-
down_ids,
232+
grouping,
226233
mv,
227-
additive_weight,
228-
multiplicative_weight,
229-
modifier_use_upstream,
234+
node_additive_weight,
235+
node_multiplicative_weight,
236+
node_modifier_use_upstream,
237+
edge_additive_weight,
238+
edge_multiplicative_weight,
230239
):
231240
return op(
232241
river_network,
233242
field,
234-
up_ids,
235-
down_ids,
243+
grouping,
236244
mv,
237-
additive_weight,
238-
multiplicative_weight,
239-
modifier_use_upstream,
245+
node_additive_weight,
246+
node_multiplicative_weight,
247+
node_modifier_use_upstream,
248+
edge_additive_weight,
249+
edge_multiplicative_weight,
240250
ufunc=ufunc,
241251
)
242252

@@ -246,9 +256,11 @@ def operation(
246256
True,
247257
operation,
248258
mv,
249-
additive_weight,
250-
multiplicative_weight,
251-
modifier_use_upstream,
259+
node_additive_weight,
260+
node_multiplicative_weight,
261+
node_modifier_use_upstream,
262+
edge_additive_weight,
263+
edge_multiplicative_weight,
252264
)
253265

254266
return nan_to_missing(field, field_dtype, mv)
@@ -257,12 +269,13 @@ def operation(
257269
def _ufunc_to_upstream(
258270
river_network,
259271
field,
260-
up_ids,
261-
down_ids,
272+
grouping,
262273
mv,
263-
additive_weight,
264-
multiplicative_weight,
265-
modifier_use_upstream,
274+
node_additive_weight,
275+
node_multiplicative_weight,
276+
node_modifier_use_upstream,
277+
edge_additive_weight,
278+
edge_multiplicative_weight,
266279
ufunc,
267280
):
268281
"""Updates field in-place by applying a ufunc at the nodes of
@@ -295,25 +308,18 @@ def _ufunc_to_upstream(
295308
None
296309
297310
"""
298-
down_group = down_ids
299-
modifier_group = up_ids if modifier_use_upstream else down_ids
300-
if additive_weight is None:
301-
if multiplicative_weight is None:
302-
modifier_field = field[..., down_group]
303-
else:
304-
modifier_field = (
305-
field[..., down_group] * multiplicative_weight[..., modifier_group]
306-
)
307-
else:
308-
if multiplicative_weight is None:
309-
modifier_field = (
310-
field[..., down_group] + additive_weight[..., modifier_group]
311-
)
312-
else:
313-
modifier_field = (
314-
field[..., down_group] * multiplicative_weight[..., modifier_group]
315-
+ additive_weight[..., modifier_group]
316-
)
311+
up_ids, down_ids = river_network.get_up_down(grouping)
312+
modifier_group = up_ids if node_modifier_use_upstream else down_ids
313+
314+
modifier_field = field[..., down_ids]
315+
if node_multiplicative_weight is not None:
316+
modifier_field *= node_multiplicative_weight[modifier_group]
317+
if edge_multiplicative_weight is not None:
318+
modifier_field *= edge_multiplicative_weight[grouping]
319+
if node_additive_weight is not None:
320+
modifier_field += node_additive_weight[modifier_group]
321+
if edge_additive_weight is not None:
322+
modifier_field += edge_additive_weight[grouping]
317323

318324
ufunc.at(
319325
field,

src/earthkit/hydro/catchments.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ def find(river_network, field, mv=0, in_place=False):
129129
else:
130130
op = _find_catchments_ND
131131

132-
def operation(river_network, field, up_ids, down_ids, mv):
133-
return op(river_network, field, up_ids, down_ids, mv, overwrite=True)
132+
def operation(river_network, field, grouping, mv):
133+
return op(river_network, field, grouping, mv, overwrite=True)
134134

135135
return flow(river_network, field, True, operation, mv)
136136

@@ -142,7 +142,7 @@ def operation(river_network, field, up_ids, down_ids, mv):
142142
globals()[metric] = func
143143

144144

145-
def _find_catchments_2D(river_network, field, grouping, down_ids, mv, overwrite):
145+
def _find_catchments_2D(river_network, field, grouping, mv, overwrite):
146146
"""Updates field in-place with the value of its downstream nodes, dealing
147147
with missing values for 2D fields.
148148
@@ -172,7 +172,7 @@ def _find_catchments_2D(river_network, field, grouping, down_ids, mv, overwrite)
172172
field[..., valid_group] = field[..., river_network.downstream_nodes[valid_group]]
173173

174174

175-
def _find_catchments_ND(river_network, field, grouping, down_ids, mv, overwrite):
175+
def _find_catchments_ND(river_network, field, grouping, mv, overwrite):
176176
"""Updates field in-place with the value of its downstream nodes, dealing
177177
with missing values for ND fields.
178178

src/earthkit/hydro/core.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,9 @@ def flow(river_network, field, invert_graph, operation, *args, **kwargs):
4141
groupings = river_network.topological_groups_edges
4242

4343
for grouping in groupings:
44-
up_ids, down_ids = river_network.get_up_down(grouping)
4544
# modify field in_place with desired operation
4645
# NB: this function needs to handle missing values
4746
# mv if they are allowed in input
48-
operation(river_network, field, up_ids, down_ids, *args, **kwargs)
47+
operation(river_network, field, grouping, *args, **kwargs)
4948

5049
return field

src/earthkit/hydro/distance.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,17 @@ def min(
6666
field,
6767
mv,
6868
ufunc=np.minimum,
69-
additive_weight=weights,
70-
modifier_use_upstream=True,
69+
node_additive_weight=weights,
70+
node_modifier_use_upstream=True,
7171
)
7272
if upstream:
7373
field = flow_upstream(
7474
river_network,
7575
field,
7676
mv,
7777
ufunc=np.minimum,
78-
additive_weight=weights,
79-
modifier_use_upstream=True,
78+
node_additive_weight=weights,
79+
node_modifier_use_upstream=True,
8080
)
8181

8282
out_field = np.empty(river_network.shape, dtype=field.dtype)
@@ -152,17 +152,17 @@ def max(
152152
field,
153153
mv,
154154
ufunc=np.maximum,
155-
additive_weight=weights,
156-
modifier_use_upstream=True,
155+
node_additive_weight=weights,
156+
node_modifier_use_upstream=True,
157157
)
158158
if upstream:
159159
field = flow_upstream(
160160
river_network,
161161
field,
162162
mv,
163163
ufunc=np.maximum,
164-
additive_weight=weights,
165-
modifier_use_upstream=True,
164+
node_additive_weight=weights,
165+
node_modifier_use_upstream=True,
166166
)
167167

168168
field = np.nan_to_num(field, neginf=np.inf)

src/earthkit/hydro/length.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,17 @@ def min(
6767
field,
6868
mv,
6969
ufunc=np.minimum,
70-
additive_weight=weights,
71-
modifier_use_upstream=False,
70+
node_additive_weight=weights,
71+
node_modifier_use_upstream=False,
7272
)
7373
if upstream:
7474
field = flow_upstream(
7575
river_network,
7676
field,
7777
mv,
7878
ufunc=np.minimum,
79-
additive_weight=weights,
80-
modifier_use_upstream=True,
79+
node_additive_weight=weights,
80+
node_modifier_use_upstream=True,
8181
)
8282

8383
out_field = np.empty(river_network.shape, dtype=field.dtype)
@@ -153,17 +153,17 @@ def max(
153153
field,
154154
mv,
155155
ufunc=np.maximum,
156-
additive_weight=weights,
157-
modifier_use_upstream=False,
156+
node_additive_weight=weights,
157+
node_modifier_use_upstream=False,
158158
)
159159
if upstream:
160160
field = flow_upstream(
161161
river_network,
162162
field,
163163
mv,
164164
ufunc=np.maximum,
165-
additive_weight=weights,
166-
modifier_use_upstream=True,
165+
node_additive_weight=weights,
166+
node_modifier_use_upstream=True,
167167
)
168168

169169
field = np.nan_to_num(field, neginf=np.inf)

src/earthkit/hydro/subcatchments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ def find(river_network, field, mv=0, in_place=False):
117117
else:
118118
op = _find_catchments_ND
119119

120-
def operation(river_network, field, grouping, down_ids, mv):
121-
return op(river_network, field, grouping, down_ids, mv, overwrite=False)
120+
def operation(river_network, field, grouping, mv):
121+
return op(river_network, field, grouping, mv, overwrite=False)
122122

123123
return flow(river_network, field, True, operation, mv)
124124

0 commit comments

Comments
 (0)