diff --git a/pyproject.toml b/pyproject.toml index 873415a24..77a5a7447 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ dependencies = [ "pyyaml", "semantic-version", "tqdm", - "zarr<=2.18.4", + "zarr", ] optional-dependencies.all = [ @@ -94,6 +94,7 @@ optional-dependencies.docs = [ optional-dependencies.remote = [ "boto3", + "obstore", "requests", ] diff --git a/src/anemoi/datasets/commands/copy.py b/src/anemoi/datasets/commands/copy.py index ea08c8aab..da2076ffa 100644 --- a/src/anemoi/datasets/commands/copy.py +++ b/src/anemoi/datasets/commands/copy.py @@ -21,6 +21,7 @@ from anemoi.utils.remote import TransferMethodNotImplementedError from anemoi.datasets.check import check_zarr +from anemoi.datasets.zarr_versions import zarr_2_or_3 from . import Command @@ -51,8 +52,6 @@ class ZarrCopier: Flag to resume copying an existing dataset. verbosity : int Verbosity level of logging. - nested : bool - Flag to use ZARR's nested directory backend. rechunk : str Rechunk size for the target data array. """ @@ -66,7 +65,6 @@ def __init__( overwrite: bool, resume: bool, verbosity: int, - nested: bool, rechunk: str, **kwargs: Any, ) -> None: @@ -88,8 +86,6 @@ def __init__( Flag to resume copying an existing dataset. verbosity : int Verbosity level of logging. - nested : bool - Flag to use ZARR's nested directory backend. rechunk : str Rechunk size for the target data array. **kwargs : Any @@ -102,7 +98,6 @@ def __init__( self.overwrite = overwrite self.resume = resume self.verbosity = verbosity - self.nested = nested self.rechunk = rechunk self.rechunking = rechunk.split(",") if rechunk else [] @@ -115,27 +110,6 @@ def __init__( raise NotImplementedError("Rechunking with SSH not implemented.") assert NotImplementedError("SSH not implemented.") - def _store(self, path: str, nested: bool = False) -> Any: - """Get the storage path. - - Parameters - ---------- - path : str - Path to the storage. - nested : bool, optional - Flag to use nested directory storage. - - Returns - ------- - Any - Storage path. - """ - if nested: - import zarr - - return zarr.storage.NestedDirectoryStore(path) - return path - def copy_chunk(self, n: int, m: int, source: Any, target: Any, _copy: Any, verbosity: int) -> Optional[slice]: """Copy a chunk of data from source to target. @@ -239,7 +213,8 @@ def copy_data(self, source: Any, target: Any, _copy: Any, verbosity: int) -> Non target_data = ( target["data"] if "data" in target - else target.create_dataset( + else zarr_2_or_3.create_array( + target, "data", shape=source_data.shape, chunks=self.data_chunks, @@ -319,7 +294,6 @@ def copy_group(self, source: Any, target: Any, _copy: Any, verbosity: int) -> No verbosity : int Verbosity level of logging. """ - import zarr if self.verbosity > 0: LOG.info(f"Copying group {source} to {target}") @@ -345,7 +319,7 @@ def copy_group(self, source: Any, target: Any, _copy: Any, verbosity: int) -> No LOG.info(f"Skipping {name}") continue - if isinstance(source[name], zarr.hierarchy.Group): + if zarr_2_or_3.is_zarr_group(source[name]): group = target[name] if name in target else target.create_group(name) self.copy_group( source[name], @@ -403,13 +377,13 @@ def run(self) -> None: def target_exists() -> bool: try: - zarr.open(self._store(self.target), mode="r") + zarr.open(self.target, mode="r") return True except ValueError: return False def target_finished() -> bool: - target = zarr.open(self._store(self.target), mode="r") + target = zarr.open(self.target, mode="r") if "_copy" in target: done = sum(1 if x else 0 for x in target["_copy"]) todo = len(target["_copy"]) @@ -427,11 +401,11 @@ def target_finished() -> bool: def open_target() -> Any: if not target_exists(): - return zarr.open(self._store(self.target, self.nested), mode="w") + return zarr.open(self.target, mode="w") if self.overwrite: LOG.error("Target already exists, overwriting.") - return zarr.open(self._store(self.target, self.nested), mode="w") + return zarr.open(self.target, mode="w") if self.resume: if target_finished(): @@ -439,7 +413,7 @@ def open_target() -> Any: sys.exit(0) LOG.error("Target already exists, resuming copy.") - return zarr.open(self._store(self.target, self.nested), mode="w+") + return zarr.open(self.target, mode=zarr_2_or_3.open_mode_append) LOG.error("Target already exists, use either --overwrite or --resume.") sys.exit(1) @@ -495,7 +469,6 @@ def add_arguments(self, command_parser: Any) -> None: help="Verbosity level. 0 is silent, 1 is normal, 2 is verbose.", default=1, ) - command_parser.add_argument("--nested", action="store_true", help="Use ZARR's nested directpry backend.") command_parser.add_argument( "--rechunk", help="Rechunk the target data array. Rechunk size should be a diviser of the block size." ) diff --git a/src/anemoi/datasets/commands/create.py b/src/anemoi/datasets/commands/create.py index 86332cfcc..219a977cb 100644 --- a/src/anemoi/datasets/commands/create.py +++ b/src/anemoi/datasets/commands/create.py @@ -86,6 +86,7 @@ def add_arguments(self, command_parser: Any) -> None: group.add_argument("--threads", help="Use `n` parallel thread workers.", type=int, default=0) group.add_argument("--processes", help="Use `n` parallel process workers.", type=int, default=0) command_parser.add_argument("--trace", action="store_true") + command_parser.add_argument("--force-zarr3", action="store_true") def run(self, args: Any) -> None: """Execute the create command. diff --git a/src/anemoi/datasets/commands/init.py b/src/anemoi/datasets/commands/init.py index 0ca540b86..73152facb 100644 --- a/src/anemoi/datasets/commands/init.py +++ b/src/anemoi/datasets/commands/init.py @@ -63,6 +63,7 @@ def add_arguments(self, subparser: Any) -> None: subparser.add_argument("--cache", help="Location to store the downloaded data.", metavar="DIR") subparser.add_argument("--trace", action="store_true") + subparser.add_argument("--force-zarr3", action="store_true", help="Force the use of Zarr v3 format.") def run(self, args: Any) -> None: """Execute the command with the provided arguments. diff --git a/src/anemoi/datasets/commands/inspect.py b/src/anemoi/datasets/commands/inspect.py index 400cdcf98..112eb5407 100644 --- a/src/anemoi/datasets/commands/inspect.py +++ b/src/anemoi/datasets/commands/inspect.py @@ -655,7 +655,7 @@ def ready(self) -> bool: if "_build_flags" not in self.zarr: return False - build_flags = self.zarr["_build_flags"] + build_flags = self.zarr["_build_flags"][:] return all(build_flags) @property @@ -711,7 +711,7 @@ def build_flags(self) -> Optional[NDArray]: if "_build" not in self.zarr: return None build = self.zarr["_build"] - return build.get("flags") + return build.get("flags")[:] @property def build_lengths(self) -> Optional[NDArray]: @@ -719,7 +719,7 @@ def build_lengths(self) -> Optional[NDArray]: if "_build" not in self.zarr: return None build = self.zarr["_build"] - return build.get("lengths") + return build.get("lengths")[:] VERSIONS = { diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index 8d74975c5..5a2f7564f 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -38,6 +38,7 @@ from anemoi.datasets.data.misc import as_first_date from anemoi.datasets.data.misc import as_last_date from anemoi.datasets.dates.groups import Groups +from anemoi.datasets.zarr_versions import zarr_2_or_3 from .check import DatasetName from .check import check_data_values @@ -156,7 +157,7 @@ def _path_readable(path: str) -> bool: try: zarr.open(path, "r") return True - except zarr.errors.PathNotFoundError: + except zarr_2_or_3.FileNotFoundException: return False @@ -173,6 +174,11 @@ def __init__(self, path: str): """ self.path = path + # if zarr_2_or_3.version != 2: + # raise ValueError( + # f"Only zarr version 2 is supported when creating datasets, found version: {zarr.__version__}" + # ) + _, ext = os.path.splitext(self.path) if ext != ".zarr": raise ValueError(f"Unsupported extension={ext} for path={self.path}") @@ -192,10 +198,9 @@ def add_dataset(self, mode: str = "r+", **kwargs: Any) -> zarr.Array: zarr.Array The added dataset. """ - import zarr z = zarr.open(self.path, mode=mode) - from .zarr import add_zarr_dataset + from .misc import add_zarr_dataset return add_zarr_dataset(zarr_root=z, **kwargs) @@ -210,7 +215,7 @@ def update_metadata(self, **kwargs: Any) -> None: import zarr LOG.debug(f"Updating metadata {kwargs}") - z = zarr.open(self.path, mode="w+") + z = zarr.open(self.path, mode=zarr_2_or_3.open_mode_append) for k, v in kwargs.items(): if isinstance(v, np.datetime64): v = v.astype(datetime.datetime) @@ -445,7 +450,7 @@ def check_missing_dates(expected: list[np.datetime64]) -> None: """ import zarr - z = zarr.open(path, "r") + z = zarr.open(path, mode="r") missing_dates = z.attrs.get("missing_dates", []) missing_dates = sorted([np.datetime64(d) for d in missing_dates]) if missing_dates != expected: @@ -517,7 +522,7 @@ class HasRegistryMixin: @cached_property def registry(self) -> Any: """Get the registry.""" - from .zarr import ZarrBuiltRegistry + from .misc import ZarrBuiltRegistry return ZarrBuiltRegistry(self.path, use_threads=self.use_threads) @@ -581,6 +586,7 @@ def __init__( progress: Any = None, test: bool = False, cache: Optional[str] = None, + force_zarr3: bool = False, **kwargs: Any, ): """Initialize an Init instance. @@ -609,6 +615,32 @@ def __init__( if _path_readable(path) and not overwrite: raise Exception(f"{path} already exists. Use overwrite=True to overwrite.") + version = zarr_2_or_3.version + if not zarr_2_or_3.supports_datetime64(): + LOG.warning("⚠️" * 80) + LOG.warning(f"This version of Zarr ({zarr.__version__}) does not support datetime64.") + LOG.warning("⚠️" * 80) + + if version != 2: + + pytesting = "PYTEST_CURRENT_TEST" in os.environ + + if pytesting or force_zarr3: + LOG.warning("⚠️" * 80) + LOG.warning("Zarr version 3 is used, but this is an unsupported feature.") + LOG.warning("⚠️" * 80) + else: + LOG.warning("⚠️" * 80) + LOG.warning( + f"Only Zarr version 2 is supported when creating datasets, found version: {zarr.__version__}" + ) + LOG.warning("If you want to use Zarr version 3, please set --force-zarr3 option.") + LOG.warning("Please note that this is an unsupported feature.") + LOG.warning("⚠️" * 80) + raise ValueError( + f"Only Zarr version 2 is supported when creating datasets, found version: {zarr.__version__}" + ) + super().__init__(path, cache=cache) self.config = config self.check_name = check_name @@ -1520,7 +1552,7 @@ def run(self) -> None: LOG.info(stats) - if not all(self.registry.get_flags(sync=False)): + if not all(self.registry.get_flags()): raise Exception(f"❗Zarr {self.path} is not fully built, not writing statistics into dataset.") for k in [ diff --git a/src/anemoi/datasets/create/zarr.py b/src/anemoi/datasets/create/misc.py similarity index 78% rename from src/anemoi/datasets/create/zarr.py rename to src/anemoi/datasets/create/misc.py index 574f603d0..c48b67c9f 100644 --- a/src/anemoi/datasets/create/zarr.py +++ b/src/anemoi/datasets/create/misc.py @@ -8,8 +8,6 @@ # nor does it submit to any jurisdiction. import datetime -import logging -import shutil from typing import Any from typing import Optional @@ -17,7 +15,10 @@ import zarr from numpy.typing import NDArray -LOG = logging.getLogger(__name__) +from anemoi.datasets.zarr_versions import zarr_2_or_3 + +from .synchronise import NoSynchroniser +from .synchronise import Synchroniser def add_zarr_dataset( @@ -72,8 +73,11 @@ def add_zarr_dataset( shape = array.shape if array is not None: + array, dtype = zarr_2_or_3.cast_dtype_datetime64(array, dtype) + assert array.shape == shape, (array.shape, shape) - a = zarr_root.create_dataset( + a = zarr_2_or_3.create_array( + zarr_root, name, shape=shape, dtype=dtype, @@ -100,7 +104,9 @@ def add_zarr_dataset( else: raise ValueError(f"No fill_value for dtype={dtype}") - a = zarr_root.create_dataset( + dtype = zarr_2_or_3.change_dtype_datetime64(dtype) + a = zarr_2_or_3.create_array( + zarr_root, name, shape=shape, dtype=dtype, @@ -132,33 +138,19 @@ def __init__(self, path: str, synchronizer_path: Optional[str] = None, use_threa use_threads : bool Whether to use thread-based synchronization. """ - import zarr assert isinstance(path, str), path self.zarr_path = path - if use_threads: - self.synchronizer = zarr.ThreadSynchronizer() - self.synchronizer_path = None - else: - if synchronizer_path is None: - synchronizer_path = self.zarr_path + ".sync" - self.synchronizer_path = synchronizer_path - self.synchronizer = zarr.ProcessSynchronizer(self.synchronizer_path) + self.synchronizer = Synchroniser(synchronizer_path) if synchronizer_path else NoSynchroniser() def clean(self) -> None: """Clean up the synchronizer path.""" - if self.synchronizer_path is not None: - try: - shutil.rmtree(self.synchronizer_path) - except FileNotFoundError: - pass + self.synchronizer.clean() def _open_write(self) -> zarr.Group: """Open the Zarr store in write mode.""" - import zarr - - return zarr.open(self.zarr_path, mode="r+", synchronizer=self.synchronizer) + return zarr.open(self.zarr_path, mode="r+") def _open_read(self, sync: bool = True) -> zarr.Group: """Open the Zarr store in read mode. @@ -173,12 +165,7 @@ def _open_read(self, sync: bool = True) -> zarr.Group: zarr.Group The opened Zarr group. """ - import zarr - - if sync: - return zarr.open(self.zarr_path, mode="r", synchronizer=self.synchronizer) - else: - return zarr.open(self.zarr_path, mode="r") + return zarr.open(self.zarr_path, mode="r") def new_dataset(self, *args, **kwargs) -> None: """Create a new dataset in the Zarr store. @@ -190,9 +177,11 @@ def new_dataset(self, *args, **kwargs) -> None: **kwargs Keyword arguments for dataset creation. """ - z = self._open_write() - zarr_root = z["_build"] - add_zarr_dataset(*args, zarr_root=zarr_root, overwrite=True, dimensions=("tmp",), **kwargs) + with self.synchronizer: + z = self._open_write() + zarr_root = z["_build"] + add_zarr_dataset(*args, zarr_root=zarr_root, overwrite=True, dimensions=("tmp",), **kwargs) + del z def add_to_history(self, action: str, **kwargs) -> None: """Add an action to the history attribute of the Zarr store. @@ -210,10 +199,12 @@ def add_to_history(self, action: str, **kwargs) -> None: ) new.update(kwargs) - z = self._open_write() - history = z.attrs.get("history", []) - history.append(new) - z.attrs["history"] = history + with self.synchronizer: + z = self._open_write() + history = z.attrs.get("history", []) + history.append(new) + z.attrs["history"] = history + del z def get_lengths(self) -> list[int]: """Get the lengths dataset. @@ -223,8 +214,11 @@ def get_lengths(self) -> list[int]: list[int] The lengths dataset. """ - z = self._open_read() - return list(z["_build"][self.name_lengths][:]) + with self.synchronizer: + z = self._open_read() + lengths = list(z["_build"][self.name_lengths][:]) + del z + return lengths def get_flags(self, **kwargs) -> list[bool]: """Get the flags dataset. @@ -239,8 +233,11 @@ def get_flags(self, **kwargs) -> list[bool]: list[bool] The flags dataset. """ - z = self._open_read(**kwargs) - return list(z["_build"][self.name_flags][:]) + with self.synchronizer: + z = self._open_read(**kwargs) + flags = list(z["_build"][self.name_flags][:]) + del z + return flags def get_flag(self, i: int) -> bool: """Get a specific flag. @@ -255,8 +252,11 @@ def get_flag(self, i: int) -> bool: bool The flag value. """ - z = self._open_read() - return z["_build"][self.name_flags][i] + with self.synchronizer: + z = self._open_read() + flag = z["_build"][self.name_flags][i] + del z + return flag def set_flag(self, i: int, value: bool = True) -> None: """Set a specific flag. @@ -268,11 +268,13 @@ def set_flag(self, i: int, value: bool = True) -> None: value : bool Value to set the flag to. """ - z = self._open_write() - z.attrs["latest_write_timestamp"] = ( - datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None).isoformat() - ) - z["_build"][self.name_flags][i] = value + with self.synchronizer: + z = self._open_write() + z.attrs["latest_write_timestamp"] = ( + datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None).isoformat() + ) + z["_build"][self.name_flags][i] = value + del z def ready(self) -> bool: """Check if all flags are set. @@ -316,11 +318,13 @@ def add_provenance(self, name: str) -> None: name : str Name of the provenance attribute. """ - z = self._open_write() + with self.synchronizer: + z = self._open_write() - if name in z.attrs: - return + if name in z.attrs: + return - from anemoi.utils.provenance import gather_provenance_info + from anemoi.utils.provenance import gather_provenance_info - z.attrs[name] = gather_provenance_info() + z.attrs[name] = gather_provenance_info() + del z diff --git a/src/anemoi/datasets/create/patch.py b/src/anemoi/datasets/create/patch.py index e8de85851..3a96d5324 100755 --- a/src/anemoi/datasets/create/patch.py +++ b/src/anemoi/datasets/create/patch.py @@ -14,6 +14,8 @@ import zarr +from anemoi.datasets.zarr_versions import zarr_2_or_3 + LOG = logging.getLogger(__name__) @@ -134,7 +136,7 @@ def apply_patch(path: str, verbose: bool = True, dry_run: bool = False) -> None: try: attrs = zarr.open(path, mode="r").attrs.asdict() - except zarr.errors.PathNotFoundError as e: + except zarr_2_or_3.get_not_found_exception() as e: LOG.error(f"Failed to open {path}") LOG.error(e) exit(0) diff --git a/src/anemoi/datasets/create/sources/planetary_computer.py b/src/anemoi/datasets/create/sources/planetary_computer.py new file mode 100644 index 000000000..b710bcbbe --- /dev/null +++ b/src/anemoi/datasets/create/sources/planetary_computer.py @@ -0,0 +1,44 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from . import source_registry +from .xarray import XarraySourceBase + + +@source_registry.register("planetary_computer") +class PlanetaryComputerSource(XarraySourceBase): + """An Xarray data source for the planetary_computer.""" + + emoji = "🪐" + + def __init__(self, context, data_catalog_id, version="v1", *args, **kwargs: dict): + + import planetary_computer + import pystac_client + + self.data_catalog_id = data_catalog_id + self.flavour = kwargs.pop("flavour", None) + self.patch = kwargs.pop("patch", None) + self.options = kwargs.pop("options", {}) + + catalog = pystac_client.Client.open( + f"https://planetarycomputer.microsoft.com/api/stac/{version}/", + modifier=planetary_computer.sign_inplace, + ) + collection = catalog.get_collection(self.data_catalog_id) + + asset = collection.assets["zarr-abfs"] + + if "xarray:storage_options" in asset.extra_fields: + self.options["storage_options"] = asset.extra_fields["xarray:storage_options"] + + self.options.update(asset.extra_fields["xarray:open_kwargs"]) + + super().__init__(context, url=asset.href, *args, **kwargs) diff --git a/src/anemoi/datasets/create/sources/xarray_support/__init__.py b/src/anemoi/datasets/create/sources/xarray_support/__init__.py index 4f4edb46f..665cfdad3 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/__init__.py +++ b/src/anemoi/datasets/create/sources/xarray_support/__init__.py @@ -20,7 +20,6 @@ from earthkit.data.core.fieldlist import MultiFieldList from anemoi.datasets.create.sources.patterns import iterate_patterns -from anemoi.datasets.data.stores import name_to_zarr_store from ..legacy import legacy_source from .fieldlist import XarrayFieldList @@ -89,37 +88,22 @@ def load_one( The loaded dataset. """ - """ - We manage the S3 client ourselves, bypassing fsspec and s3fs layers, because sometimes something on the stack - zarr/fsspec/s3fs/boto3 (?) seem to flags files as missing when they actually are not (maybe when S3 reports some sort of - connection error). In that case, Zarr will silently fill the chunks that could not be downloaded with NaNs. - See https://github.com/pydata/xarray/issues/8842 - - We have seen this bug triggered when we run many clients in parallel, for example, when we create a new dataset using `xarray-zarr`. - """ - if options is None: options = {} context.trace(emoji, dataset, options, kwargs) - if isinstance(dataset, str) and ".zarr" in dataset: - data = xr.open_zarr(name_to_zarr_store(dataset), **options) - elif "planetarycomputer" in dataset: - store = name_to_zarr_store(dataset) - if "store" in store: - data = xr.open_zarr(**store) - if "filename_or_obj" in store: - data = xr.open_dataset(**store) - else: - data = xr.open_dataset(dataset, **options) + if isinstance(dataset, str) and dataset.endswith(".zarr"): + # If the dataset is a zarr store, we need to use the zarr engine + options["engine"] = "zarr" + + data = xr.open_dataset(dataset, **options) fs = XarrayFieldList.from_xarray(data, flavour=flavour, patch=patch) if len(dates) == 0: result = fs.sel(**kwargs) else: - print("dates", dates, kwargs) result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates]) if len(result) == 0: @@ -130,7 +114,7 @@ def load_one( a = ["valid_datetime", k.metadata("valid_datetime", default=None)] for n in kwargs.keys(): a.extend([n, k.metadata(n, default=None)]) - print([str(x) for x in a]) + LOG.warning(f"{[str(x) for x in a]}") if i > 16: break diff --git a/src/anemoi/datasets/create/sources/xarray_support/field.py b/src/anemoi/datasets/create/sources/xarray_support/field.py index 663aeab54..9fdd93246 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/field.py +++ b/src/anemoi/datasets/create/sources/xarray_support/field.py @@ -87,13 +87,10 @@ def __init__(self, owner: Any, selection: Any) -> None: coordinate = owner.by_name[coord_name] self._md[coord_name] = coordinate.normalise(extract_single_value(coord_value)) - # print(values.ndim, values.shape, selection.dims) # By now, the only dimensions should be latitude and longitude self._shape = tuple(list(self.selection.shape)[-2:]) if math.prod(self._shape) != math.prod(self.selection.shape): - print(self.selection.ndim, self.selection.shape) - print(self.selection) - raise ValueError("Invalid shape for selection") + raise ValueError(f"Invalid shape for selection {self._shape=}, {self.selection.shape=} {self.selection=}") @property def shape(self) -> Tuple[int, int]: diff --git a/src/anemoi/datasets/create/sources/xarray_support/flavour.py b/src/anemoi/datasets/create/sources/xarray_support/flavour.py index 4df374148..02b30d7bb 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/flavour.py +++ b/src/anemoi/datasets/create/sources/xarray_support/flavour.py @@ -308,9 +308,9 @@ def _x_y_provided(self, x: Any, y: Any, variable: Any) -> Any: return self._grid_cache[(x.name, y.name, dim_vars)] grid_mapping = variable.attrs.get("grid_mapping", None) - if grid_mapping is not None: - print(f"grid_mapping: {grid_mapping}") - print(self.ds[grid_mapping]) + # if grid_mapping is not None: + # print(f"grid_mapping: {grid_mapping}") + # print(self.ds[grid_mapping]) if grid_mapping is None: LOG.warning(f"No 'grid_mapping' attribute provided for '{variable.name}'") diff --git a/src/anemoi/datasets/create/sources/xarray_support/patch.py b/src/anemoi/datasets/create/sources/xarray_support/patch.py index 29ea620dd..a84fccc14 100644 --- a/src/anemoi/datasets/create/sources/xarray_support/patch.py +++ b/src/anemoi/datasets/create/sources/xarray_support/patch.py @@ -61,9 +61,28 @@ def patch_coordinates(ds: xr.Dataset, coordinates: List[str]) -> Any: return ds +def patch_rename(ds: xr.Dataset, renames: dict[str, str]) -> Any: + """Rename variables in the dataset. + + Parameters + ---------- + ds : xr.Dataset + The dataset to patch. + renames : dict[str, str] + Mapping from old variable names to new variable names. + + Returns + ------- + Any + The patched dataset. + """ + return ds.rename(renames) + + PATCHES = { "attributes": patch_attributes, "coordinates": patch_coordinates, + "rename": patch_rename, } @@ -82,7 +101,9 @@ def patch_dataset(ds: xr.Dataset, patch: Dict[str, Dict[str, Any]]) -> Any: Any The patched dataset. """ - for what, values in patch.items(): + + ORDER = ["coordinates", "attributes", "rename"] + for what, values in sorted(patch.items(), key=lambda x: ORDER.index(x[0])): if what not in PATCHES: raise ValueError(f"Unknown patch type {what!r}") diff --git a/src/anemoi/datasets/create/synchronise.py b/src/anemoi/datasets/create/synchronise.py new file mode 100644 index 000000000..b0990f4af --- /dev/null +++ b/src/anemoi/datasets/create/synchronise.py @@ -0,0 +1,82 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +import shutil +import time + +from filelock import FileLock +from filelock import Timeout + +LOG = logging.getLogger(__name__) + + +class NoSynchroniser: + def __init__(self): + pass + + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_val, exc_tb): + pass + + def clean(self): + pass + + +class Synchroniser: + def __init__(self, lock_file_path, timeout=10): + """Initialize the Synchroniser with the path to the lock file and an optional timeout. + Parameters + ---------- + lock_file_path + Path to the lock file on a shared filesystem. + timeout + Timeout for acquiring the lock in seconds. + """ + self.lock_file_path = lock_file_path + self.timeout = timeout + self.lock = FileLock(lock_file_path) + + def __enter__(self): + """Acquire the lock when entering the context.""" + try: + self.lock.acquire(timeout=self.timeout) + print("Lock acquired.") + except Timeout: + print("Could not acquire lock, another process might be holding it.") + raise + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Release the lock when exiting the context.""" + self.lock.release() + print("Lock released.") + + def clean(self): + try: + shutil.rmtree(self.lock_file_path) + except FileNotFoundError: + pass + + +# Example usage +if __name__ == "__main__": + + def example_operation(): + print("Performing operation...") + time.sleep(2) # Simulate some work + print("Operation complete.") + + lock_path = "/path/to/shared/lockfile.lock" + + # Use the Synchroniser as a context manager + with Synchroniser(lock_path) as sync: + example_operation() diff --git a/src/anemoi/datasets/data/misc.py b/src/anemoi/datasets/data/misc.py index b5523ef85..54b6ba382 100644 --- a/src/anemoi/datasets/data/misc.py +++ b/src/anemoi/datasets/data/misc.py @@ -27,6 +27,8 @@ from anemoi.utils.config import load_config as load_settings from numpy.typing import NDArray +from anemoi.datasets.zarr_versions import zarr_2_or_3 + if TYPE_CHECKING: from .dataset import Dataset @@ -371,7 +373,7 @@ def _open(a: Union[str, PurePath, Dict[str, Any], List[Any], Tuple[Any, ...]]) - if isinstance(a, Dataset): return a.mutate() - if isinstance(a, zarr.hierarchy.Group): + if isinstance(a, zarr_2_or_3.Group): return Zarr(a).mutate() if isinstance(a, str): diff --git a/src/anemoi/datasets/data/stores.py b/src/anemoi/datasets/data/stores.py index e31d4cfb9..eff1fbd0b 100644 --- a/src/anemoi/datasets/data/stores.py +++ b/src/anemoi/datasets/data/stores.py @@ -12,7 +12,6 @@ import logging import os import tempfile -import warnings from functools import cached_property from typing import Any from typing import Dict @@ -27,6 +26,7 @@ from anemoi.utils.dates import frequency_to_timedelta from numpy.typing import NDArray +from ..zarr_versions import zarr_2_or_3 from . import MissingDateError from .dataset import Dataset from .dataset import FullIndex @@ -42,151 +42,12 @@ LOG = logging.getLogger(__name__) -class ReadOnlyStore(zarr.storage.BaseStore): - """A base class for read-only stores.""" - - def __delitem__(self, key: str) -> None: - """Prevent deletion of items.""" - raise NotImplementedError() - - def __setitem__(self, key: str, value: bytes) -> None: - """Prevent setting of items.""" - raise NotImplementedError() - - def __len__(self) -> int: - """Return the number of items in the store.""" - raise NotImplementedError() - - def __iter__(self) -> iter: - """Return an iterator over the store.""" - raise NotImplementedError() - - -class HTTPStore(ReadOnlyStore): - """A read-only store for HTTP(S) resources.""" - - def __init__(self, url: str) -> None: - """Initialize the HTTPStore with a URL.""" - self.url = url - - def __getitem__(self, key: str) -> bytes: - """Retrieve an item from the store.""" - import requests - - r = requests.get(self.url + "/" + key) - - if r.status_code == 404: - raise KeyError(key) - - r.raise_for_status() - return r.content - - -class S3Store(ReadOnlyStore): - """A read-only store for S3 resources.""" - - """We write our own S3Store because the one used by zarr (s3fs) - does not play well with fork(). We also get to control the s3 client - options using the anemoi configs. - """ - - def __init__(self, url: str, region: Optional[str] = None) -> None: - """Initialize the S3Store with a URL and optional region.""" - from anemoi.utils.remote.s3 import s3_client - - _, _, self.bucket, self.key = url.split("/", 3) - self.s3 = s3_client(self.bucket, region=region) - - def __getitem__(self, key: str) -> bytes: - """Retrieve an item from the store.""" - try: - response = self.s3.get_object(Bucket=self.bucket, Key=self.key + "/" + key) - except self.s3.exceptions.NoSuchKey: - raise KeyError(key) - - return response["Body"].read() - - -class PlanetaryComputerStore(ReadOnlyStore): - """We write our own Store to access catalogs on Planetary Computer, - as it requires some extra arguments to use xr.open_zarr. - """ - - def __init__(self, data_catalog_id: str) -> None: - """Initialize the PlanetaryComputerStore with a data catalog ID. - - Parameters - ---------- - data_catalog_id : str - The data catalog ID. - """ - self.data_catalog_id = data_catalog_id - - import planetary_computer - import pystac_client - - catalog = pystac_client.Client.open( - "https://planetarycomputer.microsoft.com/api/stac/v1/", - modifier=planetary_computer.sign_inplace, - ) - collection = catalog.get_collection(self.data_catalog_id) - - asset = collection.assets["zarr-abfs"] - - if "xarray:storage_options" in asset.extra_fields: - store = { - "store": asset.href, - "storage_options": asset.extra_fields["xarray:storage_options"], - **asset.extra_fields["xarray:open_kwargs"], - } - else: - store = { - "filename_or_obj": asset.href, - **asset.extra_fields["xarray:open_kwargs"], - } - - self.store = store - - def __getitem__(self, key: str) -> bytes: - """Retrieve an item from the store.""" - raise NotImplementedError() - - -class DebugStore(ReadOnlyStore): - """A store to debug the zarr loading.""" - - def __init__(self, store: ReadOnlyStore) -> None: - """Initialize the DebugStore with another store.""" - assert not isinstance(store, DebugStore) - self.store = store - - def __getitem__(self, key: str) -> bytes: - """Retrieve an item from the store and print debug information.""" - # print() - print("GET", key, self) - # traceback.print_stack(file=sys.stdout) - return self.store[key] - - def __len__(self) -> int: - """Return the number of items in the store.""" - return len(self.store) - - def __iter__(self) -> iter: - """Return an iterator over the store.""" - warnings.warn("DebugStore: iterating over the store") - return iter(self.store) - - def __contains__(self, key: str) -> bool: - """Check if the store contains a key.""" - return key in self.store - - -def name_to_zarr_store(path_or_url: str) -> ReadOnlyStore: +def name_to_zarr_store(path_or_url: str) -> Any: """Convert a path or URL to a zarr store.""" store = path_or_url if store.startswith("s3://"): - return S3Store(store) + return zarr_2_or_3.S3Store(store) if store.startswith("http://") or store.startswith("https://"): @@ -213,17 +74,14 @@ def name_to_zarr_store(path_or_url: str) -> ReadOnlyStore: bits = parsed.netloc.split(".") if len(bits) == 5 and (bits[1], bits[3], bits[4]) == ("s3", "amazonaws", "com"): s3_url = f"s3://{bits[0]}{parsed.path}" - store = S3Store(s3_url, region=bits[2]) - elif store.startswith("https://planetarycomputer.microsoft.com/"): - data_catalog_id = store.rsplit("/", 1)[-1] - store = PlanetaryComputerStore(data_catalog_id).store + store = zarr_2_or_3.S3Store(s3_url, region=bits[2]) else: - store = HTTPStore(store) + store = zarr_2_or_3.HTTPStore(store) return store -def open_zarr(path: str, dont_fail: bool = False, cache: int = None) -> zarr.hierarchy.Group: +def open_zarr(path: str, dont_fail: bool = False, cache: int = None) -> Any: """Open a zarr store from a path.""" try: store = name_to_zarr_store(path) @@ -237,24 +95,25 @@ def open_zarr(path: str, dont_fail: bool = False, cache: int = None) -> zarr.hie "DEBUG_ZARR_LOADING is only implemented for DirectoryStore. " "Please disable it for other backends." ) - store = zarr.storage.DirectoryStore(store) - store = DebugStore(store) + store = zarr_2_or_3.DirectoryStore(store) + store = zarr_2_or_3.DebugStore(store) if cache is not None: - store = zarr.LRUStoreCache(store, max_size=cache) + store = zarr_2_or_3.LRUStoreCache(store, max_size=cache) + + return zarr.open(store, mode="r") - return zarr.convenience.open(store, "r") - except zarr.errors.PathNotFoundError: + except zarr_2_or_3.FileNotFoundException: if not dont_fail: - raise zarr.errors.PathNotFoundError(path) + raise FileNotFoundError(f"Zarr store not found: {path}") class Zarr(Dataset): """A zarr dataset.""" - def __init__(self, path: Union[str, zarr.hierarchy.Group]) -> None: + def __init__(self, path: Union[str, Any]) -> None: """Initialize the Zarr dataset with a path or zarr group.""" - if isinstance(path, zarr.hierarchy.Group): + if isinstance(path, zarr_2_or_3.Group): self.was_zarr = True self.path = str(id(path)) self.z = path @@ -264,7 +123,7 @@ def __init__(self, path: Union[str, zarr.hierarchy.Group]) -> None: self.z = open_zarr(self.path) # This seems to speed up the reading of the data a lot - self.data = self.z.data + self.data = self.z["data"] self._missing = set() @property @@ -315,7 +174,7 @@ def _unwind(self, index: Union[int, slice, list, tuple], rest: list, shape: tupl @cached_property def chunks(self) -> TupleIndex: """Return the chunks of the dataset.""" - return self.z.data.chunks + return self.data.chunks @cached_property def shape(self) -> Shape: @@ -325,39 +184,45 @@ def shape(self) -> Shape: @cached_property def dtype(self) -> np.dtype: """Return the data type of the dataset.""" - return self.z.data.dtype + return self.data.dtype @cached_property def dates(self) -> NDArray[np.datetime64]: """Return the dates of the dataset.""" - return self.z.dates[:] # Convert to numpy + dates = self.z["dates"][:] + if not dates.dtype == np.dtype("datetime64[s]"): + # The datasets created with zarr3 will have the dates as int64 as long + # as zarr3 does not support datetime64 + LOG.warning("Converting dates to 'datetime64[s]'") + dates = dates.astype("datetime64[s]") + return dates @property def latitudes(self) -> NDArray[Any]: """Return the latitudes of the dataset.""" try: - return self.z.latitudes[:] + return self.z["latitudes"][:] except AttributeError: LOG.warning("No 'latitudes' in %r, trying 'latitude'", self) - return self.z.latitude[:] + return self.z["latitude"][:] @property def longitudes(self) -> NDArray[Any]: """Return the longitudes of the dataset.""" try: - return self.z.longitudes[:] + return self.z["longitudes"][:] except AttributeError: LOG.warning("No 'longitudes' in %r, trying 'longitude'", self) - return self.z.longitude[:] + return self.z["longitude"][:] @property def statistics(self) -> Dict[str, NDArray[Any]]: """Return the statistics of the dataset.""" return dict( - mean=self.z.mean[:], - stdev=self.z.stdev[:], - maximum=self.z.maximum[:], - minimum=self.z.minimum[:], + mean=self.z["mean"][:], + stdev=self.z["stdev"][:], + maximum=self.z["maximum"][:], + minimum=self.z["minimum"][:], ) def statistics_tendencies(self, delta: Optional[datetime.timedelta] = None) -> Dict[str, NDArray[Any]]: @@ -486,7 +351,7 @@ def collect_input_sources(self, collected: set) -> None: class ZarrWithMissingDates(Zarr): """A zarr dataset with missing dates.""" - def __init__(self, path: Union[str, zarr.hierarchy.Group]) -> None: + def __init__(self, path: Union[str, Any]) -> None: """Initialize the ZarrWithMissingDates dataset with a path or zarr group.""" super().__init__(path) @@ -602,7 +467,7 @@ def zarr_lookup(name: str, fail: bool = True) -> Optional[str]: LOG.info("Opening `%s` as `%s`", name, full) QUIET.add(name) return full - except zarr.errors.PathNotFoundError: + except zarr_2_or_3.FileNotFoundException: pass if fail: diff --git a/src/anemoi/datasets/zarr_versions/__init__.py b/src/anemoi/datasets/zarr_versions/__init__.py new file mode 100644 index 000000000..39fd71d91 --- /dev/null +++ b/src/anemoi/datasets/zarr_versions/__init__.py @@ -0,0 +1,23 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import zarr + +version = zarr.__version__.split(".")[0] + +if version == "2": + from . import zarr2 as zarr_2_or_3 + +elif version == "3": + from . import zarr3 as zarr_2_or_3 +else: + raise ImportError(f"Unsupported Zarr version: {zarr.__version__}. Supported versions are 2 and 3.") + +__all__ = ["zarr_2_or_3"] diff --git a/src/anemoi/datasets/zarr_versions/zarr2.py b/src/anemoi/datasets/zarr_versions/zarr2.py new file mode 100644 index 000000000..69980e31b --- /dev/null +++ b/src/anemoi/datasets/zarr_versions/zarr2.py @@ -0,0 +1,138 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +import logging +import warnings +from typing import Any +from typing import Optional + +import zarr + +LOG = logging.getLogger(__name__) + + +version = 2 + +FileNotFoundException = zarr.errors.PathNotFoundError +Group = zarr.hierarchy.Group +open_mode_append = "w+" + + +class ReadOnlyStore(zarr.storage.BaseStore): + """A base class for read-only stores.""" + + def __delitem__(self, key: str) -> None: + """Prevent deletion of items.""" + raise NotImplementedError() + + def __setitem__(self, key: str, value: bytes) -> None: + """Prevent setting of items.""" + raise NotImplementedError() + + def __len__(self) -> int: + """Return the number of items in the store.""" + raise NotImplementedError() + + def __iter__(self) -> iter: + """Return an iterator over the store.""" + raise NotImplementedError() + + +class S3Store(ReadOnlyStore): + """A read-only store for S3 resources.""" + + """We write our own S3Store because the one used by zarr (s3fs) + does not play well with fork(). We also get to control the s3 client + options using the anemoi configs. + """ + + def __init__(self, url: str, region: Optional[str] = None) -> None: + """Initialize the S3Store with a URL and optional region.""" + from anemoi.utils.remote.s3 import s3_client + + super().__init__() + + _, _, self.bucket, self.key = url.split("/", 3) + self.s3 = s3_client(self.bucket, region=region) + + # Version 2 + def __getitem__(self, key: str) -> bytes: + """Retrieve an item from the store.""" + try: + response = self.s3.get_object(Bucket=self.bucket, Key=self.key + "/" + key) + except self.s3.exceptions.NoSuchKey: + raise KeyError(key) + + return response["Body"].read() + + +class HTTPStore(ReadOnlyStore): + """A read-only store for HTTP(S) resources.""" + + def __init__(self, url: str) -> None: + """Initialize the HTTPStore with a URL.""" + super().__init__() + self.url = url + + def __getitem__(self, key: str) -> bytes: + """Retrieve an item from the store.""" + import requests + + r = requests.get(self.url + "/" + key) + + if r.status_code == 404: + raise KeyError(key) + + r.raise_for_status() + return r.content + + +class DebugStore(ReadOnlyStore): + """A store to debug the zarr loading.""" + + def __init__(self, store: Any) -> None: + super().__init__() + """Initialize the DebugStore with another store.""" + assert not isinstance(store, DebugStore) + self.store = store + + def __getitem__(self, key: str) -> bytes: + """Retrieve an item from the store and print debug information.""" + # print() + print("GET", key, self) + # traceback.print_stack(file=sys.stdout) + return self.store[key] + + def __len__(self) -> int: + """Return the number of items in the store.""" + return len(self.store) + + def __iter__(self) -> iter: + """Return an iterator over the store.""" + warnings.warn("DebugStore: iterating over the store") + return iter(self.store) + + def __contains__(self, key: str) -> bool: + """Check if the store contains a key.""" + return key in self.store + + +def create_array(zarr_root, *args, **kwargs): + return zarr_root.create_dataset(*args, **kwargs) + + +def change_dtype_datetime64(dtype): + return dtype + + +def cast_dtype_datetime64(array, dtype): + return array, dtype + + +def supports_datetime64(): + return True diff --git a/src/anemoi/datasets/zarr_versions/zarr3.py b/src/anemoi/datasets/zarr_versions/zarr3.py new file mode 100644 index 000000000..c7d18b121 --- /dev/null +++ b/src/anemoi/datasets/zarr_versions/zarr3.py @@ -0,0 +1,123 @@ +# (C) Copyright 2025 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +import logging + +import zarr + +LOG = logging.getLogger(__name__) + +version = 3 +FileNotFoundException = FileNotFoundError +Group = zarr.Group +open_mode_append = "a" + + +class S3Store(zarr.storage.ObjectStore): + """We use our class to manage per bucket credentials""" + + def __init__(self, url): + + import boto3 + from anemoi.utils.remote.s3 import s3_options + from obstore.auth.boto3 import Boto3CredentialProvider + from obstore.store import from_url + + options = s3_options(url) + + credential_provider = Boto3CredentialProvider( + session=boto3.session.Session( + aws_access_key_id=options["aws_access_key_id"], + aws_secret_access_key=options["aws_secret_access_key"], + ), + ) + + objectstore = from_url( + url, + credential_provider=credential_provider, + endpoint=options["endpoint_url"], + ) + + super().__init__(objectstore, read_only=True) + + +class HTTPStore(zarr.storage.ObjectStore): + + def __init__(self, url): + + from obstore.store import from_url + + objectstore = from_url(url) + + super().__init__(objectstore, read_only=True) + + +DebugStore = zarr.storage.LoggingStore + + +def create_array(zarr_root, *args, **kwargs): + if "compressor" in kwargs and kwargs["compressor"] is None: + # compressor is deprecated, use compressors instead + kwargs.pop("compressor") + kwargs["compressors"] = () + + data = kwargs.pop("data", None) + if data is not None: + kwargs.setdefault("dtype", change_dtype_datetime64(data.dtype)) + kwargs.setdefault("shape", data.shape) + + try: + z = zarr_root.create_array(*args, **kwargs) + if data is not None: + z[:] = data + return z + except Exception: + LOG.exception("Failed to create array in Zarr store") + LOG.error( + "Failed to create array in Zarr store with args: %s, kwargs: %s", + args, + kwargs, + ) + raise + + +def change_dtype_datetime64(dtype): + # remove this flag (and the relevant code) when Zarr 3 supports datetime64 + # https://github.com/zarr-developers/zarr-python/issues/2616 + import numpy as np + + if dtype == "datetime64[s]": + dtype = np.dtype("int64") + return dtype + + +def cast_dtype_datetime64(array, dtype): + # remove this flag (and the relevant code) when Zarr 3 supports datetime64 + # https://github.com/zarr-developers/zarr-python/issues/2616 + import numpy as np + + if dtype == np.dtype("datetime64[s]"): + dtype = "int64" + array = array.astype(dtype) + + return array, dtype + + +def supports_datetime64(): + store = zarr.storage.MemoryStore() + try: + zarr.create_array(store=store, shape=(10,), dtype="datetime64[s]") + return True + except KeyError: + # If a KeyError is raised, it means datetime64 is not supported + return False + + +if __name__ == "__main__": + print("Zarr version:", version) + print("Zarr supports datetime64:", supports_datetime64()) diff --git a/tests/create/test_create.py b/tests/create/test_create.py index 668ffa217..70b49a91c 100755 --- a/tests/create/test_create.py +++ b/tests/create/test_create.py @@ -176,6 +176,7 @@ def compare_dot_zattrs(a: dict, b: dict, path: str, errors: list) -> None: "description", "config_path", "total_size", + "total_number_of_files", # expected to differ when comparing datasets generated with zarr 2 vs zarr 3 ]: if type(a[k]) is not type(b[k]): errors.append(f"❌ {path}.{k} : type differs {type(a[k])} != {type(b[k])}") diff --git a/tests/create/test_sources.py b/tests/create/test_sources.py index 6e6098800..05281f16e 100644 --- a/tests/create/test_sources.py +++ b/tests/create/test_sources.py @@ -7,7 +7,6 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import logging import os import sys @@ -245,12 +244,44 @@ def test_kerchunk(get_test_data: callable) -> None: assert ds.shape == (4, 1, 1, 1038240) +@skip_if_offline +@skip_missing_packages("planetary_computer", "adlfs") +def test_planetary_computer_conus404() -> None: + """Test loading and validating the planetary_computer_conus404 dataset.""" + + config = { + "dates": { + "start": "2022-01-01", + "end": "2022-01-02", + "frequency": "1d", + }, + "input": { + "planetary_computer": { + "data_catalog_id": "conus404", + "param": ["Z"], + "level": [1], + "patch": { + "coordinates": ["bottom_top_stag"], + "rename": { + "bottom_top_stag": "level", + }, + "attributes": { + "lon": {"standard_name": "longitude", "long_name": "Longitude"}, + "lat": {"standard_name": "latitude", "long_name": "Latitude"}, + }, + }, + } + }, + } + + created = create_dataset(config=config, output=None) + ds = open_dataset(created) + assert ds.shape == (2, 1, 1, 1387505), ds.shape + + if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) - test_kerchunk() - exit() - """Run all test functions that start with 'test_'.""" - for name, obj in list(globals().items()): - if name.startswith("test_") and callable(obj): - print(f"Running {name}...") - obj() + test_planetary_computer_conus404() + exit(0) + from anemoi.utils.testing import run_tests + + run_tests(globals()) diff --git a/tests/test_data.py b/tests/test_data.py index 07b35887e..72b57dd60 100644 --- a/tests/test_data.py +++ b/tests/test_data.py @@ -37,6 +37,7 @@ from anemoi.datasets.data.statistics import Statistics from anemoi.datasets.data.stores import Zarr from anemoi.datasets.data.subset import Subset +from anemoi.datasets.zarr_versions import zarr_2_or_3 VALUES = 10 @@ -57,7 +58,7 @@ def mockup_open_zarr(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs): - with patch("zarr.convenience.open", zarr_from_str): + with patch("zarr.open", zarr_from_str): with patch("anemoi.datasets.data.stores.zarr_lookup", lambda name: name): return func(*args, **kwargs) @@ -104,7 +105,7 @@ def create_zarr( ensemble: Optional[int] = None, grids: Optional[int] = None, missing: bool = False, -) -> zarr.Group: +) -> zarr_2_or_3.Group: """Create a Zarr dataset. Parameters @@ -142,7 +143,7 @@ def create_zarr( dates.append(date) date += frequency - dates = np.array(dates, dtype="datetime64") + dates = np.array(dates, dtype="datetime64[s]") ensembles = ensemble if ensemble is not None else 1 values = grids if grids is not None else VALUES @@ -154,28 +155,42 @@ def create_zarr( for e in range(ensembles): data[i, j, e] = _(date.astype(object), var, k, e, values) - root.create_dataset( + zarr_2_or_3.create_array( + root, "data", - data=data, dtype=data.dtype, chunks=data.shape, compressor=None, - ) - root.create_dataset( + shape=data.shape, + )[...] = data + + dates, dtype_ = zarr_2_or_3.cast_dtype_datetime64(dates, dates.dtype) + del dtype_ + zarr_2_or_3.create_array( + root, "dates", - data=dates, compressor=None, - ) - root.create_dataset( + dtype=dates.dtype, + shape=dates.shape, + )[...] = dates + + latitudes = np.array([x + values for x in range(values)]) + zarr_2_or_3.create_array( + root, "latitudes", - data=np.array([x + values for x in range(values)]), compressor=None, - ) - root.create_dataset( + dtype=latitudes.dtype, + shape=latitudes.shape, + )[...] = latitudes + + longitudes = np.array([x + values for x in range(values)]) + zarr_2_or_3.create_array( + root, "longitudes", - data=np.array([x + values for x in range(values)]), compressor=None, - ) + dtype=longitudes.dtype, + shape=longitudes.shape, + )[...] = longitudes root.attrs["frequency"] = frequency_to_string(frequency) root.attrs["resolution"] = resolution @@ -196,26 +211,42 @@ def create_zarr( root.attrs["missing_dates"] = [d.isoformat() for d in missing_dates] - root.create_dataset( + zarr_2_or_3.create_array( + root, "mean", - data=np.mean(data, axis=0), compressor=None, - ) - root.create_dataset( + shape=data.shape[1:], + dtype=data.dtype, + )[ + ... + ] = np.mean(data, axis=0) + zarr_2_or_3.create_array( + root, "stdev", - data=np.std(data, axis=0), compressor=None, - ) - root.create_dataset( + shape=data.shape[1:], + dtype=data.dtype, + )[ + ... + ] = np.std(data, axis=0) + zarr_2_or_3.create_array( + root, "maximum", - data=np.max(data, axis=0), compressor=None, - ) - root.create_dataset( + shape=data.shape[1:], + dtype=data.dtype, + )[ + ... + ] = np.max(data, axis=0) + zarr_2_or_3.create_array( + root, "minimum", - data=np.min(data, axis=0), compressor=None, - ) + shape=data.shape[1:], + dtype=data.dtype, + )[ + ... + ] = np.min(data, axis=0) return root diff --git a/tests/test_data_gridded.py b/tests/test_data_gridded.py index 5e4738980..1a19ddf40 100644 --- a/tests/test_data_gridded.py +++ b/tests/test_data_gridded.py @@ -24,6 +24,7 @@ from anemoi.utils.dates import frequency_to_timedelta from anemoi.datasets import open_dataset +from anemoi.datasets.zarr_versions import zarr_2_or_3 VALUES = 20 @@ -44,7 +45,7 @@ def mockup_open_zarr(func: Callable) -> Callable: @wraps(func) def wrapper(*args, **kwargs): - with patch("zarr.convenience.open", zarr_from_str): + with patch("zarr.open", zarr_from_str): with patch("anemoi.datasets.data.stores.zarr_lookup", lambda name: name): return func(*args, **kwargs) @@ -144,24 +145,29 @@ def create_zarr( for e in range(ensembles): data[i, j, e] = _(date.astype(object), var, k, e, values) - root.create_dataset( + zarr_2_or_3.create_array( + root, "data", data=data, dtype=data.dtype, chunks=data.shape, compressor=None, ) - root.create_dataset( + # Store dates as ISO strings to avoid unsupported dtype in Zarr v3 + zarr_2_or_3.create_array( + root, "dates", - data=dates, + data=np.array([str(d) for d in dates], dtype="U32"), compressor=None, ) - root.create_dataset( + zarr_2_or_3.create_array( + root, "latitudes", data=np.array([x + values for x in range(values)]), compressor=None, ) - root.create_dataset( + zarr_2_or_3.create_array( + root, "longitudes", data=np.array([x + values for x in range(values)]), compressor=None, @@ -186,22 +192,26 @@ def create_zarr( root.attrs["missing_dates"] = [d.isoformat() for d in missing_dates] - root.create_dataset( + zarr_2_or_3.create_array( + root, "mean", data=np.mean(data, axis=0), compressor=None, ) - root.create_dataset( + zarr_2_or_3.create_array( + root, "stdev", data=np.std(data, axis=0), compressor=None, ) - root.create_dataset( + zarr_2_or_3.create_array( + root, "maximum", data=np.max(data, axis=0), compressor=None, ) - root.create_dataset( + zarr_2_or_3.create_array( + root, "minimum", data=np.min(data, axis=0), compressor=None, diff --git a/tests/xarray/test_zarr.py b/tests/xarray/test_zarr.py index c0a8f8ed2..8bd718101 100644 --- a/tests/xarray/test_zarr.py +++ b/tests/xarray/test_zarr.py @@ -13,7 +13,6 @@ from anemoi.utils.testing import skip_missing_packages from anemoi.datasets.create.sources.xarray import XarrayFieldList -from anemoi.datasets.data.stores import name_to_zarr_store from anemoi.datasets.testing import assert_field_list @@ -134,29 +133,35 @@ def test_noaa_replay() -> None: @skip_if_offline -@skip_missing_packages("planetary_computer", "adlfs") -def test_planetary_computer_conus404() -> None: - """Test loading and validating the planetary_computer_conus404 dataset.""" - url = "https://planetarycomputer.microsoft.com/api/stac/v1/collections/conus404" - ds = xr.open_zarr(**name_to_zarr_store(url)) +@skip_missing_packages("s3fs") +def test_aws_s3() -> None: + """Test loading and validating an AWS S3 dataset.""" + url = "s3://aodn-cloud-optimised/model_sea_level_anomaly_gridded_realtime.zarr" + ds = xr.open_zarr(url, consolidated=True, storage_options={"anon": True}) - flavour = { - "rules": { - "latitude": {"name": "lat"}, - "longitude": {"name": "lon"}, - "x": {"name": "west_east"}, - "y": {"name": "south_north"}, - "time": {"name": "time"}, - }, - } + fs = XarrayFieldList.from_xarray(ds) - fs = XarrayFieldList.from_xarray(ds, flavour=flavour) + assert_field_list( + fs, + 400, + "2011-09-01T00:00:00", + "2011-12-12T00:00:00", + ) + + +@skip_if_offline +def test_aws_s3_https() -> None: + """Test loading and validating an AWS S3 dataset via HTTPS.""" + url = "https://aodn-cloud-optimised.s3.amazonaws.com/model_sea_level_anomaly_gridded_realtime.zarr" + ds = xr.open_zarr(url, consolidated=True) + + fs = XarrayFieldList.from_xarray(ds) assert_field_list( fs, - 74634912, - "1979-10-01T00:00:00", - "2022-09-30T23:00:00", + 400, + "2011-09-01T00:00:00", + "2011-12-12T00:00:00", )