Skip to content

Inplace shapes #91

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions array_api_tests/meta/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
from hypothesis import strategies as st

from .. import _array_module as xp
from .. import xps
from .. import dtype_helpers as dh
from .. import shape_helpers as sh
from .. import xps
from ..test_creation_functions import frange
from ..test_manipulation_functions import roll_ndindex
from ..test_operators_and_elementwise_functions import mock_int_dtype
from ..test_operators_and_elementwise_functions import (
mock_int_dtype,
oneway_broadcastable_shapes,
oneway_promotable_dtypes,
)
from ..test_signatures import extension_module


Expand Down Expand Up @@ -115,3 +120,13 @@ def test_int_to_dtype(x, dtype):
except OverflowError:
reject()
assert mock_int_dtype(x, dtype) == d


@given(oneway_promotable_dtypes(dh.all_dtypes))
def test_oneway_promotable_dtypes(D):
assert D.result_dtype == dh.result_type(*D)


@given(oneway_broadcastable_shapes())
def test_oneway_broadcastable_shapes(S):
assert S.result_shape == sh.broadcast_shapes(*S)
111 changes: 78 additions & 33 deletions array_api_tests/test_operators_and_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,46 @@ def boolean_and_all_integer_dtypes() -> st.SearchStrategy[DataType]:
return xps.boolean_dtypes() | all_integer_dtypes()


class OnewayPromotableDtypes(NamedTuple):
input_dtype: DataType
result_dtype: DataType


@st.composite
def oneway_promotable_dtypes(
draw, dtypes: List[DataType]
) -> st.SearchStrategy[OnewayPromotableDtypes]:
"""Return a strategy for input dtypes that promote to result dtypes."""
d1, d2 = draw(hh.mutually_promotable_dtypes(dtypes=dtypes))
result_dtype = dh.result_type(d1, d2)
if d1 == result_dtype:
return OnewayPromotableDtypes(d2, d1)
elif d2 == result_dtype:
return OnewayPromotableDtypes(d1, d2)
else:
reject()


class OnewayBroadcastableShapes(NamedTuple):
input_shape: Shape
result_shape: Shape


@st.composite
def oneway_broadcastable_shapes(draw) -> st.SearchStrategy[OnewayBroadcastableShapes]:
"""Return a strategy for input shapes that broadcast to result shapes."""
result_shape = draw(hh.shapes(min_side=1))
input_shape = draw(
xps.broadcastable_shapes(
result_shape,
# Override defaults so bad shapes are less likely to be generated.
max_side=None if result_shape == () else max(result_shape),
max_dims=len(result_shape),
).filter(lambda s: sh.broadcast_shapes(result_shape, s) == result_shape)
)
return OnewayBroadcastableShapes(input_shape, result_shape)


def mock_int_dtype(n: int, dtype: DataType) -> int:
"""Returns equivalent of `n` that mocks `dtype` behaviour."""
nbits = dh.dtype_nbits[dtype]
Expand Down Expand Up @@ -306,8 +346,14 @@ def __repr__(self):


def make_binary_params(
elwise_func_name: str, dtypes_strat: st.SearchStrategy[DataType]
elwise_func_name: str, dtypes: List[DataType]
) -> List[Param[BinaryParamContext]]:
if hh.FILTER_UNDEFINED_DTYPES:
dtypes = [d for d in dtypes if not isinstance(d, xp._UndefinedStub)]
shared_oneway_dtypes = st.shared(oneway_promotable_dtypes(dtypes))
left_dtypes = shared_oneway_dtypes.map(lambda D: D.result_dtype)
right_dtypes = shared_oneway_dtypes.map(lambda D: D.input_dtype)

