@@ -72,9 +72,9 @@ def check_missing(field, mv, accept_missing):
72
72
return missing_values_present
73
73
74
74
75
- def mask_data (func ):
75
+ def mask_2d (func ):
76
76
"""
77
- Decorator to allow function to accept 2d inputs.
77
+ Decorator to allow function to mask 2d inputs to the river network .
78
78
79
79
Parameters
80
80
----------
@@ -107,6 +107,52 @@ def wrapper(self, field, *args, **kwargs):
107
107
numpy.ndarray
108
108
The processed field.
109
109
"""
110
+ if field .shape [- 2 :] == self .mask .shape :
111
+ return func (self , field [..., self .mask ].T , * args , ** kwargs )
112
+ else :
113
+ return func (self , field .T , * args , ** kwargs )
114
+
115
+ return wrapper
116
+
117
+
118
+ def mask_and_unmask_data (func ):
119
+ """
120
+ Decorator to convert masked 2d inputs back to 1d.
121
+
122
+ Parameters
123
+ ----------
124
+ func : callable
125
+ The function to be wrapped and executed with masking applied.
126
+
127
+ Returns
128
+ -------
129
+ callable
130
+ The wrapped function.
131
+ """
132
+
133
+ def wrapper (self , field , * args , ** kwargs ):
134
+ """
135
+ Wrapper masking 2d data fields to allow for processing along the river network, then undoing the masking.
136
+
137
+ Parameters
138
+ ----------
139
+ self : object
140
+ The RiverNetwork instance calling the method.
141
+ field : numpy.ndarray
142
+ The input data field to be processed.
143
+ *args : tuple
144
+ Positional arguments passed to the wrapped function.
145
+ **kwargs : dict
146
+ Keyword arguments passed to the wrapped function.
147
+
148
+ Returns
149
+ -------
150
+ numpy.ndarray
151
+ The processed field.
152
+ """
153
+ # gets the missing value from the keyword arguments if it is present, otherwise takes default value of mv from func
154
+ mv = kwargs .get ("mv" )
155
+ mv = mv if mv is not None else func .__defaults__ [0 ]
110
156
if field .shape [- 2 :] == self .mask .shape :
111
157
in_place = kwargs .get ("in_place" , False )
112
158
if in_place :
@@ -115,10 +161,6 @@ def wrapper(self, field, *args, **kwargs):
115
161
out_field = np .empty (field .shape , dtype = field .dtype )
116
162
out_field [..., self .mask ] = func (self , field [..., self .mask ].T , * args , ** kwargs ).T
117
163
118
- # gets the missing value from the keyword arguments if it is present, otherwise takes default value of mv from func
119
- mv = kwargs .get ("mv" )
120
- mv = mv if mv is not None else func .__defaults__ [0 ]
121
-
122
164
out_field [..., ~ self .mask ] = mv
123
165
return out_field
124
166
else :
@@ -149,7 +191,7 @@ class RiverNetwork:
149
191
Groups of nodes sorted in topological order.
150
192
"""
151
193
152
- def __init__ (self , nodes , downstream , mask ) -> None :
194
+ def __init__ (self , nodes , downstream , mask , sinks = None , sources = None , topological_labels = None ) -> None :
153
195
"""
154
196
Initialises the RiverNetwork with nodes, downstream nodes, and a mask.
155
197
@@ -166,11 +208,57 @@ def __init__(self, nodes, downstream, mask) -> None:
166
208
self .n_nodes = len (nodes )
167
209
self .downstream_nodes = downstream
168
210
self .mask = mask
169
- self .sinks = self .nodes [self .downstream_nodes == self .n_nodes ] # nodes with no downstreams
170
- print ("finding sources" )
171
- self .sources = self .get_sources () # nodes with no upstreams
172
- print ("topological sorting" )
173
- self .topological_groups = self .topological_sort ()
211
+ self .sinks = (
212
+ sinks if sinks is not None else self .nodes [self .downstream_nodes == self .n_nodes ]
213
+ ) # nodes with no downstreams
214
+ self .sources = sources if sources is not None else self .get_sources () # nodes with no upstreams
215
+ self .topological_labels = (
216
+ topological_labels if topological_labels is not None else self .compute_topological_labels ()
217
+ )
218
+ self .topological_groups = self .topological_groups_from_labels ()
219
+
220
+ @mask_2d
221
+ def create_subnetwork (self , field , recompute = False , * args , ** kwargs ):
222
+ """
223
+ Creates a subnetwork from the river network based on a mask.
224
+
225
+ Parameters
226
+ ----------
227
+ field : numpy.ndarray
228
+ A boolean mask to subset the river network.
229
+ recompute : bool, optional
230
+ If True, recomputes the topological labels for the subnetwork (default is False).
231
+
232
+ Returns
233
+ -------
234
+ RiverNetwork
235
+ A subnetwork of the river network.
236
+ """
237
+ river_network_mask = field
238
+ valid_indices = np .where (self .mask )
239
+ new_valid_indices = (valid_indices [0 ][river_network_mask ], valid_indices [1 ][river_network_mask ])
240
+ domain_mask = np .full (self .mask .shape , False )
241
+ domain_mask [new_valid_indices ] = True
242
+
243
+ downstream_indices = self .downstream_nodes [river_network_mask ]
244
+ n_nodes = len (downstream_indices ) # number of nodes in the subnetwork
245
+ # create new array of network nodes, setting all nodes not in mask to n_nodes
246
+ subnetwork_nodes = np .full (self .n_nodes , n_nodes )
247
+ subnetwork_nodes [river_network_mask ] = np .arange (n_nodes )
248
+ # get downstream nodes in the subnetwork
249
+ non_sinks = np .where (downstream_indices != self .n_nodes )
250
+ downstream = np .full (n_nodes , n_nodes )
251
+ downstream [non_sinks ] = subnetwork_nodes [downstream_indices [non_sinks ]]
252
+ nodes = np .arange (n_nodes )
253
+
254
+ if not recompute :
255
+ sinks = nodes [downstream == n_nodes ]
256
+ topological_labels = self .topological_labels [river_network_mask ]
257
+ topological_labels [sinks ] = self .n_nodes
258
+
259
+ return RiverNetwork (nodes , downstream , domain_mask , sinks = sinks , topological_labels = topological_labels )
260
+ else :
261
+ return RiverNetwork (nodes , downstream , domain_mask )
174
262
175
263
def get_sources (self ):
176
264
"""
@@ -187,14 +275,14 @@ def get_sources(self):
187
275
inlets = tmp_nodes [tmp_nodes != - 1 ] # sources are nodes that are not downstream nodes
188
276
return inlets
189
277
190
- def topological_sort (self ):
278
+ def compute_topological_labels (self ):
191
279
"""
192
- Performs a topological sorting of the nodes in the river network.
280
+ Finds the topological distance labels for each node in the river network.
193
281
194
282
Returns
195
283
-------
196
- list of numpy.ndarray
197
- A list of groups of nodes sorted in topological order .
284
+ numpy.ndarray
285
+ Array of topological distance labels for each node .
198
286
"""
199
287
inlets = self .sources
200
288
labels = np .zeros (self .n_nodes , dtype = int )
@@ -209,10 +297,9 @@ def topological_sort(self):
209
297
n += 1
210
298
current_sum = np .sum (labels )
211
299
labels [self .sinks ] = n # put all sinks in last group in topological ordering
212
- groups = self .group_labels (labels )
213
- return groups
300
+ return labels
214
301
215
- def group_labels (self , labels ):
302
+ def topological_groups_from_labels (self ):
216
303
"""
217
304
Groups nodes by their topological distance labels.
218
305
@@ -226,14 +313,14 @@ def group_labels(self, labels):
226
313
list of numpy.ndarray
227
314
A list of subarrays, each containing nodes with the same label.
228
315
"""
229
- sorted_indices = np .argsort (labels ) # sort by labels
316
+ sorted_indices = np .argsort (self . topological_labels ) # sort by labels
230
317
sorted_array = self .nodes [sorted_indices ]
231
- sorted_labels = labels [sorted_indices ]
318
+ sorted_labels = self . topological_labels [sorted_indices ]
232
319
_ , indices = np .unique (sorted_labels , return_index = True ) # find index of first occurrence of each label
233
320
subarrays = np .split (sorted_array , indices [1 :]) # split array at each first occurrence of a label
234
321
return subarrays
235
322
236
- @mask_data
323
+ @mask_and_unmask_data
237
324
def accuflux (self , field , mv = np .nan , in_place = False , operation = np .add , accept_missing = False ):
238
325
"""
239
326
Accumulate a field downstream along the river network.
@@ -274,7 +361,7 @@ def accuflux(self, field, mv=np.nan, in_place=False, operation=np.add, accept_mi
274
361
field [nodes_to_update [missing_indices ]] = mv
275
362
return field
276
363
277
- @mask_data
364
+ @mask_and_unmask_data
278
365
def upstream (self , field , mv = np .nan , operation = np .add , accept_missing = False ):
279
366
"""
280
367
Sets each node to be the sum of its upstream nodes values, or a missing value.
@@ -307,7 +394,7 @@ def upstream(self, field, mv=np.nan, operation=np.add, accept_missing=False):
307
394
ups [nodes_to_update [missing_indices ]] = mv
308
395
return ups
309
396
310
- @mask_data
397
+ @mask_and_unmask_data
311
398
def downstream (self , field , mv = np .nan , accept_missing = False ):
312
399
"""
313
400
Sets each node to be its downstream node value, or a missing value.
@@ -333,7 +420,7 @@ def downstream(self, field, mv=np.nan, accept_missing=False):
333
420
down [mask ] = field [self .downstream_nodes [mask ]]
334
421
return down
335
422
336
- @mask_data
423
+ @mask_and_unmask_data
337
424
def catchment (self , field , mv = 0 , overwrite = True ):
338
425
"""
339
426
Propagates a field upstream to find catchments.
@@ -361,7 +448,7 @@ def catchment(self, field, mv=0, overwrite=True):
361
448
field [valid_group ] = field [self .downstream_nodes [valid_group ]]
362
449
return field
363
450
364
- @mask_data
451
+ @mask_and_unmask_data
365
452
def subcatchment (self , field , mv = 0 ):
366
453
"""
367
454
Propagates a field upstream to find subcatchments.
0 commit comments