Skip to content

Commit e46e978

Browse files
committed
Add array and axis testing to repeat()
Still need to add values testing.
1 parent b4c0823 commit e46e978

File tree

1 file changed

+39
-14
lines changed

1 file changed

+39
-14
lines changed

array_api_tests/test_manipulation_functions.py

+39-14
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,45 @@ def test_permute_dims(x, axes):
287287
out_indices=permuted_indices)
288288

289289

290+
@pytest.mark.min_version("2023.12")
291+
@given(
292+
x=hh.arrays(dtype=hh.all_dtypes, shape=shared_shapes(min_dims=1)),
293+
kw=hh.kwargs(
294+
axis=st.none() | shared_shapes(min_dims=1).flatmap(
295+
lambda s: st.integers(-len(s), len(s) - 1)
296+
)
297+
),
298+
data=st.data(),
299+
)
300+
def test_repeat(x, kw, data):
301+
shape = x.shape
302+
axis = kw.get("axis", None)
303+
dim = math.prod(shape) if axis is None else shape[axis]
304+
repeat_strat = st.integers(1, 4)
305+
repeats = data.draw(repeat_strat
306+
| hh.arrays(dtype=hh.int_dtypes, elements=repeat_strat,
307+
shape=st.sampled_from([(1,), (dim,)])),
308+
label="repeats")
309+
if isinstance(repeats, int):
310+
n_repitions = dim*repeats
311+
else:
312+
if repeats.shape == (1,):
313+
n_repitions = dim*repeats[0]
314+
else:
315+
n_repitions = int(xp.sum(repeats))
316+
317+
out = xp.repeat(x, repeats, **kw)
318+
ph.assert_dtype("repeat", in_dtype=x.dtype, out_dtype=out.dtype)
319+
if axis is None:
320+
expected_shape = (n_repitions,)
321+
else:
322+
expected_shape = list(shape)
323+
expected_shape[axis] = n_repitions
324+
expected_shape = tuple(expected_shape)
325+
ph.assert_shape("repeat", out_shape=out.shape, expected=expected_shape)
326+
# TODO: values testing
327+
328+
290329
@st.composite
291330
def reshape_shapes(draw, shape):
292331
size = 1 if len(shape) == 0 else math.prod(shape)
@@ -298,20 +337,6 @@ def reshape_shapes(draw, shape):
298337
return tuple(rshape)
299338

300339

301-
@pytest.mark.min_version("2023.12")
302-
@given(
303-
x=hh.arrays(dtype=hh.all_dtypes, shape=hh.shapes(min_dims=1)),
304-
repeats=st.integers(1, 4),
305-
)
306-
def test_repeat(x, repeats):
307-
# TODO: test array repeats and non-None axis, adjust shape and value testing accordingly
308-
out = xp.repeat(x, repeats)
309-
ph.assert_dtype("repeat", in_dtype=x.dtype, out_dtype=out.dtype)
310-
expected_shape = (math.prod(x.shape) * repeats,)
311-
ph.assert_shape("repeat", out_shape=out.shape, expected=expected_shape)
312-
# TODO: values testing
313-
314-
315340
@pytest.mark.unvectorized
316341
@pytest.mark.skip("flaky") # TODO: fix!
317342
@given(

0 commit comments

Comments
 (0)