Skip to content

Commit 06f7796

Browse files
authored
Return scalar when accessing zero dimensional array (#2718)
* Return scalar when accessing zero dimensional array * returning npt.ArrayLike instead of NDArrayLike because of scalar return values * returning npt.ArrayLike instead of NDArrayLike because of scalar return values * fix mypy in tests * fix mypy in tests * fix mypy in tests * improve test_scalar_array * fix typo * add ScalarWrapper * use ScalarWrapper as NDArrayLike * Revert "fix mypy in tests" * Revert "fix mypy in tests" This reverts commit 75d6cdf. * Revert "fix mypy in tests" This reverts commit 34bf260. * format * Revert "returning npt.ArrayLike instead of NDArrayLike because of scalar return values" This reverts commit 1a290c7. * Revert "returning npt.ArrayLike instead of NDArrayLike because of scalar return values" This reverts commit 3348439 * fix mypy for ScalarWrapper * add missing import NDArrayLike * ignore unavoidable mypy error * format * fix __array__ * extend tests * format * fix typing in test_scalar_array * add dtype to ScalarWrapper * correct dtype type * fix test_basic_indexing * fix test_basic_indexing * fix test_basic_indexing for dtype=datetime64[Y] * increase codecov * fix typing * document changes * move test_scalar_wrapper to test_buffer.py * remove ScalarWrapper usage * create NDArrayOrScalarLike * fix NDArrayOrScalarLike * fix mypy * fix mypy * fix mypy * fix mypy in asynchronous.py * fix mypy in test_api.py * fix mypy in test_api.py and synchronous.py * fix mypy in test_api.py and test_array.py * fix mypy in test_array.py * fix mypy in test_array.py * fix mypy in test_array.py * fix mypy in test_array.py * fix mypy in test_array.py, test_api.py, test_buffer.py, test_sharding.py * add bytes, str and datetime to ScalarType * only support numpy datetime64 in ScalarType * remove ScalarWrapper and update changes * undo wrong code changes * rename ``NDArrayOrScalarLike`` to ``NDArrayLikeOrScalar`` * rename ``NDArrayOrScalarLike`` to ``NDArrayLikeOrScalar`` * fix mypy in test_array.py * fix mypy in test_array.py * handle datetype scalars for different units * fix mypy * fix mypy * format
1 parent 50abf3d commit 06f7796

File tree

12 files changed

+142
-67
lines changed

12 files changed

+142
-67
lines changed

changes/2718.bugfix.rst

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
0-dimensional arrays are now returning a scalar. Therefore, the return type of ``__getitem__`` changed
2+
to NDArrayLikeOrScalar. This change is to make the behavior of 0-dimensional arrays consistent with
3+
``numpy`` scalars.

src/zarr/api/asynchronous.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from collections.abc import Iterable
3939

4040
from zarr.abc.codec import Codec
41+
from zarr.core.buffer import NDArrayLikeOrScalar
4142
from zarr.core.chunk_key_encodings import ChunkKeyEncoding
4243
from zarr.storage import StoreLike
4344

@@ -238,7 +239,7 @@ async def load(
238239
path: str | None = None,
239240
zarr_format: ZarrFormat | None = None,
240241
zarr_version: ZarrFormat | None = None,
241-
) -> NDArrayLike | dict[str, NDArrayLike]:
242+
) -> NDArrayLikeOrScalar | dict[str, NDArrayLikeOrScalar]:
242243
"""Load data from an array or group into memory.
243244
244245
Parameters

src/zarr/api/synchronous.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
ShardsLike,
2828
)
2929
from zarr.core.array_spec import ArrayConfigLike
30-
from zarr.core.buffer import NDArrayLike
30+
from zarr.core.buffer import NDArrayLike, NDArrayLikeOrScalar
3131
from zarr.core.chunk_key_encodings import ChunkKeyEncoding, ChunkKeyEncodingLike
3232
from zarr.core.common import (
3333
JSON,
@@ -121,7 +121,7 @@ def load(
121121
path: str | None = None,
122122
zarr_format: ZarrFormat | None = None,
123123
zarr_version: ZarrFormat | None = None,
124-
) -> NDArrayLike | dict[str, NDArrayLike]:
124+
) -> NDArrayLikeOrScalar | dict[str, NDArrayLikeOrScalar]:
125125
"""Load data from an array or group into memory.
126126
127127
Parameters

src/zarr/core/array.py

