Skip to content

Improve write performance of shards #2977

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions bench/write_shard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import itertools
import os.path
import shutil
import sys
import tempfile
import timeit

import line_profiler
import numpy as np

import zarr
import zarr.codecs
import zarr.codecs.sharding

if __name__ == "__main__":
sys.path.insert(0, "..")

# setup
with tempfile.TemporaryDirectory() as path:

ndim = 3
opt = {
'shape': [1024]*ndim,
'chunks': [128]*ndim,
'shards': [512]*ndim,
'dtype': np.float64,
}

store = zarr.storage.LocalStore(path)
z = zarr.create_array(store, **opt)
print(z)

def cleanup() -> None:
for elem in os.listdir(path):
elem = os.path.join(path, elem)
if not elem.endswith(".json"):
if os.path.isdir(elem):
shutil.rmtree(elem)
else:
os.remove(elem)

def write() -> None:
wchunk = [512]*ndim
nwchunks = [n//s for n, s in zip(opt['shape'], wchunk, strict=True)]
for shard in itertools.product(*(range(n) for n in nwchunks)):
slicer = tuple(
slice(i*n, (i+1)*n)
for i, n in zip(shard, wchunk, strict=True)
)
d = np.random.rand(*wchunk).astype(opt['dtype'])
z[slicer] = d

print("*" * 79)

# time
vars = {"write": write, "cleanup": cleanup, "z": z, "opt": opt}
t = timeit.repeat("write()", "cleanup()", repeat=2, number=1, globals=vars)
print(t)
print(min(t))
print(z)

# profile
# f = zarr.codecs.sharding.ShardingCodec._encode_partial_single
# profile = line_profiler.LineProfiler(f)
# profile.run("write()")
# profile.print_stats()
4 changes: 2 additions & 2 deletions src/zarr/codecs/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def create_empty(
buffer_prototype = default_buffer_prototype()
index = _ShardIndex.create_empty(chunks_per_shard)
obj = cls()
obj.buf = buffer_prototype.buffer.create_zero_length()
obj.buf = buffer_prototype.buffer.Delayed.create_zero_length()
obj.index = index
return obj

Expand Down Expand Up @@ -251,7 +251,7 @@ def create_empty(
if buffer_prototype is None:
buffer_prototype = default_buffer_prototype()
obj = cls()
obj.buf = buffer_prototype.buffer.create_zero_length()
obj.buf = buffer_prototype.buffer.Delayed.create_zero_length()
obj.index = _ShardIndex.create_empty(chunks_per_shard)
return obj

Expand Down
147 changes: 147 additions & 0 deletions src/zarr/core/buffer/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,153 @@
nd_buffer: type[NDBuffer]


class DelayedBuffer(Buffer):
"""
A Buffer that is the virtual concatenation of other buffers.
"""
_BufferImpl: type
_concatenate: callable

def __init__(self, array: NDArrayLike | list[NDArrayLike] | None) -> None:
if array is None:
self._data_list = []
elif isinstance(array, list):
self._data_list = list(array)
else:
self._data_list = [array]
for array in self._data_list:
if array.ndim != 1:
raise ValueError("array: only 1-dim allowed")
if array.dtype != np.dtype("b"):
raise ValueError("array: only byte dtype allowed")

@property
def _data(self) -> npt.NDArray[Any]:
return type(self)._concatenate(self._data_list)

@classmethod
def from_buffer(cls, buffer: Buffer) -> Self:
if isinstance(buffer, cls):
return cls(buffer._data_list)
else:

Check warning on line 533 in src/zarr/core/buffer/core.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/core.py#L533

Added line #L533 was not covered by tests
return cls(buffer._data)

def __add__(self, other: Buffer) -> Self:
if isinstance(other, self.__class__):
return self.__class__(self._data_list + other._data_list)
else:
return self.__class__(self._data_list + [other._data])

Check warning on line 540 in src/zarr/core/buffer/core.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/core.py#L540

Added line #L540 was not covered by tests

def __radd__(self, other: Buffer) -> Self:

Check warning on line 542 in src/zarr/core/buffer/core.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/core.py#L542

Added line #L542 was not covered by tests
if isinstance(other, self.__class__):
return self.__class__(other._data_list + self._data_list)
else:
return self.__class__([other._data] + self._data_list)

def __len__(self) -> int:
return sum(map(len, self._data_list))

def __getitem__(self, key: slice) -> Self:

Check warning on line 551 in src/zarr/core/buffer/core.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/core.py#L550-L551

Added lines #L550 - L551 were not covered by tests
check_item_key_is_1d_contiguous(key)
start, stop = key.start, key.stop

Check warning on line 553 in src/zarr/core/buffer/core.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/core.py#L553

Added line #L553 was not covered by tests
this_len = len(self)
if start is None:
start = 0
if start < 0:
start = this_len + start
if stop is None:
stop = this_len
if stop < 0:
stop = this_len + stop
if stop > this_len:

Check warning on line 563 in src/zarr/core/buffer/core.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/core.py#L562-L563

Added lines #L562 - L563 were not covered by tests
stop = this_len
if stop <= start:

Check warning on line 565 in src/zarr/core/buffer/core.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/core.py#L565

Added line #L565 was not covered by tests
return Buffer.from_buffer(b'')

new_list = []
offset = 0
found_last = False
for chunk in self._data_list:
chunk_size = len(chunk)
skip = False
if 0 <= start - offset < chunk_size:
# first chunk

Check warning on line 575 in src/zarr/core/buffer/core.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/core.py#L575

Added line #L575 was not covered by tests
if stop - offset <= chunk_size:
# also last chunk

Check warning on line 577 in src/zarr/core/buffer/core.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/core.py#L577

Added line #L577 was not covered by tests
chunk = chunk[start-offset:stop-offset]
found_last = True

Check warning on line 579 in src/zarr/core/buffer/core.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/core.py#L579

Added line #L579 was not covered by tests
else:
chunk = chunk[start-offset:]

Check warning on line 581 in src/zarr/core/buffer/core.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/core.py#L581

Added line #L581 was not covered by tests
elif 0 <= stop - offset <= chunk_size:
# last chunk
chunk = chunk[:stop-offset]
found_last = True

Check warning on line 585 in src/zarr/core/buffer/core.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/core.py#L585

Added line #L585 was not covered by tests
elif chunk_size <= start - offset:
# before first chunk
skip = True
else:
# middle chunk
pass

if not skip:
new_list.append(chunk)
if found_last:
break
offset += chunk_size
assert sum(map(len, new_list)) == stop - start
return self.__class__(new_list)

def __setitem__(self, key: slice, value: Any) -> None:
# This assumes that `value` is a broadcasted array
check_item_key_is_1d_contiguous(key)
start, stop = key.start, key.stop
if start is None:
start = 0
if start < 0:
start = len(self) + start
if stop is None:
stop = len(self)
if stop < 0:
stop = len(self) + stop
if stop <= start:
return

offset = 0
found_last = False
value = memoryview(np.asanyarray(value))
for chunk in self._data_list:
chunk_size = len(chunk)
skip = False
if 0 <= start - offset < chunk_size:
# first chunk
if stop - offset <= chunk_size:
# also last chunk
chunk = chunk[start-offset:stop-offset]
found_last = True
else:
chunk = chunk[start-offset:]
elif 0 <= stop - offset <= chunk_size:
# last chunk
chunk = chunk[:stop-offset]
found_last = True

Check warning on line 633 in src/zarr/core/buffer/core.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/core.py#L622-L633

Added lines #L622 - L633 were not covered by tests
elif chunk_size <= start - offset:
# before first chunk
skip = True
else:
# middle chunk
pass

if not skip:

Check warning on line 641 in src/zarr/core/buffer/core.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/core.py#L635-L641

Added lines #L635 - L641 were not covered by tests
chunk[:] = value[:len(chunk)]
value = value[len(chunk):]

Check warning on line 643 in src/zarr/core/buffer/core.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/core.py#L643

Added line #L643 was not covered by tests
if len(value) == 0:
# nothing left to write
break

Check warning on line 646 in src/zarr/core/buffer/core.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/core.py#L645-L646

Added lines #L645 - L646 were not covered by tests
if found_last:
break
offset += chunk_size

Check warning on line 649 in src/zarr/core/buffer/core.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/core.py#L648-L649

Added lines #L648 - L649 were not covered by tests


Check warning on line 651 in src/zarr/core/buffer/core.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/core.py#L651

Added line #L651 was not covered by tests
# The default buffer prototype used throughout the Zarr codebase.
def default_buffer_prototype() -> BufferPrototype:
from zarr.registry import (
Expand Down
33 changes: 33 additions & 0 deletions src/zarr/core/buffer/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,39 @@
self._data.__setitem__(key, value)


class DelayedBuffer(core.DelayedBuffer, Buffer):
"""
A Buffer that is the virtual concatenation of other buffers.
"""
_BufferImpl = Buffer
_concatenate = np.concatenate

def __init__(self, array: NDArrayLike | list[NDArrayLike] | None) -> None:
core.DelayedBuffer.__init__(self, array)
self._data_list = list(map(np.asanyarray, self._data_list))

@classmethod
def create_zero_length(cls) -> Self:
return cls(np.array([], dtype="b"))

@classmethod
def from_buffer(cls, buffer: core.Buffer) -> Self:
if isinstance(buffer, cls):
return cls(buffer._data_list)

Check warning on line 206 in src/zarr/core/buffer/cpu.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/cpu.py#L205-L206

Added lines #L205 - L206 were not covered by tests
else:
return cls(buffer._data)

Check warning on line 208 in src/zarr/core/buffer/cpu.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/cpu.py#L208

Added line #L208 was not covered by tests

@classmethod
def from_bytes(cls, bytes_like: BytesLike) -> Self:
return cls(np.asarray(bytes_like, dtype="b"))

Check warning on line 212 in src/zarr/core/buffer/cpu.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/cpu.py#L212

Added line #L212 was not covered by tests

def as_numpy_array(self) -> npt.NDArray[Any]:
return np.asanyarray(self._data)


Buffer.Delayed = DelayedBuffer


def as_numpy_array_wrapper(
func: Callable[[npt.NDArray[Any]], bytes], buf: core.Buffer, prototype: core.BufferPrototype
) -> core.Buffer:
Expand Down
33 changes: 33 additions & 0 deletions src/zarr/core/buffer/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,39 @@
self._data.__setitem__(key, value)


class DelayedBuffer(core.DelayedBuffer, Buffer):
"""
A Buffer that is the virtual concatenation of other buffers.
"""
_BufferImpl = Buffer
_concatenate = getattr(cp, 'concatenate', None)

def __init__(self, array: NDArrayLike | list[NDArrayLike] | None) -> None:
core.DelayedBuffer.__init__(self, array)
self._data_list = list(map(cp.asarray, self._data_list))

Check warning on line 230 in src/zarr/core/buffer/gpu.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/gpu.py#L229-L230

Added lines #L229 - L230 were not covered by tests

@classmethod
def create_zero_length(cls) -> Self:
return cls(np.array([], dtype="b"))

Check warning on line 234 in src/zarr/core/buffer/gpu.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/gpu.py#L234

Added line #L234 was not covered by tests

@classmethod
def from_buffer(cls, buffer: core.Buffer) -> Self:
if isinstance(buffer, cls):
return cls(buffer._data_list)

Check warning on line 239 in src/zarr/core/buffer/gpu.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/gpu.py#L238-L239

Added lines #L238 - L239 were not covered by tests
else:
return cls(buffer._data)

Check warning on line 241 in src/zarr/core/buffer/gpu.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/gpu.py#L241

Added line #L241 was not covered by tests

@classmethod
def from_bytes(cls, bytes_like: BytesLike) -> Self:
return cls(np.asarray(bytes_like, dtype="b"))

Check warning on line 245 in src/zarr/core/buffer/gpu.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/gpu.py#L245

Added line #L245 was not covered by tests

def as_numpy_array(self) -> npt.NDArray[Any]:
return np.asanyarray(self._data)

Check warning on line 248 in src/zarr/core/buffer/gpu.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/buffer/gpu.py#L248

Added line #L248 was not covered by tests


Buffer.Delayed = DelayedBuffer


buffer_prototype = BufferPrototype(buffer=Buffer, nd_buffer=NDBuffer)

register_buffer(Buffer)
Expand Down
Loading