21
21
import numpy as np
22
22
import torch
23
23
24
- from monai .utils import ensure_tuple_size , get_package_version , optional_import , require_pkg , version_geq
24
+ from monai .utils import (
25
+ deprecated_arg ,
26
+ ensure_tuple_size ,
27
+ get_package_version ,
28
+ optional_import ,
29
+ require_pkg ,
30
+ version_geq ,
31
+ )
25
32
26
33
if TYPE_CHECKING :
27
34
import zarr
@@ -218,15 +225,41 @@ class ZarrAvgMerger(Merger):
218
225
store: the zarr store to save the final results. Default is "merged.zarr".
219
226
value_store: the zarr store to save the value aggregating tensor. Default is a temporary store.
220
227
count_store: the zarr store to save the sample counting tensor. Default is a temporary store.
221
- compressor: the compressor for final merged zarr array. Default is "default".
228
+ compressor: the compressor for final merged zarr array. Default is None.
229
+ Deprecated since 1.5.0 and will be removed in 1.7.0. Use codecs instead.
222
230
value_compressor: the compressor for value aggregating zarr array. Default is None.
231
+ Deprecated since 1.5.0 and will be removed in 1.7.0. Use value_codecs instead.
223
232
count_compressor: the compressor for sample counting zarr array. Default is None.
233
+ Deprecated since 1.5.0 and will be removed in 1.7.0. Use count_codecs instead.
234
+ codecs: the codecs for final merged zarr array. Default is None.
235
+ For zarr v3, this is a list of codec configurations. See zarr documentation for details.
236
+ value_codecs: the codecs for value aggregating zarr array. Default is None.
237
+ For zarr v3, this is a list of codec configurations. See zarr documentation for details.
238
+ count_codecs: the codecs for sample counting zarr array. Default is None.
239
+ For zarr v3, this is a list of codec configurations. See zarr documentation for details.
224
240
chunks : int or tuple of ints that defines the chunk shape, or boolean. Default is True.
225
241
If True, chunk shape will be guessed from `shape` and `dtype`.
226
242
If False, it will be set to `shape`, i.e., single chunk for the whole array.
227
243
If an int, the chunk size in each dimension will be given by the value of `chunks`.
228
244
"""
229
245
246
+ @deprecated_arg (
247
+ name = "compressor" , since = "1.5.0" , removed = "1.7.0" , new_name = "codecs" , msg_suffix = "Please use 'codecs' instead."
248
+ )
249
+ @deprecated_arg (
250
+ name = "value_compressor" ,
251
+ since = "1.5.0" ,
252
+ removed = "1.7.0" ,
253
+ new_name = "value_codecs" ,
254
+ msg_suffix = "Please use 'value_codecs' instead." ,
255
+ )
256
+ @deprecated_arg (
257
+ name = "count_compressor" ,
258
+ since = "1.5.0" ,
259
+ removed = "1.7.0" ,
260
+ new_name = "count_codecs" ,
261
+ msg_suffix = "Please use 'count_codecs' instead." ,
262
+ )
230
263
def __init__ (
231
264
self ,
232
265
merged_shape : Sequence [int ],
@@ -240,6 +273,9 @@ def __init__(
240
273
compressor : str | None = None ,
241
274
value_compressor : str | None = None ,
242
275
count_compressor : str | None = None ,
276
+ codecs : list | None = None ,
277
+ value_codecs : list | None = None ,
278
+ count_codecs : list | None = None ,
243
279
chunks : Sequence [int ] | bool = True ,
244
280
thread_locking : bool = True ,
245
281
) -> None :
@@ -251,7 +287,11 @@ def __init__(
251
287
self .count_dtype = count_dtype
252
288
self .store = store
253
289
self .tmpdir : TemporaryDirectory | None
254
- if version_geq (get_package_version ("zarr" ), "3.0.0" ):
290
+
291
+ # Handle zarr v3 vs older versions
292
+ is_zarr_v3 = version_geq (get_package_version ("zarr" ), "3.0.0" )
293
+
294
+ if is_zarr_v3 :
255
295
if value_store is None :
256
296
self .tmpdir = TemporaryDirectory ()
257
297
self .value_store = zarr .storage .LocalStore (self .tmpdir .name ) # type: ignore
@@ -266,34 +306,119 @@ def __init__(
266
306
self .tmpdir = None
267
307
self .value_store = zarr .storage .TempStore () if value_store is None else value_store # type: ignore
268
308
self .count_store = zarr .storage .TempStore () if count_store is None else count_store # type: ignore
309
+
269
310
self .chunks = chunks
270
- self .compressor = compressor
271
- self .value_compressor = value_compressor
272
- self .count_compressor = count_compressor
273
- self .output = zarr .empty (
274
- shape = self .merged_shape ,
275
- chunks = self .chunks ,
276
- dtype = self .output_dtype ,
277
- compressor = self .compressor ,
278
- store = self .store ,
279
- overwrite = True ,
280
- )
281
- self .values = zarr .zeros (
282
- shape = self .merged_shape ,
283
- chunks = self .chunks ,
284
- dtype = self .value_dtype ,
285
- compressor = self .value_compressor ,
286
- store = self .value_store ,
287
- overwrite = True ,
288
- )
289
- self .counts = zarr .zeros (
290
- shape = self .merged_shape ,
291
- chunks = self .chunks ,
292
- dtype = self .count_dtype ,
293
- compressor = self .count_compressor ,
294
- store = self .count_store ,
295
- overwrite = True ,
296
- )
311
+
312
+ # Handle compressor/codecs based on zarr version
313
+ is_zarr_v3 = version_geq (get_package_version ("zarr" ), "3.0.0" )
314
+
315
+ # Initialize codecs/compressor attributes with proper types
316
+ self .codecs : list | None = None
317
+ self .value_codecs : list | None = None
318
+ self .count_codecs : list | None = None
319
+
320
+ if is_zarr_v3 :
321
+ # For zarr v3, use codecs or convert compressor to codecs
322
+ if codecs is not None :
323
+ self .codecs = codecs
324
+ elif compressor is not None :
325
+ # Convert compressor to codec format
326
+ if isinstance (compressor , (list , tuple )):
327
+ self .codecs = compressor
328
+ else :
329
+ self .codecs = [compressor ]
330
+ else :
331
+ self .codecs = None
332
+
333
+ if value_codecs is not None :
334
+ self .value_codecs = value_codecs
335
+ elif value_compressor is not None :
336
+ if isinstance (value_compressor , (list , tuple )):
337
+ self .value_codecs = value_compressor
338
+ else :
339
+ self .value_codecs = [value_compressor ]
340
+ else :
341
+ self .value_codecs = None
342
+
343
+ if count_codecs is not None :
344
+ self .count_codecs = count_codecs
345
+ elif count_compressor is not None :
346
+ if isinstance (count_compressor , (list , tuple )):
347
+ self .count_codecs = count_compressor
348
+ else :
349
+ self .count_codecs = [count_compressor ]
350
+ else :
351
+ self .count_codecs = None
352
+ else :
353
+ # For zarr v2, use compressors
354
+ if codecs is not None :
355
+ # If codecs are specified in v2, use the first codec as compressor
356
+ self .codecs = codecs [0 ] if isinstance (codecs , (list , tuple )) else codecs
357
+ else :
358
+ self .codecs = compressor # type: ignore[assignment]
359
+
360
+ if value_codecs is not None :
361
+ self .value_codecs = value_codecs [0 ] if isinstance (value_codecs , (list , tuple )) else value_codecs
362
+ else :
363
+ self .value_codecs = value_compressor # type: ignore[assignment]
364
+
365
+ if count_codecs is not None :
366
+ self .count_codecs = count_codecs [0 ] if isinstance (count_codecs , (list , tuple )) else count_codecs
367
+ else :
368
+ self .count_codecs = count_compressor # type: ignore[assignment]
369
+
370
+ # Create zarr arrays with appropriate parameters based on version
371
+ if is_zarr_v3 :
372
+ self .output = zarr .empty (
373
+ shape = self .merged_shape ,
374
+ chunks = self .chunks ,
375
+ dtype = self .output_dtype ,
376
+ codecs = self .codecs ,
377
+ store = self .store ,
378
+ overwrite = True ,
379
+ )
380
+ self .values = zarr .zeros (
381
+ shape = self .merged_shape ,
382
+ chunks = self .chunks ,
383
+ dtype = self .value_dtype ,
384
+ codecs = self .value_codecs ,
385
+ store = self .value_store ,
386
+ overwrite = True ,
387
+ )
388
+ self .counts = zarr .zeros (
389
+ shape = self .merged_shape ,
390
+ chunks = self .chunks ,
391
+ dtype = self .count_dtype ,
392
+ codecs = self .count_codecs ,
393
+ store = self .count_store ,
394
+ overwrite = True ,
395
+ )
396
+ else :
397
+ self .output = zarr .empty (
398
+ shape = self .merged_shape ,
399
+ chunks = self .chunks ,
400
+ dtype = self .output_dtype ,
401
+ compressor = self .codecs ,
402
+ store = self .store ,
403
+ overwrite = True ,
404
+ )
405
+ self .values = zarr .zeros (
406
+ shape = self .merged_shape ,
407
+ chunks = self .chunks ,
408
+ dtype = self .value_dtype ,
409
+ compressor = self .value_codecs ,
410
+ store = self .value_store ,
411
+ overwrite = True ,
412
+ )
413
+ self .counts = zarr .zeros (
414
+ shape = self .merged_shape ,
415
+ chunks = self .chunks ,
416
+ dtype = self .count_dtype ,
417
+ compressor = self .count_codecs ,
418
+ store = self .count_store ,
419
+ overwrite = True ,
420
+ )
421
+
297
422
self .lock : threading .Lock | nullcontext
298
423
if thread_locking :
299
424
# use lock to protect the in-place addition during aggregation
0 commit comments