|
3 | 3 | import warnings
|
4 | 4 | from typing import TYPE_CHECKING, Any
|
5 | 5 |
|
6 |
| -from zarr.abc.store import ByteRangeRequest, Store |
| 6 | +from zarr.abc.store import ( |
| 7 | + ByteRequest, |
| 8 | + OffsetByteRequest, |
| 9 | + RangeByteRequest, |
| 10 | + Store, |
| 11 | + SuffixByteRequest, |
| 12 | +) |
7 | 13 | from zarr.storage._common import _dereference_path
|
8 | 14 |
|
9 | 15 | if TYPE_CHECKING:
|
@@ -199,31 +205,34 @@ async def get(
|
199 | 205 | self,
|
200 | 206 | key: str,
|
201 | 207 | prototype: BufferPrototype,
|
202 |
| - byte_range: ByteRangeRequest | None = None, |
| 208 | + byte_range: ByteRequest | None = None, |
203 | 209 | ) -> Buffer | None:
|
204 | 210 | # docstring inherited
|
205 | 211 | if not self._is_open:
|
206 | 212 | await self._open()
|
207 | 213 | path = _dereference_path(self.path, key)
|
208 | 214 |
|
209 | 215 | try:
|
210 |
| - if byte_range: |
211 |
| - # fsspec uses start/end, not start/length |
212 |
| - start, length = byte_range |
213 |
| - if start is not None and length is not None: |
214 |
| - end = start + length |
215 |
| - elif length is not None: |
216 |
| - end = length |
217 |
| - else: |
218 |
| - end = None |
219 |
| - value = prototype.buffer.from_bytes( |
220 |
| - await ( |
221 |
| - self.fs._cat_file(path, start=byte_range[0], end=end) |
222 |
| - if byte_range |
223 |
| - else self.fs._cat_file(path) |
| 216 | + if byte_range is None: |
| 217 | + value = prototype.buffer.from_bytes(await self.fs._cat_file(path)) |
| 218 | + elif isinstance(byte_range, RangeByteRequest): |
| 219 | + value = prototype.buffer.from_bytes( |
| 220 | + await self.fs._cat_file( |
| 221 | + path, |
| 222 | + start=byte_range.start, |
| 223 | + end=byte_range.end, |
| 224 | + ) |
224 | 225 | )
|
225 |
| - ) |
226 |
| - |
| 226 | + elif isinstance(byte_range, OffsetByteRequest): |
| 227 | + value = prototype.buffer.from_bytes( |
| 228 | + await self.fs._cat_file(path, start=byte_range.offset, end=None) |
| 229 | + ) |
| 230 | + elif isinstance(byte_range, SuffixByteRequest): |
| 231 | + value = prototype.buffer.from_bytes( |
| 232 | + await self.fs._cat_file(path, start=-byte_range.suffix, end=None) |
| 233 | + ) |
| 234 | + else: |
| 235 | + raise ValueError(f"Unexpected byte_range, got {byte_range}.") |
227 | 236 | except self.allowed_exceptions:
|
228 | 237 | return None
|
229 | 238 | except OSError as e:
|
@@ -270,25 +279,35 @@ async def exists(self, key: str) -> bool:
|
270 | 279 | async def get_partial_values(
|
271 | 280 | self,
|
272 | 281 | prototype: BufferPrototype,
|
273 |
| - key_ranges: Iterable[tuple[str, ByteRangeRequest]], |
| 282 | + key_ranges: Iterable[tuple[str, ByteRequest | None]], |
274 | 283 | ) -> list[Buffer | None]:
|
275 | 284 | # docstring inherited
|
276 | 285 | if key_ranges:
|
277 |
| - paths, starts, stops = zip( |
278 |
| - *( |
279 |
| - ( |
280 |
| - _dereference_path(self.path, k[0]), |
281 |
| - k[1][0], |
282 |
| - ((k[1][0] or 0) + k[1][1]) if k[1][1] is not None else None, |
283 |
| - ) |
284 |
| - for k in key_ranges |
285 |
| - ), |
286 |
| - strict=False, |
287 |
| - ) |
| 286 | + # _cat_ranges expects a list of paths, start, and end ranges, so we need to reformat each ByteRequest. |
| 287 | + key_ranges = list(key_ranges) |
| 288 | + paths: list[str] = [] |
| 289 | + starts: list[int | None] = [] |
| 290 | + stops: list[int | None] = [] |
| 291 | + for key, byte_range in key_ranges: |
| 292 | + paths.append(_dereference_path(self.path, key)) |
| 293 | + if byte_range is None: |
| 294 | + starts.append(None) |
| 295 | + stops.append(None) |
| 296 | + elif isinstance(byte_range, RangeByteRequest): |
| 297 | + starts.append(byte_range.start) |
| 298 | + stops.append(byte_range.end) |
| 299 | + elif isinstance(byte_range, OffsetByteRequest): |
| 300 | + starts.append(byte_range.offset) |
| 301 | + stops.append(None) |
| 302 | + elif isinstance(byte_range, SuffixByteRequest): |
| 303 | + starts.append(-byte_range.suffix) |
| 304 | + stops.append(None) |
| 305 | + else: |
| 306 | + raise ValueError(f"Unexpected byte_range, got {byte_range}.") |
288 | 307 | else:
|
289 | 308 | return []
|
290 | 309 | # TODO: expectations for exceptions or missing keys?
|
291 |
| - res = await self.fs._cat_ranges(list(paths), starts, stops, on_error="return") |
| 310 | + res = await self.fs._cat_ranges(paths, starts, stops, on_error="return") |
292 | 311 | # the following is an s3-specific condition we probably don't want to leak
|
293 | 312 | res = [b"" if (isinstance(r, OSError) and "not satisfiable" in str(r)) else r for r in res]
|
294 | 313 | for r in res:
|
|
0 commit comments