-
-
Notifications
You must be signed in to change notification settings - Fork 330
groundwork for V3 group tests #1743
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e492be2
b7f66c7
c7b333a
a64b342
3d11fc0
cf34afc
16cb226
b28eaee
0215b97
e33ca6e
f237172
4605446
6e5ded9
f81b6f3
c768ca8
c405506
5b96554
d77d55f
05f89c6
b2f9bee
88f4c46
624ff77
b762fa4
5574226
bcd5c7d
d634cbf
9b9d146
d264f71
0741bed
eb8a535
ee2e233
01eec6f
acae77a
ebe1548
3dce5e3
7f82fdf
1655ff8
e8514b1
dacacc8
5d2a532
06d8b04
d8749de
a2fe4e0
eed03c8
75f75b1
8a14e3b
459bb42
8ef3fec
b5a7698
49f1505
c62e686
b1767ce
8e61ff9
b996aff
923fc18
7c6cfb4
17997f8
1140b5c
22c0889
57f564f
a821808
6803f26
d236b24
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,21 +5,19 @@ | |
import asyncio | ||
import json | ||
import logging | ||
import numpy.typing as npt | ||
|
||
if TYPE_CHECKING: | ||
from typing import ( | ||
Any, | ||
AsyncGenerator, | ||
Literal, | ||
AsyncIterator, | ||
) | ||
from typing import Any, AsyncGenerator, Literal, Iterable | ||
from zarr.abc.codec import Codec | ||
from zarr.abc.metadata import Metadata | ||
|
||
from zarr.array import AsyncArray, Array | ||
from zarr.attributes import Attributes | ||
from zarr.common import ZARR_JSON, ZARRAY_JSON, ZATTRS_JSON, ZGROUP_JSON | ||
from zarr.common import ZARR_JSON, ZARRAY_JSON, ZATTRS_JSON, ZGROUP_JSON, ChunkCoords | ||
from zarr.store import StoreLike, StorePath, make_store_path | ||
from zarr.sync import SyncMixin, sync | ||
from typing import overload | ||
|
||
logger = logging.getLogger("zarr.group") | ||
|
||
|
@@ -41,6 +39,26 @@ def parse_attributes(data: Any) -> dict[str, Any]: | |
raise TypeError(msg) | ||
|
||
|
||
@overload | ||
def _parse_async_node(node: AsyncArray) -> Array: ... | ||
|
||
|
||
@overload | ||
def _parse_async_node(node: AsyncGroup) -> Group: ... | ||
|
||
|
||
def _parse_async_node(node: AsyncArray | AsyncGroup) -> Array | Group: | ||
""" | ||
Wrap an AsyncArray in an Array, or an AsyncGroup in a Group. | ||
""" | ||
if isinstance(node, AsyncArray): | ||
return Array(node) | ||
elif isinstance(node, AsyncGroup): | ||
return Group(node) | ||
else: | ||
assert False | ||
d-v-b marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
@dataclass(frozen=True) | ||
class GroupMetadata(Metadata): | ||
attributes: dict[str, Any] = field(default_factory=dict) | ||
|
@@ -53,7 +71,7 @@ def to_bytes(self) -> dict[str, bytes]: | |
return {ZARR_JSON: json.dumps(self.to_dict()).encode()} | ||
else: | ||
return { | ||
ZGROUP_JSON: json.dumps({"zarr_format": 2}).encode(), | ||
ZGROUP_JSON: json.dumps({"zarr_format": self.zarr_format}).encode(), | ||
ZATTRS_JSON: json.dumps(self.attributes).encode(), | ||
} | ||
|
||
|
@@ -113,11 +131,11 @@ async def open( | |
(store_path / ZGROUP_JSON).get(), (store_path / ZATTRS_JSON).get() | ||
) | ||
if zgroup_bytes is None: | ||
raise KeyError(store_path) # filenotfounderror? | ||
raise FileNotFoundError(store_path) | ||
elif zarr_format == 3: | ||
zarr_json_bytes = await (store_path / ZARR_JSON).get() | ||
if zarr_json_bytes is None: | ||
raise KeyError(store_path) # filenotfounderror? | ||
raise FileNotFoundError(store_path) | ||
elif zarr_format is None: | ||
zarr_json_bytes, zgroup_bytes, zattrs_bytes = await asyncio.gather( | ||
(store_path / ZARR_JSON).get(), | ||
|
@@ -168,17 +186,14 @@ async def getitem( | |
key: str, | ||
) -> AsyncArray | AsyncGroup: | ||
store_path = self.store_path / key | ||
logger.warning("key=%s, store_path=%s", key, store_path) | ||
|
||
# Note: | ||
# in zarr-python v2, we first check if `key` references an Array, else if `key` references | ||
# a group,using standalone `contains_array` and `contains_group` functions. These functions | ||
# are reusable, but for v3 they would perform redundant I/O operations. | ||
# Not clear how much of that strategy we want to keep here. | ||
|
||
# if `key` names an object in storage, it cannot be an array or group | ||
if await store_path.exists(): | ||
raise KeyError(key) | ||
|
||
if self.metadata.zarr_format == 3: | ||
zarr_json_bytes = await (store_path / ZARR_JSON).get() | ||
if zarr_json_bytes is None: | ||
|
@@ -248,16 +263,42 @@ def attrs(self): | |
def info(self): | ||
return self.metadata.info | ||
|
||
async def create_group(self, path: str, **kwargs) -> AsyncGroup: | ||
async def create_group( | ||
self, path: str, exists_ok: bool = False, attributes: dict[str, Any] = {} | ||
) -> AsyncGroup: | ||
return await type(self).create( | ||
self.store_path / path, | ||
**kwargs, | ||
attributes=attributes, | ||
exists_ok=exists_ok, | ||
zarr_format=self.metadata.zarr_format, | ||
) | ||
|
||
async def create_array(self, path: str, **kwargs) -> AsyncArray: | ||
async def create_array( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need to redo that in #1857 |
||
self, | ||
path: str, | ||
shape: ChunkCoords, | ||
dtype: npt.DTypeLike, | ||
chunk_shape: ChunkCoords, | ||
fill_value: Any | None = None, | ||
chunk_key_encoding: tuple[Literal["default"], Literal[".", "/"]] | ||
| tuple[Literal["v2"], Literal[".", "/"]] = ("default", "/"), | ||
codecs: Iterable[Codec | dict[str, Any]] | None = None, | ||
dimension_names: Iterable[str] | None = None, | ||
attributes: dict[str, Any] | None = None, | ||
exists_ok: bool = False, | ||
) -> AsyncArray: | ||
return await AsyncArray.create( | ||
self.store_path / path, | ||
**kwargs, | ||
shape=shape, | ||
dtype=dtype, | ||
chunk_shape=chunk_shape, | ||
fill_value=fill_value, | ||
chunk_key_encoding=chunk_key_encoding, | ||
codecs=codecs, | ||
dimension_names=dimension_names, | ||
attributes=attributes, | ||
exists_ok=exists_ok, | ||
zarr_format=self.metadata.zarr_format, | ||
) | ||
|
||
async def update_attributes(self, new_attributes: dict[str, Any]): | ||
|
@@ -348,7 +389,7 @@ async def array_keys(self) -> AsyncGenerator[str, None]: | |
yield key | ||
|
||
# todo: decide if this method should be separate from `array_keys` | ||
async def arrays(self) -> AsyncIterator[AsyncArray]: | ||
async def arrays(self) -> AsyncGenerator[AsyncArray, None]: | ||
async for key, value in self.members(): | ||
if isinstance(value, AsyncArray): | ||
yield value | ||
|
@@ -472,19 +513,13 @@ def nmembers(self) -> int: | |
@property | ||
def members(self) -> tuple[tuple[str, Array | Group], ...]: | ||
""" | ||
Return the sub-arrays and sub-groups of this group as a `tuple` of (name, array | group) | ||
Return the sub-arrays and sub-groups of this group as a tuple of (name, array | group) | ||
pairs | ||
""" | ||
_members: list[tuple[str, AsyncArray | AsyncGroup]] = self._sync_iter( | ||
self._async_group.members() | ||
) | ||
ret: list[tuple[str, Array | Group]] = [] | ||
for key, value in _members: | ||
if isinstance(value, AsyncArray): | ||
ret.append((key, Array(value))) | ||
else: | ||
ret.append((key, Group(value))) | ||
return tuple(ret) | ||
_members = self._sync_iter(self._async_group.members()) | ||
|
||
result = tuple(map(lambda kv: (kv[0], _parse_async_node(kv[1])), _members)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
return result | ||
|
||
def __contains__(self, member) -> bool: | ||
return self._sync(self._async_group.contains(member)) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -88,4 +88,4 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]: | |
else: | ||
for key in self._store_dict: | ||
if key.startswith(prefix + "/") and key != prefix: | ||
yield key.strip(prefix + "/").split("/")[0] | ||
yield key.removeprefix(prefix + "/").split("/")[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this fixed a subtle bug in the >>> 'a/b/a/b/b/x'.strip('a/b')
'x' vs >>> 'a/b/a/b/b/x'.removeprefix('a/b')
'/a/b/b/x' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this exists as something that can be used with
map
over an iterable ofAsyncArray
/AsyncGroup
to transform each one to its synchronous counterpart. Previously we used a tuple around a generator expression, but mypy couldn't type check this correctly.