@@ -491,7 +491,7 @@ def solve_args():
491
491
Strategy for the x1 and x2 arguments to test_solve()
492
492
493
493
solve() takes x1, x2, where x1 is any stack of square invertible matrices
494
- of shape (..., M, M), and x2 is either shape (..., M ) or (..., M, K),
494
+ of shape (..., M, M), and x2 is either shape (M, ) or (..., M, K),
495
495
where the ... parts of x1 and x2 are broadcast compatible.
496
496
"""
497
497
stack_shapes = shared (two_mutually_broadcastable_shapes )
@@ -501,26 +501,18 @@ def solve_args():
501
501
pair [0 ])))
502
502
503
503
@composite
504
- def x2_shapes (draw ):
505
- end = draw (xps .array_shapes (min_dims = 0 , max_dims = 1 , min_side = 0 ,
506
- max_side = SQRT_MAX_ARRAY_SIZE ))
507
- return draw (stack_shapes )[1 ] + draw (x1 ).shape [- 1 :] + end
504
+ def _x2_shapes (draw ):
505
+ end = draw (integers (0 , SQRT_MAX_ARRAY_SIZE ))
506
+ return draw (stack_shapes )[1 ] + draw (x1 ).shape [- 1 :] + (end ,)
508
507
509
- x2 = xps .arrays (dtype = xps .floating_dtypes (), shape = x2_shapes ())
508
+ x2_shapes = one_of (x1 .map (lambda x : (x .shape [- 1 ],)), _x2_shapes ())
509
+ x2 = xps .arrays (dtype = xps .floating_dtypes (), shape = x2_shapes )
510
510
return x1 , x2
511
511
512
512
@pytest .mark .xp_extension ('linalg' )
513
513
@given (* solve_args ())
514
514
def test_solve (x1 , x2 ):
515
- # TODO: solve() is currently ambiguous, in that some inputs can be
516
- # interpreted in two different ways. For example, if x1 is shape (2, 2, 2)
517
- # and x2 is shape (2, 2), should this be interpreted as x2 is (2,) stack
518
- # of a (2,) vector, i.e., the result would be (2, 2, 2, 1) after
519
- # broadcasting, or as a single stack of a 2x2 matrix, i.e., resulting in
520
- # (2, 2, 2, 2).
521
-
522
- # res = linalg.solve(x1, x2)
523
- pass
515
+ res = linalg .solve (x1 , x2 )
524
516
525
517
@pytest .mark .xp_extension ('linalg' )
526
518
@given (
0 commit comments