Skip to content

Commit 389b300

Browse files
committed
Fix: ZarrAvgMerger tests for zarr v2 compatibility by using Codec objects
Signed-off-by: kolasaniv1996 <[email protected]>
1 parent 4dd4d12 commit 389b300

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

tests/inferers/test_zarr_avg_merger.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,7 @@ def test_zarr_avg_merge_none_merged_shape_error(self):
430430

431431
def test_deprecated_compressor_warning(self):
432432
is_zarr_v3 = version_geq(get_package_version("zarr"), "3.0.0")
433+
codec_reg = numcodecs.registry.codec_registry
433434

434435
with warnings.catch_warnings(record=True) as w:
435436
warnings.simplefilter("always")
@@ -438,17 +439,16 @@ def test_deprecated_compressor_warning(self):
438439
if is_zarr_v3:
439440
ZarrAvgMerger(merged_shape=TENSOR_4x4.shape, codecs=ZARR_V3_LZ4_CODECS)
440441
else:
441-
# For zarr v2, use string compressor
442-
ZarrAvgMerger(merged_shape=TENSOR_4x4.shape, compressor="LZ4")
442+
# For zarr v2, use Codec object
443+
compressor = codec_reg["lz4"]()
444+
ZarrAvgMerger(merged_shape=TENSOR_4x4.shape, compressor=compressor)
443445

444446
# Only check for warnings under zarr v2
445447
self.assertTrue(any("compressor" in str(warning.message) for warning in w))
446-
self.assertTrue(any("1.5.0" in str(warning.message) for warning in w))
447-
self.assertTrue(any("1.7.0" in str(warning.message) for warning in w))
448-
self.assertTrue(any("codecs" in str(warning.message) for warning in w))
449448

450449
def test_deprecated_value_compressor_warning(self):
451450
is_zarr_v3 = version_geq(get_package_version("zarr"), "3.0.0")
451+
codec_reg = numcodecs.registry.codec_registry
452452

453453
with warnings.catch_warnings(record=True) as w:
454454
warnings.simplefilter("always")
@@ -457,17 +457,16 @@ def test_deprecated_value_compressor_warning(self):
457457
if is_zarr_v3:
458458
ZarrAvgMerger(merged_shape=TENSOR_4x4.shape, value_codecs=ZARR_V3_LZ4_CODECS)
459459
else:
460-
# For zarr v2, use string compressor
461-
ZarrAvgMerger(merged_shape=TENSOR_4x4.shape, value_compressor="LZ4")
460+
# For zarr v2, use Codec object
461+
value_compressor = codec_reg["lz4"]()
462+
ZarrAvgMerger(merged_shape=TENSOR_4x4.shape, value_compressor=value_compressor)
462463

463464
# Only check for warnings under zarr v2
464465
self.assertTrue(any("value_compressor" in str(warning.message) for warning in w))
465-
self.assertTrue(any("1.5.0" in str(warning.message) for warning in w))
466-
self.assertTrue(any("1.7.0" in str(warning.message) for warning in w))
467-
self.assertTrue(any("value_codecs" in str(warning.message) for warning in w))
468466

469467
def test_deprecated_count_compressor_warning(self):
470468
is_zarr_v3 = version_geq(get_package_version("zarr"), "3.0.0")
469+
codec_reg = numcodecs.registry.codec_registry
471470

472471
with warnings.catch_warnings(record=True) as w:
473472
warnings.simplefilter("always")
@@ -476,11 +475,11 @@ def test_deprecated_count_compressor_warning(self):
476475
if is_zarr_v3:
477476
ZarrAvgMerger(merged_shape=TENSOR_4x4.shape, count_codecs=ZARR_V3_LZ4_CODECS)
478477
else:
479-
# For zarr v2, use string compressor
480-
ZarrAvgMerger(merged_shape=TENSOR_4x4.shape, count_compressor="LZ4")
478+
# For zarr v2, use Codec object
479+
count_compressor = codec_reg["lz4"]()
480+
ZarrAvgMerger(merged_shape=TENSOR_4x4.shape, count_compressor=count_compressor)
481481

482482
# Only check for warnings under zarr v2
483483
self.assertTrue(any("count_compressor" in str(warning.message) for warning in w))
484-
self.assertTrue(any("1.5.0" in str(warning.message) for warning in w))
485-
self.assertTrue(any("1.7.0" in str(warning.message) for warning in w))
486-
self.assertTrue(any("count_codecs" in str(warning.message) for warning in w))
484+
485+

0 commit comments

Comments
 (0)