Skip to content

Commit a29c8da

Browse files
authored
Merge pull request #100 from Quansight-Labs/linalg
add torch_np.linalg module
2 parents 0251227 + f50cb75 commit a29c8da

File tree

5 files changed

+395
-135
lines changed

5 files changed

+395
-135
lines changed

torch_np/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from . import random
1+
from . import linalg, random
22
from ._binary_ufuncs import *
33
from ._detail._util import AxisError, UFuncTypeError
44
from ._dtypes import *

torch_np/_funcs.py

+50-4
Original file line numberDiff line numberDiff line change
@@ -911,10 +911,6 @@ def array_equiv(a1: ArrayLike, a2: ArrayLike):
911911
return _tensor_equal(a1_t, a2_t)
912912

913913

914-
def common_type():
915-
raise NotImplementedError
916-
917-
918914
def mintypecode():
919915
raise NotImplementedError
920916

@@ -927,6 +923,10 @@ def asfarray():
927923
raise NotImplementedError
928924

929925

926+
def block(*args, **kwds):
927+
raise NotImplementedError
928+
929+
930930
# ### put/take_along_axis ###
931931

932932

@@ -1358,8 +1358,12 @@ def reshape(a: ArrayLike, newshape, order="C"):
13581358
@normalizer
13591359
def transpose(a: ArrayLike, axes=None):
13601360
# numpy allows both .tranpose(sh) and .transpose(*sh)
1361+
# also older code uses axes being a list
13611362
if axes in [(), None, (None,)]:
13621363
axes = tuple(range(a.ndim))[::-1]
1364+
elif len(axes) == 1:
1365+
axes = axes[0]
1366+
13631367
try:
13641368
result = a.permute(axes)
13651369
except RuntimeError:
@@ -1908,3 +1912,45 @@ def blackman(M):
19081912
def bartlett(M):
19091913
dtype = _dtypes_impl.default_float_dtype
19101914
return torch.bartlett_window(M, periodic=False, dtype=dtype)
1915+
1916+
1917+
# ### Dtype routines ###
1918+
1919+
# vendored from https://git.1-hub.cnnumpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L666
1920+
1921+
1922+
array_type = [
1923+
[torch.float16, torch.float32, torch.float64],
1924+
[None, torch.complex64, torch.complex128],
1925+
]
1926+
array_precision = {
1927+
torch.float16: 0,
1928+
torch.float32: 1,
1929+
torch.float64: 2,
1930+
torch.complex64: 1,
1931+
torch.complex128: 2,
1932+
}
1933+
1934+
1935+
@normalizer
1936+
def common_type(*tensors: ArrayLike):
1937+
1938+
import builtins
1939+
1940+
is_complex = False
1941+
precision = 0
1942+
for a in tensors:
1943+
t = a.dtype
1944+
if iscomplexobj(a):
1945+
is_complex = True
1946+
if not (t.is_floating_point or t.is_complex):
1947+
p = 2 # array_precision[_nx.double]
1948+
else:
1949+
p = array_precision.get(t, None)
1950+
if p is None:
1951+
raise TypeError("can't get common type for non-numeric array")
1952+
precision = builtins.max(precision, p)
1953+
if is_complex:
1954+
return array_type[1][precision]
1955+
else:
1956+
return array_type[0][precision]

torch_np/_ndarray.py

+8
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,14 @@ def copy(self, order="C"):
152152
tensor = self.tensor.clone()
153153
return ndarray(tensor)
154154

155+
def view(self, dtype):
156+
torch_dtype = _dtypes.dtype(dtype).torch_dtype
157+
tview = self.tensor.view(torch_dtype)
158+
return ndarray(tview)
159+
160+
def fill(self, value):
161+
self.tensor.fill_(value)
162+
155163
def tolist(self):
156164
return self.tensor.tolist()
157165

torch_np/linalg.py

