Skip to content

Fix indexing with bools #1968

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
38578dd
test z[selection] for orthogonal selection
brokkoli71 Jun 15, 2024
7b6470f
include boolean indexing in is_pure_orthogonal_indexing
brokkoli71 Jun 15, 2024
badf818
Revert "test z[selection] for orthogonal selection"
brokkoli71 Jun 15, 2024
dd764e2
add test_indexing_equals_numpy
brokkoli71 Jun 15, 2024
26b920c
extend _test_get_mask_selection for square bracket notation
brokkoli71 Jun 15, 2024
782a712
fix is_pure_fancy_indexing for mask selection
brokkoli71 Jun 15, 2024
a94b995
add test_orthogonal_bool_indexing_like_numpy_ix
brokkoli71 Jun 15, 2024
97a06f0
fix for mypy
brokkoli71 Jun 15, 2024
85ca73f
ruff format
brokkoli71 Jun 17, 2024
46a2d4b
fix is_pure_orthogonal_indexing
brokkoli71 Jun 17, 2024
9e7b53c
fix is_pure_orthogonal_indexing
brokkoli71 Jun 17, 2024
b1a2ccf
replace deprecated ~ by not
brokkoli71 Jun 17, 2024
7849f41
restrict is_integer to not bool
brokkoli71 Jun 17, 2024
d52b9d0
correct typing
brokkoli71 Jun 19, 2024
b18697f
Merge branch 'zarr-developers:v3' into fix-indexing-with-bools
brokkoli71 Jun 19, 2024
1b27e65
correct typing
brokkoli71 Jun 19, 2024
ea6eddb
check if bool list has only bools
brokkoli71 Jun 19, 2024
31c1e7a
check if bool list has only bools
brokkoli71 Jun 19, 2024
d525d7f
fix list unpacking in test for python3.10
brokkoli71 Jun 19, 2024
ef7492b
Apply spelling suggestions from code review
brokkoli71 Jun 24, 2024
3355a02
Merge branch 'v3' into fix-indexing-with-bools
brokkoli71 Jun 24, 2024
4f0321f
fix mypy
brokkoli71 Jun 27, 2024
6812f73
Merge branch 'refs/heads/master' into fix-indexing-with-bools
brokkoli71 Jun 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/zarr/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def attrs(self) -> dict[str, JSON]:

@property
def read_only(self) -> bool:
return bool(~self.store_path.store.writeable)
return bool(not self.store_path.store.writeable)

@property
def path(self) -> str:
Expand Down
55 changes: 34 additions & 21 deletions src/zarr/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,21 +87,23 @@ def ceildiv(a: float, b: float) -> int:


def is_integer(x: Any) -> TypeGuard[int]:
"""True if x is an integer (both pure Python or NumPy).
"""True if x is an integer (both pure Python or NumPy)."""
return isinstance(x, numbers.Integral) and not is_bool(x)

Note that Python's bool is considered an integer too.
"""
return isinstance(x, numbers.Integral)

def is_bool(x: Any) -> TypeGuard[bool | np.bool_]:
"""True if x is a boolean (both pure Python or NumPy)."""
return type(x) in [bool, np.bool_]


def is_integer_list(x: Any) -> TypeGuard[list[int]]:
"""True if x is a list of integers.
"""True if x is a list of integers."""
return isinstance(x, list) and len(x) > 0 and all(is_integer(i) for i in x)

This function assumes ie *does not check* that all elements of the list
have the same type. Mixed type lists will result in other errors that will
bubble up anyway.
"""
return isinstance(x, list) and len(x) > 0 and is_integer(x[0])

def is_bool_list(x: Any) -> TypeGuard[list[bool | np.bool_]]:
"""True if x is a list of boolean."""
return isinstance(x, list) and len(x) > 0 and all(is_bool(i) for i in x)


