Skip to content

Commit d974716

Browse files
asmeurerhonno
authored andcommitted
Add an rtols strategy
Start implementing tests for pinv and matrix_norm.
1 parent d03206e commit d974716

File tree

2 files changed

+26
-16
lines changed

2 files changed

+26
-16
lines changed

array_api_tests/hypothesis_helpers.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,21 @@ def matrix_shapes(draw, stack_shapes=shapes()):
158158

159159
square_matrix_shapes = matrix_shapes().filter(lambda shape: shape[-1] == shape[-2])
160160

161-
finite_matrices = xps.arrays(dtype=xps.floating_dtypes(),
162-
shape=matrix_shapes(),
163-
elements=dict(allow_nan=False,
164-
allow_infinity=False))
161+
@composite
162+
def finite_matrices(draw, shape=matrix_shapes()):
163+
return draw(xps.arrays(dtype=xps.floating_dtypes(),
164+
shape=shape,
165+
elements=dict(allow_nan=False,
166+
allow_infinity=False)))
167+
168+
rtol_shared_matrix_shapes = shared(matrix_shapes())
169+
# Should we set a max_value here?
170+
_rtol_float_kw = dict(allow_nan=False, allow_infinity=False, min_value=0)
171+
rtols = one_of(floats(**_rtol_float_kw),
172+
xps.arrays(dtype=xps.floating_dtypes(),
173+
shape=rtol_shared_matrix_shapes.map(lambda shape: shape[:-2]),
174+
elements=_rtol_float_kw))
175+
165176

166177
def mutually_broadcastable_shapes(
167178
num_shapes: int,

array_api_tests/test_linalg.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
invertible_matrices, two_mutual_arrays,
2727
mutually_promotable_dtypes, one_d_shapes,
2828
two_mutually_broadcastable_shapes,
29-
SQRT_MAX_ARRAY_SIZE, finite_matrices)
29+
SQRT_MAX_ARRAY_SIZE, finite_matrices,
30+
rtol_shared_matrix_shapes, rtols)
3031
from . import dtype_helpers as dh
3132
from . import pytest_helpers as ph
3233

@@ -311,7 +312,7 @@ def test_matmul(x1, x2):
311312

312313
@pytest.mark.xp_extension('linalg')
313314
@given(
314-
x=finite_matrices,
315+
x=finite_matrices(),
315316
kw=kwargs(keepdims=booleans(),
316317
ord=sampled_from([-float('inf'), -2, -2, 1, 2, float('inf'), 'fro', 'nuc']))
317318
)
@@ -357,12 +358,11 @@ def test_matrix_power(x, n):
357358

358359
@pytest.mark.xp_extension('linalg')
359360
@given(
360-
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes()),
361-
kw=kwargs(rtol=todo)
361+
x=finite_matrices(shape=rtol_shared_matrix_shapes),
362+
kw=kwargs(rtol=rtols)
362363
)
363364
def test_matrix_rank(x, kw):
364-
# res = linalg.matrix_rank(x, **kw)
365-
pass
365+
res = linalg.matrix_rank(x, **kw)
366366

367367
@given(
368368
x=xps.arrays(dtype=dtypes, shape=matrix_shapes()),
@@ -407,12 +407,11 @@ def test_outer(x1, x2):
407407

408408
@pytest.mark.xp_extension('linalg')
409409
@given(
410-
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes()),
411-
kw=kwargs(rtol=todo)
410+
x=finite_matrices(shape=rtol_shared_matrix_shapes),
411+
kw=kwargs(rtol=rtols)
412412
)
413413
def test_pinv(x, kw):
414-
# res = linalg.pinv(x, **kw)
415-
pass
414+
res = linalg.pinv(x, **kw)
416415

417416
@pytest.mark.xp_extension('linalg')
418417
@given(
@@ -525,7 +524,7 @@ def test_solve(x1, x2):
525524

526525
@pytest.mark.xp_extension('linalg')
527526
@given(
528-
x=finite_matrices,
527+
x=finite_matrices(),
529528
kw=kwargs(full_matrices=booleans())
530529
)
531530
def test_svd(x, kw):
@@ -561,7 +560,7 @@ def test_svd(x, kw):
561560

562561
@pytest.mark.xp_extension('linalg')
563562
@given(
564-
x=finite_matrices,
563+
x=finite_matrices(),
565564
)
566565
def test_svdvals(x):
567566
res = linalg.svdvals(x)

0 commit comments

Comments
 (0)