+26-22
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from zarr.core.buffer import (
3636
BufferPrototype,
3737
NDArrayLike,
38+
NDArrayLikeOrScalar,
3839
NDBuffer,
3940
default_buffer_prototype,
4041
)
@@ -1256,7 +1257,7 @@ async def _get_selection(
12561257
prototype: BufferPrototype,
12571258
out: NDBuffer | None = None,
12581259
fields: Fields | None = None,
1259-
) -> NDArrayLike:
1260+
) -> NDArrayLikeOrScalar:
12601261
# check fields are sensible
12611262
out_dtype = check_fields(fields, self.dtype)
12621263

@@ -1298,14 +1299,16 @@ async def _get_selection(
12981299
out_buffer,
12991300
drop_axes=indexer.drop_axes,
13001301
)
1302+
if isinstance(indexer, BasicIndexer) and indexer.shape == ():
1303+
return out_buffer.as_scalar()
13011304
return out_buffer.as_ndarray_like()
13021305

13031306
async def getitem(
13041307
self,
13051308
selection: BasicSelection,
13061309
*,
13071310
prototype: BufferPrototype | None = None,
1308-
) -> NDArrayLike:
1311+
) -> NDArrayLikeOrScalar:
13091312
"""
13101313
Asynchronous function that retrieves a subset of the array's data based on the provided selection.
13111314
@@ -1318,7 +1321,7 @@ async def getitem(
13181321
13191322
Returns
13201323
-------
1321-
NDArrayLike
1324+
NDArrayLikeOrScalar
13221325
The retrieved subset of the array's data.
13231326
13241327
Examples
@@ -2268,14 +2271,15 @@ def __array__(
22682271
msg = "`copy=False` is not supported. This method always creates a copy."
22692272
raise ValueError(msg)
22702273

2271-
arr_np = self[...]
2274+
arr = self[...]
2275+
arr_np: NDArrayLike = np.array(arr, dtype=dtype)
22722276

22732277
if dtype is not None:
22742278
arr_np = arr_np.astype(dtype)
22752279

22762280
return arr_np
22772281

2278-
def __getitem__(self, selection: Selection) -> NDArrayLike:
2282+
def __getitem__(self, selection: Selection) -> NDArrayLikeOrScalar:
22792283
"""Retrieve data for an item or region of the array.
22802284
22812285
Parameters
@@ -2286,8 +2290,8 @@ def __getitem__(self, selection: Selection) -> NDArrayLike:
22862290
22872291
Returns
22882292
-------
2289-
NDArrayLike
2290-
An array-like containing the data for the requested region.
2293+
NDArrayLikeOrScalar
2294+
An array-like or scalar containing the data for the requested region.
22912295
22922296
Examples
22932297
--------
@@ -2533,7 +2537,7 @@ def get_basic_selection(
25332537
out: NDBuffer | None = None,
25342538
prototype: BufferPrototype | None = None,
25352539
fields: Fields | None = None,
2536-
) -> NDArrayLike:
2540+
) -> NDArrayLikeOrScalar:
25372541
"""Retrieve data for an item or region of the array.
25382542
25392543
Parameters
@@ -2551,8 +2555,8 @@ def get_basic_selection(
25512555
25522556
Returns
25532557
-------
2554-
NDArrayLike
2555-
An array-like containing the data for the requested region.
2558+
NDArrayLikeOrScalar
2559+
An array-like or scalar containing the data for the requested region.
25562560
25572561
Examples
25582562
--------
@@ -2753,7 +2757,7 @@ def get_orthogonal_selection(
27532757
out: NDBuffer | None = None,
27542758
fields: Fields | None = None,
27552759
prototype: BufferPrototype | None = None,
2756-
) -> NDArrayLike:
2760+
) -> NDArrayLikeOrScalar:
27572761
"""Retrieve data by making a selection for each dimension of the array. For
27582762
example, if an array has 2 dimensions, allows selecting specific rows and/or
27592763
columns. The selection for each dimension can be either an integer (indexing a
@@ -2775,8 +2779,8 @@ def get_orthogonal_selection(
27752779
27762780
Returns
27772781
-------
2778-
NDArrayLike
2779-
An array-like containing the data for the requested selection.
2782+
NDArrayLikeOrScalar
2783+
An array-like or scalar containing the data for the requested selection.
27802784
27812785
Examples
27822786
--------
@@ -2989,7 +2993,7 @@ def get_mask_selection(
29892993
out: NDBuffer | None = None,
29902994
fields: Fields | None = None,
29912995
prototype: BufferPrototype | None = None,
2992-
) -> NDArrayLike:
2996+
) -> NDArrayLikeOrScalar:
29932997
"""Retrieve a selection of individual items, by providing a Boolean array of the
29942998
same shape as the array against which the selection is being made, where True
29952999
values indicate a selected item.
@@ -3009,8 +3013,8 @@ def get_mask_selection(
30093013
30103014
Returns
30113015
-------
3012-
NDArrayLike
3013-
An array-like containing the data for the requested selection.
3016+
NDArrayLikeOrScalar
3017+
An array-like or scalar containing the data for the requested selection.
30143018
30153019
Examples
30163020
--------
@@ -3151,7 +3155,7 @@ def get_coordinate_selection(
31513155
out: NDBuffer | None = None,
31523156
fields: Fields | None = None,
31533157
prototype: BufferPrototype | None = None,
3154-
) -> NDArrayLike:
3158+
) -> NDArrayLikeOrScalar:
31553159
"""Retrieve a selection of individual items, by providing the indices
31563160
(coordinates) for each selected item.
31573161
@@ -3169,8 +3173,8 @@ def get_coordinate_selection(
31693173
31703174
Returns
31713175
-------
3172-
NDArrayLike
3173-
An array-like containing the data for the requested coordinate selection.
3176+
NDArrayLikeOrScalar
3177+
An array-like or scalar containing the data for the requested coordinate selection.
31743178
31753179
Examples
31763180
--------
@@ -3339,7 +3343,7 @@ def get_block_selection(
33393343
out: NDBuffer | None = None,
33403344
fields: Fields | None = None,
33413345
prototype: BufferPrototype | None = None,
3342-
) -> NDArrayLike:
3346+
) -> NDArrayLikeOrScalar:
33433347
"""Retrieve a selection of individual items, by providing the indices
33443348
(coordinates) for each selected item.
33453349
@@ -3357,8 +3361,8 @@ def get_block_selection(
33573361
33583362
Returns
33593363
-------
3360-
NDArrayLike
3361-
An array-like containing the data for the requested block selection.
3364+
NDArrayLikeOrScalar
3365+
An array-like or scalar containing the data for the requested block selection.
33623366
33633367
Examples
33643368
--------

src/zarr/core/buffer/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
Buffer,
44
BufferPrototype,
55
NDArrayLike,
6+
NDArrayLikeOrScalar,
67
NDBuffer,
78
default_buffer_prototype,
89
)
@@ -13,6 +14,7 @@
1314
"Buffer",
1415
"BufferPrototype",
1516
"NDArrayLike",
17+
"NDArrayLikeOrScalar",
1618
"NDBuffer",
1719
"default_buffer_prototype",
1820
"numpy_buffer_prototype",

src/zarr/core/buffer/core.py

+19
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ def __eq__(self, other: object) -> Self: # type: ignore[explicit-override, over
105105
"""
106106

107107

108+
ScalarType = int | float | complex | bytes | str | bool | np.generic
109+
NDArrayLikeOrScalar = ScalarType | NDArrayLike
110+
111+
108112
def check_item_key_is_1d_contiguous(key: Any) -> None:
109113
"""Raises error if `key` isn't a 1d contiguous slice"""
110114
if not isinstance(key, slice):
@@ -419,6 +423,21 @@ def as_numpy_array(self) -> npt.NDArray[Any]:
419423
"""
420424
...
421425

426+
def as_scalar(self) -> ScalarType:
427+
"""Returns the buffer as a scalar value"""
428+
if self._data.size != 1:
429+
raise ValueError("Buffer does not contain a single scalar value")
430+
item = self.as_numpy_array().item()
431+
scalar: ScalarType
432+
433+
if np.issubdtype(self.dtype, np.datetime64):
434+
unit: str = np.datetime_data(self.dtype)[0] # Extract the unit (e.g., 'Y', 'D', etc.)
435+
scalar = np.datetime64(item, unit)
436+
else:
437+
scalar = self.dtype.type(item) # Regular conversion for non-datetime types
438+
439+
return scalar
440+
422441
@property
423442
def dtype(self) -> np.dtype[Any]:
424443
return self._data.dtype

src/zarr/core/indexing.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
if TYPE_CHECKING:
3131
from zarr.core.array import Array
32-
from zarr.core.buffer import NDArrayLike
32+
from zarr.core.buffer import NDArrayLikeOrScalar
3333
from zarr.core.chunk_grids import ChunkGrid
3434
from zarr.core.common import ChunkCoords
3535

@@ -937,7 +937,7 @@ class OIndex:
937937
array: Array
938938

939939
# TODO: develop Array generic and move zarr.Array[np.intp] | zarr.Array[np.bool_] to ArrayOfIntOrBool
940-
def __getitem__(self, selection: OrthogonalSelection | Array) -> NDArrayLike:
940+
def __getitem__(self, selection: OrthogonalSelection | Array) -> NDArrayLikeOrScalar:
941941
from zarr.core.array import Array
942942

943943
# if input is a Zarr array, we materialize it now.
@@ -1046,7 +1046,7 @@ def __iter__(self) -> Iterator[ChunkProjection]:
10461046
class BlockIndex:
10471047
array: Array
10481048

1049-
def __getitem__(self, selection: BasicSelection) -> NDArrayLike:
1049+
def __getitem__(self, selection: BasicSelection) -> NDArrayLikeOrScalar:
10501050
fields, new_selection = pop_fields(selection)
10511051
new_selection = ensure_tuple(new_selection)
10521052
new_selection = replace_lists(new_selection)
@@ -1236,7 +1236,9 @@ class VIndex:
12361236
array: Array
12371237

12381238
# TODO: develop Array generic and move zarr.Array[np.intp] | zarr.Array[np.bool_] to ArrayOfIntOrBool
1239-
def __getitem__(self, selection: CoordinateSelection | MaskSelection | Array) -> NDArrayLike:
1239+
def __getitem__(
1240+
self, selection: CoordinateSelection | MaskSelection | Array
1241+
) -> NDArrayLikeOrScalar:
12401242
from zarr.core.array import Array
12411243

12421244
# if input is a Zarr array, we materialize it now.

tests/test_api.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
save_array,
3333
save_group,
3434
)
35+
from zarr.core.buffer import NDArrayLike
3536
from zarr.errors import MetadataValidationError
3637
from zarr.storage import MemoryStore
3738
from zarr.storage._utils import normalize_path
@@ -244,7 +245,9 @@ def test_open_with_mode_r(tmp_path: pathlib.Path) -> None:
244245
z2 = zarr.open(store=tmp_path, mode="r")
245246
assert isinstance(z2, Array)
246247
assert z2.fill_value == 1
247-
assert (z2[:] == 1).all()
248+
result = z2[:]
249+
assert isinstance(result, NDArrayLike)
250+
assert (result == 1).all()
248251
with pytest.raises(ValueError):
249252
z2[:] = 3
250253

@@ -256,7 +259,9 @@ def test_open_with_mode_r_plus(tmp_path: pathlib.Path) -> None:
256259
zarr.ones(store=tmp_path, shape=(3, 3))
257260
z2 = zarr.open(store=tmp_path, mode="r+")
258261
assert isinstance(z2, Array)
259-
assert (z2[:] == 1).all()
262+
result = z2[:]
263+
assert isinstance(result, NDArrayLike)
264+
assert (result == 1).all()
260265
z2[:] = 3
261266

262267

@@ -272,7 +277,9 @@ async def test_open_with_mode_a(tmp_path: pathlib.Path) -> None:
272277
arr[...] = 1
273278
z2 = zarr.open(store=tmp_path, mode="a")
274279
assert isinstance(z2, Array)
275-
assert (z2[:] == 1).all()
280+
result = z2[:]
281+
assert isinstance(result, NDArrayLike)
282+
assert (result == 1).all()
276283
z2[:] = 3
277284

278285

@@ -284,7 +291,9 @@ def test_open_with_mode_w(tmp_path: pathlib.Path) -> None:
284291
arr[...] = 3
285292
z2 = zarr.open(store=tmp_path, mode="w", shape=(3, 3))
286293
assert isinstance(z2, Array)
287-
assert not (z2[:] == 3).all()
294+
result = z2[:]
295+
assert isinstance(result, NDArrayLike)
296+
assert not (result == 3).all()
288297
z2[:] = 3
289298

290299

@@ -1134,7 +1143,9 @@ def test_open_array_with_mode_r_plus(store: Store) -> None:
11341143
zarr.ones(store=store, shape=(3, 3))
11351144
z2 = zarr.open_array(store=store, mode="r+")
11361145
assert isinstance(z2, Array)
1137-
assert (z2[:] == 1).all()
1146+
result = z2[:]
1147+
assert isinstance(result, NDArrayLike)
1148+
assert (result == 1).all()
11381149
z2[:] = 3
11391150

11401151

0 commit comments

Comments
 (0)