Skip to content

Commit d03206e

Browse files
asmeurerhonno
authored andcommitted
Test the output dtype and shape in matrix_norm
1 parent b33e556 commit d03206e

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

array_api_tests/test_linalg.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,15 @@ def test_matrix_norm(x, kw):
319319
res = linalg.matrix_norm(x, **kw)
320320

321321
keepdims = kw.get('keepdims', False)
322-
ord = kw.get('ord', 'fro')
322+
# TODO: Check that the ord values give the correct norms.
323+
# ord = kw.get('ord', 'fro')
324+
325+
if keepdims:
326+
expected_shape = x.shape[:-2] + (1, 1)
327+
else:
328+
expected_shape = x.shape[:-2]
329+
assert res.shape == expected_shape, f"matrix_norm({keepdims=}) did not return the correct shape"
330+
assert res.dtype == x.dtype, "matrix_norm() did not return the correct dtype"
323331

324332
_test_stacks(linalg.matrix_norm, x, **kw, dims=2 if keepdims else 0,
325333
res=res)

0 commit comments

Comments
 (0)