def make_param(
func_name: str, func_type: FuncType, right_is_scalar: bool
) -> Param[BinaryParamContext]:
Expand All @@ -318,26 +364,29 @@ def make_param(
left_sym = "x1"
right_sym = "x2"

shared_dtypes = st.shared(dtypes_strat)
if right_is_scalar:
left_strat = xps.arrays(dtype=shared_dtypes, shape=hh.shapes(**shapes_kw))
right_strat = shared_dtypes.flatmap(
lambda d: xps.from_dtype(d, **finite_kw)
)
left_strat = xps.arrays(dtype=left_dtypes, shape=hh.shapes(**shapes_kw))
right_strat = right_dtypes.flatmap(lambda d: xps.from_dtype(d, **finite_kw))
else:
if func_type is FuncType.IOP:
shared_shapes = st.shared(hh.shapes(**shapes_kw))
left_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes)
right_strat = xps.arrays(dtype=shared_dtypes, shape=shared_shapes)
shared_oneway_shapes = st.shared(oneway_broadcastable_shapes())
left_strat = xps.arrays(
dtype=left_dtypes,
shape=shared_oneway_shapes.map(lambda S: S.result_shape),
)
right_strat = xps.arrays(
dtype=right_dtypes,
shape=shared_oneway_shapes.map(lambda S: S.input_shape),
)
else:
mutual_shapes = st.shared(
hh.mutually_broadcastable_shapes(2, **shapes_kw)
)
left_strat = xps.arrays(
dtype=shared_dtypes, shape=mutual_shapes.map(lambda pair: pair[0])
dtype=left_dtypes, shape=mutual_shapes.map(lambda pair: pair[0])
)
right_strat = xps.arrays(
dtype=shared_dtypes, shape=mutual_shapes.map(lambda pair: pair[1])
dtype=right_dtypes, shape=mutual_shapes.map(lambda pair: pair[1])
)

if func_type is FuncType.FUNC:
Expand Down Expand Up @@ -514,7 +563,7 @@ def test_acosh(x):
)


