Skip to content

Commit 713b449

Browse files
asmeurerhonno
authored andcommitted
Fix the solve() inputs strategy with the new, unambiguous input shapes
1 parent 577a4fe commit 713b449

File tree

1 file changed

+7
-15
lines changed

1 file changed

+7
-15
lines changed

array_api_tests/test_linalg.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ def solve_args():
491491
Strategy for the x1 and x2 arguments to test_solve()
492492
493493
solve() takes x1, x2, where x1 is any stack of square invertible matrices
494-
of shape (..., M, M), and x2 is either shape (..., M) or (..., M, K),
494+
of shape (..., M, M), and x2 is either shape (M,) or (..., M, K),
495495
where the ... parts of x1 and x2 are broadcast compatible.
496496
"""
497497
stack_shapes = shared(two_mutually_broadcastable_shapes)
@@ -501,26 +501,18 @@ def solve_args():
501501
pair[0])))
502502

503503
@composite
504-
def x2_shapes(draw):
505-
end = draw(xps.array_shapes(min_dims=0, max_dims=1, min_side=0,
506-
max_side=SQRT_MAX_ARRAY_SIZE))
507-
return draw(stack_shapes)[1] + draw(x1).shape[-1:] + end
504+
def _x2_shapes(draw):
505+
end = draw(integers(0, SQRT_MAX_ARRAY_SIZE))
506+
return draw(stack_shapes)[1] + draw(x1).shape[-1:] + (end,)
508507

509-
x2 = xps.arrays(dtype=xps.floating_dtypes(), shape=x2_shapes())
508+
x2_shapes = one_of(x1.map(lambda x: (x.shape[-1],)), _x2_shapes())
509+
x2 = xps.arrays(dtype=xps.floating_dtypes(), shape=x2_shapes)
510510
return x1, x2
511511

512512
@pytest.mark.xp_extension('linalg')
513513
@given(*solve_args())
514514
def test_solve(x1, x2):
515-
# TODO: solve() is currently ambiguous, in that some inputs can be
516-
# interpreted in two different ways. For example, if x1 is shape (2, 2, 2)
517-
# and x2 is shape (2, 2), should this be interpreted as x2 is (2,) stack
518-
# of a (2,) vector, i.e., the result would be (2, 2, 2, 1) after
519-
# broadcasting, or as a single stack of a 2x2 matrix, i.e., resulting in
520-
# (2, 2, 2, 2).
521-
522-
# res = linalg.solve(x1, x2)
523-
pass
515+
res = linalg.solve(x1, x2)
524516

525517
@pytest.mark.xp_extension('linalg')
526518
@given(

0 commit comments

Comments
 (0)