Skip to content

Fix AsyncGroup.create_dataset() dtype handling and optimize tests #3050 #3059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/3050.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- Fixed potential error in `AsyncGroup.create_dataset()` where `dtype` argument could be missing when calling `create_array()`
29 changes: 21 additions & 8 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,8 +1155,11 @@
# create_dataset in zarr 2.x requires shape but not dtype if data is
# provided. Allow this configuration by inferring dtype from data if
# necessary and passing it to create_array
if "dtype" not in kwargs and data is not None:
kwargs["dtype"] = data.dtype
if "dtype" not in kwargs:
if data is not None:
kwargs["dtype"] = data.dtype
else:
raise ValueError("dtype must be provided if data is None")

Check warning on line 1162 in src/zarr/core/group.py

View check run for this annotation

Codecov / codecov/patch

src/zarr/core/group.py#L1162

Added line #L1162 was not covered by tests
array = await self.create_array(name, shape=shape, **kwargs)
if data is not None:
await array.setitem(slice(None), data)
Expand Down Expand Up @@ -2544,12 +2547,17 @@
----------
name : str
Array name.
**kwargs :
See :func:`zarr.Group.create_dataset`.
shape : int or tuple of ints
Array shape.
dtype : str or dtype, optional
NumPy dtype.
exact : bool, optional
If True, require `dtype` to match exactly. If false, require
`dtype` can be cast from array dtype.

Returns
-------
a : Array
a : AsyncArray
"""
return Array(self._sync(self._async_group.require_array(name, shape=shape, **kwargs)))

Expand All @@ -2562,12 +2570,17 @@
----------
name : str
Array name.
**kwargs :
See :func:`zarr.Group.create_array`.
shape : int or tuple of ints
Array shape.
dtype : str or dtype, optional
NumPy dtype.
exact : bool, optional
If True, require `dtype` to match exactly. If false, require
`dtype` can be cast from array dtype.

Returns
-------
a : Array
a : AsyncArray
"""
return Array(self._sync(self._async_group.require_array(name, shape=shape, **kwargs)))

Expand Down
11 changes: 7 additions & 4 deletions tests/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import hypothesis.extra.numpy as npst
import hypothesis.strategies as st
from hypothesis import assume, given, settings
from hypothesis import HealthCheck, assume, given, settings

from zarr.abc.store import Store
from zarr.core.common import ZARR_JSON, ZARRAY_JSON, ZATTRS_JSON
Expand Down Expand Up @@ -76,6 +76,7 @@ def deep_equal(a: Any, b: Any) -> bool:
return a == b


@settings(deadline=None) # Increased from default 200ms to None
@given(data=st.data(), zarr_format=zarr_formats)
def test_array_roundtrip(data: st.DataObject, zarr_format: int) -> None:
nparray = data.draw(numpy_arrays(zarr_formats=st.just(zarr_format)))
Expand Down Expand Up @@ -117,10 +118,11 @@ def test_basic_indexing(data: st.DataObject) -> None:
assert_array_equal(nparray, zarray[:])


@settings(deadline=None, suppress_health_check=[HealthCheck.too_slow])
@given(data=st.data())
def test_oindex(data: st.DataObject) -> None:
# integer_array_indices can't handle 0-size dimensions.
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=3, min_side=1, max_side=8)))
nparray = zarray[:]

zindexer, npindexer = data.draw(orthogonal_indices(shape=nparray.shape))
Expand All @@ -138,15 +140,16 @@ def test_oindex(data: st.DataObject) -> None:
assert_array_equal(nparray, zarray[:])


@settings(deadline=None, suppress_health_check=[HealthCheck.too_slow])
@given(data=st.data())
def test_vindex(data: st.DataObject) -> None:
# integer_array_indices can't handle 0-size dimensions.
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=4, min_side=1)))
zarray = data.draw(simple_arrays(shapes=npst.array_shapes(max_dims=3, min_side=1, max_side=8)))
nparray = zarray[:]

indexer = data.draw(
npst.integer_array_indices(
shape=nparray.shape, result_shape=npst.array_shapes(min_side=1, max_dims=None)
shape=nparray.shape, result_shape=npst.array_shapes(min_side=1, max_dims=2, max_side=8)
)
)
actual = zarray.vindex[indexer]
Expand Down