Skip to content

Commit 36a08aa

Browse files
authored
Merge pull request #3 from ecmwf/feat/subnetwork
Feat/subnetwork
2 parents dcca6e7 + bdafb9f commit 36a08aa

File tree

6 files changed

+397
-54
lines changed

6 files changed

+397
-54
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@ repos:
88
rev: v0.5.6
99
hooks:
1010
- id: ruff # fix linting violations
11+
types_or: [ python, pyi, jupyter ]
1112
args: [ --fix ]
1213
- id: ruff-format # fix formatting
14+
types_or: [ python, pyi, jupyter ]
1315
- repo: https://github.com/pre-commit/pre-commit-hooks
1416
rev: v4.4.0
1517
hooks:

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,11 @@ Finds the subcatchments (all upstream nodes of specified nodes, without overwrit
107107

108108
<img src="docs/images/subcatchment.gif" width="200px" height="160px" />
109109

110+
```
111+
network.create_subnetwork(mask)
112+
```
113+
Computes the river subnetwork defined by a mask of the domain.
114+
110115
```
111116
network.export(filename)
112117
```

docs/notebooks/example.ipynb

Lines changed: 237 additions & 28 deletions
Large diffs are not rendered by default.

src/earthkit/hydro/river_network.py

Lines changed: 113 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ def check_missing(field, mv, accept_missing):
7272
return missing_values_present
7373

7474

75-
def mask_data(func):
75+
def mask_2d(func):
7676
"""
77-
Decorator to allow function to accept 2d inputs.
77+
Decorator to allow function to mask 2d inputs to the river network.
7878
7979
Parameters
8080
----------
@@ -107,6 +107,52 @@ def wrapper(self, field, *args, **kwargs):
107107
numpy.ndarray
108108
The processed field.
109109
"""
110+
if field.shape[-2:] == self.mask.shape:
111+
return func(self, field[..., self.mask].T, *args, **kwargs)
112+
else:
113+
return func(self, field.T, *args, **kwargs)
114+
115+
return wrapper
116+
117+
118+
def mask_and_unmask_data(func):
119+
"""
120+
Decorator to convert masked 2d inputs back to 1d.
121+
122+
Parameters
123+
----------
124+
func : callable
125+
The function to be wrapped and executed with masking applied.
126+
127+
Returns
128+
-------
129+
callable
130+
The wrapped function.
131+
"""
132+
133+
def wrapper(self, field, *args, **kwargs):
134+
"""
135+
Wrapper masking 2d data fields to allow for processing along the river network, then undoing the masking.
136+
137+
Parameters
138+
----------
139+
self : object
140+
The RiverNetwork instance calling the method.
141+
field : numpy.ndarray
142+
The input data field to be processed.
143+
*args : tuple
144+
Positional arguments passed to the wrapped function.
145+
**kwargs : dict
146+
Keyword arguments passed to the wrapped function.
147+
148+
Returns
149+
-------
150+
numpy.ndarray
151+
The processed field.
152+
"""
153+
# gets the missing value from the keyword arguments if it is present, otherwise takes default value of mv from func
154+
mv = kwargs.get("mv")
155+
mv = mv if mv is not None else func.__defaults__[0]
110156
if field.shape[-2:] == self.mask.shape:
111157
in_place = kwargs.get("in_place", False)
112158
if in_place:
@@ -115,10 +161,6 @@ def wrapper(self, field, *args, **kwargs):
115161
out_field = np.empty(field.shape, dtype=field.dtype)
116162
out_field[..., self.mask] = func(self, field[..., self.mask].T, *args, **kwargs).T
117163

118-
# gets the missing value from the keyword arguments if it is present, otherwise takes default value of mv from func
119-
mv = kwargs.get("mv")
120-
mv = mv if mv is not None else func.__defaults__[0]
121-
122164
out_field[..., ~self.mask] = mv
123165
return out_field
124166
else:
@@ -149,7 +191,7 @@ class RiverNetwork:
149191
Groups of nodes sorted in topological order.
150192
"""
151193

152-
def __init__(self, nodes, downstream, mask) -> None:
194+
def __init__(self, nodes, downstream, mask, sinks=None, sources=None, topological_labels=None) -> None:
153195
"""
154196
Initialises the RiverNetwork with nodes, downstream nodes, and a mask.
155197
@@ -166,11 +208,57 @@ def __init__(self, nodes, downstream, mask) -> None:
166208
self.n_nodes = len(nodes)
167209
self.downstream_nodes = downstream
168210
self.mask = mask
169-
self.sinks = self.nodes[self.downstream_nodes == self.n_nodes] # nodes with no downstreams
170-
print("finding sources")
171-
self.sources = self.get_sources() # nodes with no upstreams
172-
print("topological sorting")
173-
self.topological_groups = self.topological_sort()
211+
self.sinks = (
212+
sinks if sinks is not None else self.nodes[self.downstream_nodes == self.n_nodes]
213+
) # nodes with no downstreams
214+
self.sources = sources if sources is not None else self.get_sources() # nodes with no upstreams
215+
self.topological_labels = (
216+
topological_labels if topological_labels is not None else self.compute_topological_labels()
217+
)
218+
self.topological_groups = self.topological_groups_from_labels()
219+
220+
@mask_2d
221+
def create_subnetwork(self, field, recompute=False, *args, **kwargs):
222+
"""
223+
Creates a subnetwork from the river network based on a mask.
224+
225+
Parameters
226+
----------
227+
field : numpy.ndarray
228+
A boolean mask to subset the river network.
229+
recompute : bool, optional
230+
If True, recomputes the topological labels for the subnetwork (default is False).
231+
232+
Returns
233+
-------
234+
RiverNetwork
235+
A subnetwork of the river network.
236+
"""
237+
river_network_mask = field
238+
valid_indices = np.where(self.mask)
239+
new_valid_indices = (valid_indices[0][river_network_mask], valid_indices[1][river_network_mask])
240+
domain_mask = np.full(self.mask.shape, False)
241+
domain_mask[new_valid_indices] = True
242+
243+
downstream_indices = self.downstream_nodes[river_network_mask]
244+
n_nodes = len(downstream_indices) # number of nodes in the subnetwork
245+
# create new array of network nodes, setting all nodes not in mask to n_nodes
246+
subnetwork_nodes = np.full(self.n_nodes, n_nodes)
247+
subnetwork_nodes[river_network_mask] = np.arange(n_nodes)
248+
# get downstream nodes in the subnetwork
249+
non_sinks = np.where(downstream_indices != self.n_nodes)
250+
downstream = np.full(n_nodes, n_nodes)
251+
downstream[non_sinks] = subnetwork_nodes[downstream_indices[non_sinks]]
252+
nodes = np.arange(n_nodes)
253+
254+
if not recompute:
255+
sinks = nodes[downstream == n_nodes]
256+
topological_labels = self.topological_labels[river_network_mask]
257+
topological_labels[sinks] = self.n_nodes
258+
259+
return RiverNetwork(nodes, downstream, domain_mask, sinks=sinks, topological_labels=topological_labels)
260+
else:
261+
return RiverNetwork(nodes, downstream, domain_mask)
174262

175263
def get_sources(self):
176264
"""
@@ -187,14 +275,14 @@ def get_sources(self):
187275
inlets = tmp_nodes[tmp_nodes != -1] # sources are nodes that are not downstream nodes
188276
return inlets
189277

190-
def topological_sort(self):
278+
def compute_topological_labels(self):
191279
"""
192-
Performs a topological sorting of the nodes in the river network.
280+
Finds the topological distance labels for each node in the river network.
193281
194282
Returns
195283
-------
196-
list of numpy.ndarray
197-
A list of groups of nodes sorted in topological order.
284+
numpy.ndarray
285+
Array of topological distance labels for each node.
198286
"""
199287
inlets = self.sources
200288
labels = np.zeros(self.n_nodes, dtype=int)
@@ -209,10 +297,9 @@ def topological_sort(self):
209297
n += 1
210298
current_sum = np.sum(labels)
211299
labels[self.sinks] = n # put all sinks in last group in topological ordering
212-
groups = self.group_labels(labels)
213-
return groups
300+
return labels
214301

215-
def group_labels(self, labels):
302+
def topological_groups_from_labels(self):
216303
"""
217304
Groups nodes by their topological distance labels.
218305
@@ -226,14 +313,14 @@ def group_labels(self, labels):
226313
list of numpy.ndarray
227314
A list of subarrays, each containing nodes with the same label.
228315
"""
229-
sorted_indices = np.argsort(labels) # sort by labels
316+
sorted_indices = np.argsort(self.topological_labels) # sort by labels
230317
sorted_array = self.nodes[sorted_indices]
231-
sorted_labels = labels[sorted_indices]
318+
sorted_labels = self.topological_labels[sorted_indices]
232319
_, indices = np.unique(sorted_labels, return_index=True) # find index of first occurrence of each label
233320
subarrays = np.split(sorted_array, indices[1:]) # split array at each first occurrence of a label
234321
return subarrays
235322

236-
@mask_data
323+
@mask_and_unmask_data
237324
def accuflux(self, field, mv=np.nan, in_place=False, operation=np.add, accept_missing=False):
238325
"""
239326
Accumulate a field downstream along the river network.
@@ -274,7 +361,7 @@ def accuflux(self, field, mv=np.nan, in_place=False, operation=np.add, accept_mi
274361
field[nodes_to_update[missing_indices]] = mv
275362
return field
276363

277-
@mask_data
364+
@mask_and_unmask_data
278365
def upstream(self, field, mv=np.nan, operation=np.add, accept_missing=False):
279366
"""
280367
Sets each node to be the sum of its upstream nodes values, or a missing value.
@@ -307,7 +394,7 @@ def upstream(self, field, mv=np.nan, operation=np.add, accept_missing=False):
307394
ups[nodes_to_update[missing_indices]] = mv
308395
return ups
309396

310-
@mask_data
397+
@mask_and_unmask_data
311398
def downstream(self, field, mv=np.nan, accept_missing=False):
312399
"""
313400
Sets each node to be its downstream node value, or a missing value.
@@ -333,7 +420,7 @@ def downstream(self, field, mv=np.nan, accept_missing=False):
333420
down[mask] = field[self.downstream_nodes[mask]]
334421
return down
335422

336-
@mask_data
423+
@mask_and_unmask_data
337424
def catchment(self, field, mv=0, overwrite=True):
338425
"""
339426
Propagates a field upstream to find catchments.
@@ -361,7 +448,7 @@ def catchment(self, field, mv=0, overwrite=True):
361448
field[valid_group] = field[self.downstream_nodes[valid_group]]
362449
return field
363450

364-
@mask_data
451+
@mask_and_unmask_data
365452
def subcatchment(self, field, mv=0):
366453
"""
367454
Propagates a field upstream to find subcatchments.

tests/conftest.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,25 @@ def catchment_1():
244244
@fixture
245245
def catchment_2():
246246
return np.array([4, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 4, 2, 2, 2])
247+
248+
249+
# @fixture
250+
# def mask_2():
251+
# return np.array(
252+
# [
253+
# [True, True, True, True],
254+
# [True, True, False, True,],
255+
# [True, True, True, True,],
256+
# [True, False, True, True,],
257+
# ]
258+
# )
259+
260+
261+
@fixture
262+
def mask_2():
263+
return np.array([True, True, True, True, True, True, False, True, True, True, True, True, True, False, True, True])
264+
265+
266+
@fixture
267+
def masked_unit_accuflux_2():
268+
return np.array([2, 1, 2, 1, 1, 2, 3, 1, 1, 3, 6, 1, 1, 2])

tests/test_river_network.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,3 +258,21 @@ def test_catchment_2d(reader, map_name, query_field, catchment):
258258
print(network_catchment)
259259
np.testing.assert_array_equal(network_catchment[network.mask], catchment)
260260
np.testing.assert_array_equal(network_catchment[~network.mask], 0)
261+
262+
263+
@parametrize(
264+
"reader,map_name,mask,accuflux",
265+
[
266+
("d8_ldd", d8_ldd_2, mask_2, masked_unit_accuflux_2),
267+
("cama_downxy", cama_downxy_2, mask_2, masked_unit_accuflux_2),
268+
("cama_nextxy", cama_nextxy_2, mask_2, masked_unit_accuflux_2),
269+
],
270+
)
271+
def test_subnetwork(reader, map_name, mask, accuflux):
272+
network = read_network(reader, map_name)
273+
network = network.create_subnetwork(mask, on_domain=False)
274+
field = np.ones(network.n_nodes)
275+
accum = network.accuflux(field)
276+
print(accum)
277+
print(accuflux)
278+
np.testing.assert_array_equal(accum, accuflux)

0 commit comments

Comments
 (0)