def is_integer_array(x: Any, ndim: int | None = None) -> TypeGuard[npt.NDArray[np.intp]]:
Expand All @@ -118,6 +120,10 @@ def is_bool_array(x: Any, ndim: int | None = None) -> TypeGuard[npt.NDArray[np.b
return t


def is_int_or_bool_iterable(x: Any) -> bool:
return is_integer_list(x) or is_integer_array(x) or is_bool_array(x) or is_bool_list(x)


def is_scalar(value: Any, dtype: np.dtype[Any]) -> bool:
if np.isscalar(value):
return True
Expand All @@ -129,7 +135,7 @@ def is_scalar(value: Any, dtype: np.dtype[Any]) -> bool:


def is_pure_fancy_indexing(selection: Any, ndim: int) -> bool:
"""Check whether a selection contains only scalars or integer array-likes.
"""Check whether a selection contains only scalars or integer/bool array-likes.

Parameters
----------
Expand All @@ -142,9 +148,14 @@ def is_pure_fancy_indexing(selection: Any, ndim: int) -> bool:
True if the selection is a pure fancy indexing expression (ie not mixed
with boolean or slices).
"""
if is_bool_array(selection):
# is mask selection
return True

if ndim == 1:
if is_integer_list(selection) or is_integer_array(selection):
if is_integer_list(selection) or is_integer_array(selection) or is_bool_list(selection):
return True

# if not, we go through the normal path below, because a 1-tuple
# of integers is also allowed.
no_slicing = (
Expand All @@ -166,19 +177,21 @@ def is_pure_orthogonal_indexing(selection: Selection, ndim: int) -> TypeGuard[Or
if not ndim:
return False

# Case 1: Selection is a single iterable of integers
if is_integer_list(selection) or is_integer_array(selection, ndim=1):
selection_normalized = (selection,) if not isinstance(selection, tuple) else selection

# Case 1: Selection contains of iterable of integers or boolean
if len(selection_normalized) == ndim and all(
is_int_or_bool_iterable(s) for s in selection_normalized
):
return True

# Case two: selection contains either zero or one integer iterables.
# Case 2: selection contains either zero or one integer iterables.
# All other selection elements are slices or integers
return (
isinstance(selection, tuple)
and len(selection) == ndim
and sum(is_integer_list(elem) or is_integer_array(elem) for elem in selection) <= 1
len(selection_normalized) <= ndim
and sum(is_int_or_bool_iterable(s) for s in selection_normalized) <= 1
and all(
is_integer_list(elem) or is_integer_array(elem) or isinstance(elem, int | slice)
for elem in selection
is_int_or_bool_iterable(s) or isinstance(s, int | slice) for s in selection_normalized
)
)

Expand Down Expand Up @@ -1023,7 +1036,7 @@ def __init__(self, selection: CoordinateSelection, shape: ChunkCoords, chunk_gri
# flatten selection
selection_broadcast = tuple(dim_sel.reshape(-1) for dim_sel in selection_broadcast)
chunks_multi_index_broadcast = tuple(
dim_chunks.reshape(-1) for dim_chunks in chunks_multi_index_broadcast
[dim_chunks.reshape(-1) for dim_chunks in chunks_multi_index_broadcast]
)

# ravel chunk indices
Expand Down
55 changes: 53 additions & 2 deletions tests/v3/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def test_get_basic_selection_0d(store: StorePath, use_out: bool, value: Any, dty
slice(50, 150, 10),
]


basic_selections_1d_bad = [
# only positive step supported
slice(None, None, -1),
Expand Down Expand Up @@ -305,7 +304,6 @@ def test_get_basic_selection_1d(store: StorePath):
(Ellipsis, slice(None), slice(None)),
]


basic_selections_2d_bad = [
# bad stuff
2.3,
Expand Down Expand Up @@ -1272,6 +1270,8 @@ def _test_get_mask_selection(a, z, selection):
assert_array_equal(expect, actual)
actual = z.vindex[selection]
assert_array_equal(expect, actual)
actual = z[selection]
assert_array_equal(expect, actual)


mask_selections_1d_bad = [
Expand Down Expand Up @@ -1344,6 +1344,9 @@ def _test_set_mask_selection(v, a, z, selection):
z[:] = 0
z.vindex[selection] = v[selection]
assert_array_equal(a, z[:])
z[:] = 0
z[selection] = v[selection]
assert_array_equal(a, z[:])


def test_set_mask_selection_1d(store: StorePath):
Expand Down Expand Up @@ -1726,3 +1729,51 @@ def test_accessed_chunks(shape, chunks, ops):
) == 1
# Check that no other chunks were accessed
assert len(delta_counts) == 0


@pytest.mark.parametrize(
"selection",
[
# basic selection
[...],
[1, ...],
[slice(None)],
[1, 3],
[[1, 2, 3], 9],
[np.arange(1000)],
[slice(5, 15)],
[slice(2, 4), 4],
[[1, 3]],
# mask selection
[np.tile([True, False], (1000, 5))],
[np.full((1000, 10), False)],
# coordinate selection
[[1, 2, 3, 4], [5, 6, 7, 8]],
[[100, 200, 300], [4, 5, 6]],
],
)
def test_indexing_equals_numpy(store, selection):
a = np.arange(10000, dtype=int).reshape(1000, 10)
z = zarr_array_from_numpy_array(store, a, chunk_shape=(300, 3))
# note: in python 3.10 a[*selection] is not valid unpacking syntax
expected = a[(*selection,)]
actual = z[(*selection,)]
assert_array_equal(expected, actual, err_msg=f"selection: {selection}")


@pytest.mark.parametrize(
"selection",
[
[np.tile([True, False], 500), np.tile([True, False], 5)],
[np.full(1000, False), np.tile([True, False], 5)],
[np.full(1000, True), np.full(10, True)],
[np.full(1000, True), [True, False] * 5],
],
)
def test_orthogonal_bool_indexing_like_numpy_ix(store, selection):
a = np.arange(10000, dtype=int).reshape(1000, 10)
z = zarr_array_from_numpy_array(store, a, chunk_shape=(300, 3))
expected = a[np.ix_(*selection)]
# note: in python 3.10 z[*selection] is not valid unpacking syntax
actual = z[(*selection,)]
assert_array_equal(expected, actual, err_msg=f"{selection=}")