+240
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
import functools
2+
import math
3+
from typing import Sequence
4+
5+
import torch
6+
7+
from ._detail import _dtypes_impl, _util
8+
from ._normalizations import ArrayLike, normalizer
9+
10+
11+
class LinAlgError(Exception):
12+
pass
13+
14+
15+
def _atleast_float_1(a):
16+
if not (a.dtype.is_floating_point or a.dtype.is_complex):
17+
a = a.to(_dtypes_impl.default_float_dtype)
18+
return a
19+
20+
21+
def _atleast_float_2(a, b):
22+
dtyp = _dtypes_impl.result_type_impl((a.dtype, b.dtype))
23+
if not (dtyp.is_floating_point or dtyp.is_complex):
24+
dtyp = _dtypes_impl.default_float_dtype
25+
26+
a = _util.cast_if_needed(a, dtyp)
27+
b = _util.cast_if_needed(b, dtyp)
28+
return a, b
29+
30+
31+
def linalg_errors(func):
32+
@functools.wraps(func)
33+
def wrapped(*args, **kwds):
34+
try:
35+
return func(*args, **kwds)
36+
except torch._C._LinAlgError as e:
37+
raise LinAlgError(*e.args)
38+
39+
return wrapped
40+
41+
42+
# ### Matrix and vector products ###
43+
44+
45+
@normalizer
46+
@linalg_errors
47+
def matrix_power(a: ArrayLike, n):
48+
a = _atleat_float_1(a)
49+
return torch.linalg.matrix_power(a, n)
50+
51+
52+
@normalizer
53+
@linalg_errors
54+
def multi_dot(inputs: Sequence[ArrayLike], *, out=None):
55+
return torch.linalg.multi_dot(inputs)
56+
57+
58+
# ### Solving equations and inverting matrices ###
59+
60+
61+
@normalizer
62+
@linalg_errors
63+
def solve(a: ArrayLike, b: ArrayLike):
64+
a, b = _atleast_float_2(a, b)
65+
return torch.linalg.solve(a, b)
66+
67+
68+
@normalizer
69+
@linalg_errors
70+
def lstsq(a: ArrayLike, b: ArrayLike, rcond=None):
71+
a, b = _atleast_float_2(a, b)
72+
# NumPy is using gelsd: https://git.1-hub.cnnumpy/numpy/blob/v1.24.0/numpy/linalg/umath_linalg.cpp#L3991
73+
# on CUDA, only `gels` is available though, so use it instead
74+
driver = "gels" if a.is_cuda or b.is_cuda else "gelsd"
75+
return torch.linalg.lstsq(a, b, rcond=rcond, driver=driver)
76+
77+
78+
@normalizer
79+
@linalg_errors
80+
def inv(a: ArrayLike):
81+
a = _atleast_float_1(a)
82+
result = torch.linalg.inv(a)
83+
return result
84+
85+
86+
@normalizer
87+
@linalg_errors
88+
def pinv(a: ArrayLike, rcond=1e-15, hermitian=False):
89+
a = _atleast_float_1(a)
90+
return torch.linalg.pinv(a, rtol=rcond, hermitian=hermitian)
91+
92+
93+
@normalizer
94+
@linalg_errors
95+
def tensorsolve(a: ArrayLike, b: ArrayLike, axes=None):
96+
a, b = _atleast_float_2(a, b)
97+
return torch.linalg.tensorsolve(a, b, dims=axes)
98+
99+
100+
@normalizer
101+
@linalg_errors
102+
def tensorinv(a: ArrayLike, ind=2):
103+
a = _atleast_float_1(a)
104+
return torch.linalg.tensorinv(a, ind=ind)
105+
106+
107+
# ### Norms and other numbers ###
108+
109+
110+
@normalizer
111+
@linalg_errors
112+
def det(a: ArrayLike):
113+
a = _atleast_float_1(a)
114+
return torch.linalg.det(a)
115+
116+
117+
@normalizer
118+
@linalg_errors
119+
def slogdet(a: ArrayLike):
120+
a = _atleast_float_1(a)
121+
return torch.linalg.slogdet(a)
122+
123+
124+
@normalizer
125+
@linalg_errors
126+
def cond(x: ArrayLike, p=None):
127+
x = _atleast_float_1(x)
128+
129+
# check if empty
130+
# cf: https://git.1-hub.cnnumpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744
131+
if x.numel() == 0 and math.prod(x.shape[-2:]) == 0:
132+
raise LinAlgError("cond is not defined on empty arrays")
133+
134+
result = torch.linalg.cond(x, p=p)
135+
136+
# Convert nans to infs (numpy does it in a data-dependent way, depending on
137+
# whether the input array has nans or not)
138+
# XXX: NumPy does this: https://git.1-hub.cnnumpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1744
139+
return torch.where(torch.isnan(result), float("inf"), result)
140+
141+
142+
@normalizer
143+
@linalg_errors
144+
def matrix_rank(a: ArrayLike, tol=None, hermitian=False):
145+
a = _atleast_float_1(a)
146+
147+
if a.ndim < 2:
148+
return int((a != 0).any())
149+
150+
if tol is None:
151+
# follow https://git.1-hub.cnnumpy/numpy/blob/v1.24.0/numpy/linalg/linalg.py#L1885
152+
atol = 0
153+
rtol = max(a.shape[-2:]) * torch.finfo(a.dtype).eps
154+
else:
155+
atol, rtol = tol, 0
156+
return torch.linalg.matrix_rank(a, atol=atol, rtol=rtol, hermitian=hermitian)
157+
158+
159+
@normalizer
160+
@linalg_errors
161+
def norm(x: ArrayLike, ord=None, axis=None, keepdims=False):
162+
x = _atleast_float_1(x)
163+
result = torch.linalg.norm(x, ord=ord, dim=axis)
164+
if keepdims:
165+
result = _util.apply_keepdims(result, axis, x.ndim)
166+
return result
167+
168+
169+
# ### Decompositions ###
170+
171+
172+
@normalizer
173+
@linalg_errors
174+
def cholesky(a: ArrayLike):
175+
a = _atleast_float_1(a)
176+
return torch.linalg.cholesky(a)
177+
178+
179+
@normalizer
180+
@linalg_errors
181+
def qr(a: ArrayLike, mode="reduced"):
182+
a = _atleast_float_1(a)
183+
result = torch.linalg.qr(a, mode=mode)
184+
if mode == "r":
185+
# match NumPy
186+
result = result.R
187+
return result
188+
189+
190+
@normalizer
191+
@linalg_errors
192+
def svd(a: ArrayLike, full_matrices=True, compute_uv=True, hermitian=False):
193+
a = _atleast_float_1(a)
194+
if not compute_uv:
195+
return torch.linalg.svdvals(a)
196+
197+
# NB: ignore the hermitian= argument (no pytorch equivalent)
198+
result = torch.linalg.svd(a, full_matrices=full_matrices)
199+
return result
200+
201+
202+
# ### Eigenvalues and eigenvectors ###
203+
204+
205+
@normalizer
206+
@linalg_errors
207+
def eig(a: ArrayLike):
208+
a = _atleast_float_1(a)
209+
w, vt = torch.linalg.eig(a)
210+
211+
if not a.is_complex():
212+
if w.is_complex() and (w.imag == 0).all():
213+
w = w.real
214+
vt = vt.real
215+
return w, vt
216+
217+
218+
@normalizer
219+
@linalg_errors
220+
def eigh(a: ArrayLike, UPLO="L"):
221+
a = _atleast_float_1(a)
222+
return torch.linalg.eigh(a, UPLO=UPLO)
223+
224+
225+
@normalizer
226+
@linalg_errors
227+
def eigvals(a: ArrayLike):
228+
a = _atleast_float_1(a)
229+
result = torch.linalg.eigvals(a)
230+
if not a.is_complex():
231+
if result.is_complex() and (result.imag == 0).all():
232+
result = result.real
233+
return result
234+
235+
236+
@normalizer
237+
@linalg_errors
238+
def eigvalsh(a: ArrayLike, UPLO="L"):
239+
a = _atleast_float_1(a)
240+
return torch.linalg.eigvalsh(a, UPLO=UPLO)

0 commit comments

Comments
 (0)