Skip to content

Commit 3d6021a

Browse files
kolasaniv1996pre-commit-ci[bot]KumoLiu
authored
Fix: ZarrAvgMerger ValueError with zarr_format 3 (#8477)
This PR fixes the issue with ZarrAvgMerger when using zarr v3, where the compressor argument is not supported. Changes: - Added deprecated_arg decorator for compressor, value_compressor, and count_compressor (since 1.5.0, to be removed in 1.7.0) - Added codecs, value_codecs, and count_codecs support for zarr format 3 - Updated tests to handle both zarr v2 and zarr v3 compatibility - Fixed issue where tests would fail with zarr v3 due to compressor usage This addresses the issue reported in #8476 where using compressor with zarr v3 causes ValueError. Signed-off-by: kolasaniv1996 <[email protected]> --------- Signed-off-by: kolasaniv1996 <[email protected]> Signed-off-by: vivek kolasani <[email protected]> Signed-off-by: YunLiu <[email protected]> Co-authored-by: kolasaniv1996 <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: YunLiu <[email protected]>
1 parent d38c93f commit 3d6021a

File tree

2 files changed

+277
-49
lines changed

2 files changed

+277
-49
lines changed

monai/inferers/merger.py

Lines changed: 155 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,14 @@
2121
import numpy as np
2222
import torch
2323

24-
from monai.utils import ensure_tuple_size, get_package_version, optional_import, require_pkg, version_geq
24+
from monai.utils import (
25+
deprecated_arg,
26+
ensure_tuple_size,
27+
get_package_version,
28+
optional_import,
29+
require_pkg,
30+
version_geq,
31+
)
2532

2633
if TYPE_CHECKING:
2734
import zarr
@@ -218,15 +225,41 @@ class ZarrAvgMerger(Merger):
218225
store: the zarr store to save the final results. Default is "merged.zarr".
219226
value_store: the zarr store to save the value aggregating tensor. Default is a temporary store.
220227
count_store: the zarr store to save the sample counting tensor. Default is a temporary store.
221-
compressor: the compressor for final merged zarr array. Default is "default".
228+
compressor: the compressor for final merged zarr array. Default is None.
229+
Deprecated since 1.5.0 and will be removed in 1.7.0. Use codecs instead.
222230
value_compressor: the compressor for value aggregating zarr array. Default is None.
231+
Deprecated since 1.5.0 and will be removed in 1.7.0. Use value_codecs instead.
223232
count_compressor: the compressor for sample counting zarr array. Default is None.
233+
Deprecated since 1.5.0 and will be removed in 1.7.0. Use count_codecs instead.
234+
codecs: the codecs for final merged zarr array. Default is None.
235+
For zarr v3, this is a list of codec configurations. See zarr documentation for details.
236+
value_codecs: the codecs for value aggregating zarr array. Default is None.
237+
For zarr v3, this is a list of codec configurations. See zarr documentation for details.
238+
count_codecs: the codecs for sample counting zarr array. Default is None.
239+
For zarr v3, this is a list of codec configurations. See zarr documentation for details.
224240
chunks : int or tuple of ints that defines the chunk shape, or boolean. Default is True.
225241
If True, chunk shape will be guessed from `shape` and `dtype`.
226242
If False, it will be set to `shape`, i.e., single chunk for the whole array.
227243
If an int, the chunk size in each dimension will be given by the value of `chunks`.
228244
"""
229245

246+
@deprecated_arg(
247+
name="compressor", since="1.5.0", removed="1.7.0", new_name="codecs", msg_suffix="Please use 'codecs' instead."
248+
)
249+
@deprecated_arg(
250+
name="value_compressor",
251+
since="1.5.0",
252+
removed="1.7.0",
253+
new_name="value_codecs",
254+
msg_suffix="Please use 'value_codecs' instead.",
255+
)
256+
@deprecated_arg(
257+
name="count_compressor",
258+
since="1.5.0",
259+
removed="1.7.0",
260+
new_name="count_codecs",
261+
msg_suffix="Please use 'count_codecs' instead.",
262+
)
230263
def __init__(
231264
self,
232265
merged_shape: Sequence[int],
@@ -240,6 +273,9 @@ def __init__(
240273
compressor: str | None = None,
241274
value_compressor: str | None = None,
242275
count_compressor: str | None = None,
276+
codecs: list | None = None,
277+
value_codecs: list | None = None,
278+
count_codecs: list | None = None,
243279
chunks: Sequence[int] | bool = True,
244280
thread_locking: bool = True,
245281
) -> None:
@@ -251,7 +287,11 @@ def __init__(
251287
self.count_dtype = count_dtype
252288
self.store = store
253289
self.tmpdir: TemporaryDirectory | None
254-
if version_geq(get_package_version("zarr"), "3.0.0"):
290+
291+
# Handle zarr v3 vs older versions
292+
is_zarr_v3 = version_geq(get_package_version("zarr"), "3.0.0")
293+
294+
if is_zarr_v3:
255295
if value_store is None:
256296
self.tmpdir = TemporaryDirectory()
257297
self.value_store = zarr.storage.LocalStore(self.tmpdir.name) # type: ignore
@@ -266,34 +306,119 @@ def __init__(
266306
self.tmpdir = None
267307
self.value_store = zarr.storage.TempStore() if value_store is None else value_store # type: ignore
268308
self.count_store = zarr.storage.TempStore() if count_store is None else count_store # type: ignore
309+
269310
self.chunks = chunks
270-
self.compressor = compressor
271-
self.value_compressor = value_compressor
272-
self.count_compressor = count_compressor
273-
self.output = zarr.empty(
274-
shape=self.merged_shape,
275-
chunks=self.chunks,
276-
dtype=self.output_dtype,
277-
compressor=self.compressor,
278-
store=self.store,
279-
overwrite=True,
280-
)
281-
self.values = zarr.zeros(
282-
shape=self.merged_shape,
283-
chunks=self.chunks,
284-
dtype=self.value_dtype,
285-
compressor=self.value_compressor,
286-
store=self.value_store,
287-
overwrite=True,
288-
)
289-
self.counts = zarr.zeros(
290-
shape=self.merged_shape,
291-
chunks=self.chunks,
292-
dtype=self.count_dtype,
293-
compressor=self.count_compressor,
294-
store=self.count_store,
295-
overwrite=True,
296-
)
311+
312+
# Handle compressor/codecs based on zarr version
313+
is_zarr_v3 = version_geq(get_package_version("zarr"), "3.0.0")
314+
315+
# Initialize codecs/compressor attributes with proper types
316+
self.codecs: list | None = None
317+
self.value_codecs: list | None = None
318+
self.count_codecs: list | None = None
319+
320+
if is_zarr_v3:
321+
# For zarr v3, use codecs or convert compressor to codecs
322+
if codecs is not None:
323+
self.codecs = codecs
324+
elif compressor is not None:
325+
# Convert compressor to codec format
326+
if isinstance(compressor, (list, tuple)):
327+
self.codecs = compressor
328+
else:
329+
self.codecs = [compressor]
330+
else:
331+
self.codecs = None
332+
333+
if value_codecs is not None:
334+
self.value_codecs = value_codecs
335+
elif value_compressor is not None:
336+
if isinstance(value_compressor, (list, tuple)):
337+
self.value_codecs = value_compressor
338+
else:
339+
self.value_codecs = [value_compressor]
340+
else:
341+
self.value_codecs = None
342+
343+
if count_codecs is not None:
344+
self.count_codecs = count_codecs
345+
elif count_compressor is not None:
346+
if isinstance(count_compressor, (list, tuple)):
347+
self.count_codecs = count_compressor
348+
else:
349+
self.count_codecs = [count_compressor]
350+
else:
351+
self.count_codecs = None
352+
else:
353+
# For zarr v2, use compressors
354+
if codecs is not None:
355+
# If codecs are specified in v2, use the first codec as compressor
356+
self.codecs = codecs[0] if isinstance(codecs, (list, tuple)) else codecs
357+
else:
358+
self.codecs = compressor # type: ignore[assignment]
359+
360+
if value_codecs is not None:
361+
self.value_codecs = value_codecs[0] if isinstance(value_codecs, (list, tuple)) else value_codecs
362+
else:
363+
self.value_codecs = value_compressor # type: ignore[assignment]
364+
365+
if count_codecs is not None:
366+
self.count_codecs = count_codecs[0] if isinstance(count_codecs, (list, tuple)) else count_codecs
367+
else:
368+
self.count_codecs = count_compressor # type: ignore[assignment]
369+
370+
# Create zarr arrays with appropriate parameters based on version
371+
if is_zarr_v3:
372+
self.output = zarr.empty(
373+
shape=self.merged_shape,
374+
chunks=self.chunks,
375+
dtype=self.output_dtype,
376+
codecs=self.codecs,
377+
store=self.store,
378+
overwrite=True,
379+
)
380+
self.values = zarr.zeros(
381+
shape=self.merged_shape,
382+
chunks=self.chunks,
383+
dtype=self.value_dtype,
384+
codecs=self.value_codecs,
385+
store=self.value_store,
386+
overwrite=True,
387+
)
388+
self.counts = zarr.zeros(
389+
shape=self.merged_shape,
390+
chunks=self.chunks,
391+
dtype=self.count_dtype,
392+
codecs=self.count_codecs,
393+
store=self.count_store,
394+
overwrite=True,
395+
)
396+
else:
397+
self.output = zarr.empty(
398+
shape=self.merged_shape,
399+
chunks=self.chunks,
400+
dtype=self.output_dtype,
401+
compressor=self.codecs,
402+
store=self.store,
403+
overwrite=True,
404+
)
405+
self.values = zarr.zeros(
406+
shape=self.merged_shape,
407+
chunks=self.chunks,
408+
dtype=self.value_dtype,
409+
compressor=self.value_codecs,
410+
store=self.value_store,
411+
overwrite=True,
412+
)
413+
self.counts = zarr.zeros(
414+
shape=self.merged_shape,
415+
chunks=self.chunks,
416+
dtype=self.count_dtype,
417+
compressor=self.count_codecs,
418+
store=self.count_store,
419+
overwrite=True,
420+
)
421+
297422
self.lock: threading.Lock | nullcontext
298423
if thread_locking:
299424
# use lock to protect the in-place addition during aggregation

0 commit comments

Comments
 (0)