Skip to content
Snippets Groups Projects
Commit da619a05 authored by Joseph Weston's avatar Joseph Weston
Browse files

switch Xgetrs to fused types

parent 1e30f9c1
No related branches found
No related tags found
No related merge requests found
......@@ -73,21 +73,7 @@ def lu_solve(matrix_factorization, b):
ltype, lu, b = lapack.prepare_for_lapack(False, lu, b)
ipiv = np.ascontiguousarray(np.asanyarray(ipiv), dtype=lapack.int_dtype)
if b.ndim > 2:
raise ValueError("lu_solve: b must be a vector or matrix")
if lu.shape[0] != b.shape[0]:
raise ValueError("lu_solve: incompatible dimensions of b")
if ltype == 'd':
return lapack.dgetrs(lu, ipiv, b)
elif ltype == 'z':
return lapack.zgetrs(lu, ipiv, b)
elif ltype == 's':
return lapack.sgetrs(lu, ipiv, b)
else:
return lapack.cgetrs(lu, ipiv, b)
return lapack.getrs(lu, ipiv, b)
def rcond_from_lu(matrix_factorization, norm_a, norm="1"):
......
......@@ -9,7 +9,7 @@
"""Low-level access to LAPACK functions. """
__all__ = ['getrf',
'sgetrs', 'dgetrs', 'cgetrs', 'zgetrs',
'getrs',
'sgecon', 'dgecon', 'cgecon', 'zgecon',
'sggev', 'dggev', 'cggev', 'zggev',
'sgees', 'dgees', 'cgees', 'zgees',
......@@ -102,127 +102,56 @@ def getrf(np.ndarray[scalar, ndim=2] A):
return (A, ipiv, info > 0 or M != N)
# Wrappers for xGETRS
def sgetrs(np.ndarray[np.float32_t, ndim=2] LU,
np.ndarray[l_int, ndim=1] IPIV, B):
def getrs(np.ndarray[scalar, ndim=2] LU, np.ndarray[l_int] IPIV,
np.ndarray B):
cdef l_int N, NRHS, info
cdef np.ndarray b
assert_fortran_mat(LU)
# again: workaround for 1x1-Fortran bug in NumPy < v2.0
if (not isinstance(B, np.ndarray) or
(B.ndim == 2 and (B.shape[0] > 1 or B.shape[1] > 1) and
not B.flags["F_CONTIGUOUS"])):
raise ValueError("In dgetrs: B must be a Fortran ordered NumPy array")
b = B
N = LU.shape[0]
if b.ndim == 1:
NRHS = 1
elif b.ndim == 2:
NRHS = B.shape[1]
else:
raise ValueError("In sgetrs: B must be a vector or matrix")
# Consistency checks for LU and B
lapack.sgetrs("N", &N, &NRHS, <float *>LU.data, &N,
<l_int *>IPIV.data, <float *>b.data, &N,
&info)
if B.descr.type_num != LU.descr.type_num:
raise TypeError('B must have same dtype as LU')
assert info == 0, "Argument error in sgetrs"
return b
def dgetrs(np.ndarray[np.float64_t, ndim=2] LU,
np.ndarray[l_int, ndim=1] IPIV, B):
cdef l_int N, NRHS, info
cdef np.ndarray b
assert_fortran_mat(LU)
# again: workaround for 1x1-Fortran bug in NumPy < v2.0
if (not isinstance(B, np.ndarray) or
(B.ndim == 2 and (B.shape[0] > 1 or B.shape[1] > 1) and
# Workaround for 1x1-Fortran bug in NumPy < v2.0
if ((B.ndim == 2 and (B.shape[0] > 1 or B.shape[1] > 1) and
not B.flags["F_CONTIGUOUS"])):
raise ValueError("In dgetrs: B must be a Fortran ordered NumPy array")
b = B
N = LU.shape[0]
if b.ndim == 1:
NRHS = 1
elif b.ndim == 2:
NRHS = b.shape[1]
else:
raise ValueError("In dgetrs: B must be a vector or matrix")
raise ValueError("B must be Fortran ordered")
lapack.dgetrs("N", &N, &NRHS, <double *>LU.data, &N,
<l_int *>IPIV.data, <double *>b.data, &N,
&info)
if B.ndim > 2:
raise ValueError("B must be a vector or matrix")
assert info == 0, "Argument error in dgetrs"
return b
def cgetrs(np.ndarray[np.complex64_t, ndim=2] LU,
np.ndarray[l_int, ndim=1] IPIV, B):
cdef l_int N, NRHS, info
cdef np.ndarray b
assert_fortran_mat(LU)
if LU.shape[0] != B.shape[0]:
raise ValueError('LU and B have incompatible shapes')
# again: workaround for 1x1-Fortran bug in NumPy < v2.0
if (not isinstance(B, np.ndarray) or
(B.ndim == 2 and (B.shape[0] > 1 or B.shape[1] > 1) and
not B.flags["F_CONTIGUOUS"])):
raise ValueError("In dgetrs: B must be a Fortran ordered NumPy array")
b = B
N = LU.shape[0]
if b.ndim == 1:
NRHS = 1
elif b.ndim == 2:
NRHS = b.shape[1]
else:
raise ValueError("In cgetrs: B must be a vector or matrix")
lapack.cgetrs("N", &N, &NRHS, <float complex *>LU.data, &N,
<l_int *>IPIV.data, <float complex *>b.data, &N,
&info)
assert info == 0, "Argument error in cgetrs"
return b
def zgetrs(np.ndarray[np.complex128_t, ndim=2] LU,
np.ndarray[l_int, ndim=1] IPIV, B):
cdef l_int N, NRHS, info
cdef np.ndarray b
assert_fortran_mat(LU)
# again: workaround for 1x1-Fortran bug in NumPy < v2.0
if (not isinstance(B, np.ndarray) or
(B.ndim == 2 and (B.shape[0] > 1 or B.shape[1] > 1) and
not B.flags["F_CONTIGUOUS"])):
raise ValueError("In dgetrs: B must be a Fortran ordered NumPy array")
b = B
N = LU.shape[0]
if b.ndim == 1:
if B.ndim == 1:
NRHS = 1
elif b.ndim == 2:
NRHS = b.shape[1]
else:
raise ValueError("In zgetrs: B must be a vector or matrix")
elif B.ndim == 2:
NRHS = B.shape[1]
lapack.zgetrs("N", &N, &NRHS, <double complex *>LU.data, &N,
<l_int *>IPIV.data, <double complex *>b.data, &N,
&info)
if scalar is float:
lapack.sgetrs("N", &N, &NRHS, <float *>LU.data, &N,
<l_int *>IPIV.data, <float *>B.data, &N,
&info)
elif scalar is double:
lapack.dgetrs("N", &N, &NRHS, <double *>LU.data, &N,
<l_int *>IPIV.data, <double *>B.data, &N,
&info)
elif scalar is float_complex:
lapack.cgetrs("N", &N, &NRHS, <float complex *>LU.data, &N,
<l_int *>IPIV.data, <float complex *>B.data, &N,
&info)
elif scalar is double_complex:
lapack.zgetrs("N", &N, &NRHS, <double complex *>LU.data, &N,
<l_int *>IPIV.data, <double complex *>B.data, &N,
&info)
assert info == 0, "Argument error in zgetrs"
assert info == 0, "Argument error in getrs"
return b
return B
# Wrappers for xGECON
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment