@@ -1911,19 +1911,23 @@ def bartlett(M):
1911
1911
return torch .bartlett_window (M , periodic = False , dtype = dtype )
1912
1912
1913
1913
1914
-
1915
1914
# ### Dtype routines ###
1916
1915
1917
1916
# vendored from https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L666
1918
1917
1919
1918
1920
- array_type = [[torch .float16 , torch .float32 , torch .float64 ],
1921
- [None , torch .complex64 , torch .complex128 ]]
1922
- array_precision = {torch .float16 : 0 ,
1923
- torch .float32 : 1 ,
1924
- torch .float64 : 2 ,
1925
- torch .complex64 : 1 ,
1926
- torch .complex128 : 2 ,}
1919
+ array_type = [
1920
+ [torch .float16 , torch .float32 , torch .float64 ],
1921
+ [None , torch .complex64 , torch .complex128 ],
1922
+ ]
1923
+ array_precision = {
1924
+ torch .float16 : 0 ,
1925
+ torch .float32 : 1 ,
1926
+ torch .float64 : 2 ,
1927
+ torch .complex64 : 1 ,
1928
+ torch .complex128 : 2 ,
1929
+ }
1930
+
1927
1931
1928
1932
@normalizer
1929
1933
def common_type (* tensors : ArrayLike ):
@@ -1936,7 +1940,7 @@ def common_type(*tensors: ArrayLike):
1936
1940
t = a .dtype
1937
1941
if iscomplexobj (a ):
1938
1942
is_complex = True
1939
- if not (t .is_floating_point or t .is_complex ):
1943
+ if not (t .is_floating_point or t .is_complex ):
1940
1944
p = 2 # array_precision[_nx.double]
1941
1945
else :
1942
1946
p = array_precision .get (t , None )
@@ -1947,5 +1951,3 @@ def common_type(*tensors: ArrayLike):
1947
1951
return array_type [1 ][precision ]
1948
1952
else :
1949
1953
return array_type [0 ][precision ]
1950
-
1951
-
0 commit comments