@pytest.mark.parametrize("ctx,", make_binary_params("add", xps.numeric_dtypes()))
@pytest.mark.parametrize("ctx,", make_binary_params("add", dh.numeric_dtypes))
@given(data=st.data())
def test_add(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand Down Expand Up @@ -579,7 +628,7 @@ def test_atanh(x):


@pytest.mark.parametrize(
"ctx", make_binary_params("bitwise_and", boolean_and_all_integer_dtypes())
"ctx", make_binary_params("bitwise_and", dh.bool_and_all_int_dtypes)
)
@given(data=st.data())
def test_bitwise_and(ctx, data):
Expand All @@ -598,7 +647,7 @@ def test_bitwise_and(ctx, data):


@pytest.mark.parametrize(
"ctx", make_binary_params("bitwise_left_shift", all_integer_dtypes())
"ctx", make_binary_params("bitwise_left_shift", dh.all_int_dtypes)
)
@given(data=st.data())
def test_bitwise_left_shift(ctx, data):
Expand Down Expand Up @@ -638,7 +687,7 @@ def test_bitwise_invert(ctx, data):


@pytest.mark.parametrize(
"ctx", make_binary_params("bitwise_or", boolean_and_all_integer_dtypes())
"ctx", make_binary_params("bitwise_or", dh.bool_and_all_int_dtypes)
)
@given(data=st.data())
def test_bitwise_or(ctx, data):
Expand All @@ -657,7 +706,7 @@ def test_bitwise_or(ctx, data):


@pytest.mark.parametrize(
"ctx", make_binary_params("bitwise_right_shift", all_integer_dtypes())
"ctx", make_binary_params("bitwise_right_shift", dh.all_int_dtypes)
)
@given(data=st.data())
def test_bitwise_right_shift(ctx, data):
Expand All @@ -678,7 +727,7 @@ def test_bitwise_right_shift(ctx, data):


@pytest.mark.parametrize(
"ctx", make_binary_params("bitwise_xor", boolean_and_all_integer_dtypes())
"ctx", make_binary_params("bitwise_xor", dh.bool_and_all_int_dtypes)
)
@given(data=st.data())
def test_bitwise_xor(ctx, data):
Expand Down Expand Up @@ -720,7 +769,7 @@ def test_cosh(x):
unary_assert_against_refimpl("cosh", x, out, math.cosh)


@pytest.mark.parametrize("ctx", make_binary_params("divide", xps.floating_dtypes()))
@pytest.mark.parametrize("ctx", make_binary_params("divide", dh.float_dtypes))
@given(data=st.data())
def test_divide(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand All @@ -743,7 +792,7 @@ def test_divide(ctx, data):
)


@pytest.mark.parametrize("ctx", make_binary_params("equal", xps.scalar_dtypes()))
@pytest.mark.parametrize("ctx", make_binary_params("equal", dh.all_dtypes))
@given(data=st.data())
def test_equal(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand Down Expand Up @@ -795,9 +844,7 @@ def test_floor(x):
unary_assert_against_refimpl("floor", x, out, math.floor, strict_check=True)


@pytest.mark.parametrize(
"ctx", make_binary_params("floor_divide", xps.numeric_dtypes())
)
@pytest.mark.parametrize("ctx", make_binary_params("floor_divide", dh.numeric_dtypes))
@given(data=st.data())
def test_floor_divide(ctx, data):
left = data.draw(
Expand All @@ -816,7 +863,7 @@ def test_floor_divide(ctx, data):
binary_param_assert_against_refimpl(ctx, left, right, res, "//", operator.floordiv)


@pytest.mark.parametrize("ctx", make_binary_params("greater", xps.numeric_dtypes()))
@pytest.mark.parametrize("ctx", make_binary_params("greater", dh.numeric_dtypes))
@given(data=st.data())
def test_greater(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand All @@ -836,9 +883,7 @@ def test_greater(ctx, data):
)


@pytest.mark.parametrize(
"ctx", make_binary_params("greater_equal", xps.numeric_dtypes())
)
@pytest.mark.parametrize("ctx", make_binary_params("greater_equal", dh.numeric_dtypes))
@given(data=st.data())
def test_greater_equal(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand Down Expand Up @@ -882,7 +927,7 @@ def test_isnan(x):
unary_assert_against_refimpl("isnan", x, out, math.isnan, res_stype=bool)


@pytest.mark.parametrize("ctx", make_binary_params("less", xps.numeric_dtypes()))
@pytest.mark.parametrize("ctx", make_binary_params("less", dh.numeric_dtypes))
@given(data=st.data())
def test_less(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand All @@ -902,7 +947,7 @@ def test_less(ctx, data):
)


@pytest.mark.parametrize("ctx", make_binary_params("less_equal", xps.numeric_dtypes()))
@pytest.mark.parametrize("ctx", make_binary_params("less_equal", dh.numeric_dtypes))
@given(data=st.data())
def test_less_equal(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand Down Expand Up @@ -1014,7 +1059,7 @@ def test_logical_xor(x1, x2):
)


@pytest.mark.parametrize("ctx", make_binary_params("multiply", xps.numeric_dtypes()))
@pytest.mark.parametrize("ctx", make_binary_params("multiply", dh.numeric_dtypes))
@given(data=st.data())
def test_multiply(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand Down Expand Up @@ -1047,7 +1092,7 @@ def test_negative(ctx, data):
)


@pytest.mark.parametrize("ctx", make_binary_params("not_equal", xps.scalar_dtypes()))
@pytest.mark.parametrize("ctx", make_binary_params("not_equal", dh.all_dtypes))
@given(data=st.data())
def test_not_equal(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand Down Expand Up @@ -1079,7 +1124,7 @@ def test_positive(ctx, data):
ph.assert_array(ctx.func_name, out, x)


@pytest.mark.parametrize("ctx", make_binary_params("pow", xps.numeric_dtypes()))
@pytest.mark.parametrize("ctx", make_binary_params("pow", dh.numeric_dtypes))
@given(data=st.data())
def test_pow(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand All @@ -1103,7 +1148,7 @@ def test_pow(ctx, data):
)


@pytest.mark.parametrize("ctx", make_binary_params("remainder", xps.numeric_dtypes()))
@pytest.mark.parametrize("ctx", make_binary_params("remainder", dh.numeric_dtypes))
@given(data=st.data())
def test_remainder(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand Down Expand Up @@ -1174,7 +1219,7 @@ def test_sqrt(x):
)


@pytest.mark.parametrize("ctx", make_binary_params("subtract", xps.numeric_dtypes()))
@pytest.mark.parametrize("ctx", make_binary_params("subtract", dh.numeric_dtypes))
@given(data=st.data())
def test_subtract(ctx, data):
left = data.draw(ctx.left_strat, label=ctx.left_sym)
Expand Down