diff --git a/kwant/linalg/decomp_lu.py b/kwant/linalg/decomp_lu.py index acdbf06061c584ec348c667098a3c6127ccdb798..19ea4e70feab5b0c985f717a8321accb3e7ff029 100644 --- a/kwant/linalg/decomp_lu.py +++ b/kwant/linalg/decomp_lu.py @@ -73,21 +73,7 @@ def lu_solve(matrix_factorization, b): ltype, lu, b = lapack.prepare_for_lapack(False, lu, b) ipiv = np.ascontiguousarray(np.asanyarray(ipiv), dtype=lapack.int_dtype) - - if b.ndim > 2: - raise ValueError("lu_solve: b must be a vector or matrix") - - if lu.shape[0] != b.shape[0]: - raise ValueError("lu_solve: incompatible dimensions of b") - - if ltype == 'd': - return lapack.dgetrs(lu, ipiv, b) - elif ltype == 'z': - return lapack.zgetrs(lu, ipiv, b) - elif ltype == 's': - return lapack.sgetrs(lu, ipiv, b) - else: - return lapack.cgetrs(lu, ipiv, b) + return lapack.getrs(lu, ipiv, b) def rcond_from_lu(matrix_factorization, norm_a, norm="1"): diff --git a/kwant/linalg/lapack.pyx b/kwant/linalg/lapack.pyx index f9c2c06a0bb8fdee5a3037cfa428081cf25923a0..3f52b18ee715f719ebedfbd0199ec137f81f790f 100644 --- a/kwant/linalg/lapack.pyx +++ b/kwant/linalg/lapack.pyx @@ -9,7 +9,7 @@ """Low-level access to LAPACK functions. """ __all__ = ['getrf', - 'sgetrs', 'dgetrs', 'cgetrs', 'zgetrs', + 'getrs', 'sgecon', 'dgecon', 'cgecon', 'zgecon', 'sggev', 'dggev', 'cggev', 'zggev', 'sgees', 'dgees', 'cgees', 'zgees', @@ -102,127 +102,56 @@ def getrf(np.ndarray[scalar, ndim=2] A): return (A, ipiv, info > 0 or M != N) -# Wrappers for xGETRS -def sgetrs(np.ndarray[np.float32_t, ndim=2] LU, - np.ndarray[l_int, ndim=1] IPIV, B): +def getrs(np.ndarray[scalar, ndim=2] LU, np.ndarray[l_int] IPIV, + np.ndarray B): cdef l_int N, NRHS, info - cdef np.ndarray b assert_fortran_mat(LU) - # again: workaround for 1x1-Fortran bug in NumPy < v2.0 - if (not isinstance(B, np.ndarray) or - (B.ndim == 2 and (B.shape[0] > 1 or B.shape[1] > 1) and - not B.flags["F_CONTIGUOUS"])): - raise ValueError("In dgetrs: B must be a Fortran ordered NumPy array") - - b = B - N = LU.shape[0] - if b.ndim == 1: - NRHS = 1 - elif b.ndim == 2: - NRHS = B.shape[1] - else: - raise ValueError("In sgetrs: B must be a vector or matrix") + # Consistency checks for LU and B - lapack.sgetrs("N", &N, &NRHS, <float *>LU.data, &N, - <l_int *>IPIV.data, <float *>b.data, &N, - &info) + if B.descr.type_num != LU.descr.type_num: + raise TypeError('B must have same dtype as LU') - assert info == 0, "Argument error in sgetrs" - - return b - -def dgetrs(np.ndarray[np.float64_t, ndim=2] LU, - np.ndarray[l_int, ndim=1] IPIV, B): - cdef l_int N, NRHS, info - cdef np.ndarray b - - assert_fortran_mat(LU) - - # again: workaround for 1x1-Fortran bug in NumPy < v2.0 - if (not isinstance(B, np.ndarray) or - (B.ndim == 2 and (B.shape[0] > 1 or B.shape[1] > 1) and + # Workaround for 1x1-Fortran bug in NumPy < v2.0 + if ((B.ndim == 2 and (B.shape[0] > 1 or B.shape[1] > 1) and not B.flags["F_CONTIGUOUS"])): - raise ValueError("In dgetrs: B must be a Fortran ordered NumPy array") - - b = B - N = LU.shape[0] - if b.ndim == 1: - NRHS = 1 - elif b.ndim == 2: - NRHS = b.shape[1] - else: - raise ValueError("In dgetrs: B must be a vector or matrix") + raise ValueError("B must be Fortran ordered") - lapack.dgetrs("N", &N, &NRHS, <double *>LU.data, &N, - <l_int *>IPIV.data, <double *>b.data, &N, - &info) + if B.ndim > 2: + raise ValueError("B must be a vector or matrix") - assert info == 0, "Argument error in dgetrs" - - return b - -def cgetrs(np.ndarray[np.complex64_t, ndim=2] LU, - np.ndarray[l_int, ndim=1] IPIV, B): - cdef l_int N, NRHS, info - cdef np.ndarray b - - assert_fortran_mat(LU) + if LU.shape[0] != B.shape[0]: + raise ValueError('LU and B have incompatible shapes') - # again: workaround for 1x1-Fortran bug in NumPy < v2.0 - if (not isinstance(B, np.ndarray) or - (B.ndim == 2 and (B.shape[0] > 1 or B.shape[1] > 1) and - not B.flags["F_CONTIGUOUS"])): - raise ValueError("In dgetrs: B must be a Fortran ordered NumPy array") - - b = B N = LU.shape[0] - if b.ndim == 1: - NRHS = 1 - elif b.ndim == 2: - NRHS = b.shape[1] - else: - raise ValueError("In cgetrs: B must be a vector or matrix") - - lapack.cgetrs("N", &N, &NRHS, <float complex *>LU.data, &N, - <l_int *>IPIV.data, <float complex *>b.data, &N, - &info) - - assert info == 0, "Argument error in cgetrs" - - return b - -def zgetrs(np.ndarray[np.complex128_t, ndim=2] LU, - np.ndarray[l_int, ndim=1] IPIV, B): - cdef l_int N, NRHS, info - cdef np.ndarray b - - assert_fortran_mat(LU) - - # again: workaround for 1x1-Fortran bug in NumPy < v2.0 - if (not isinstance(B, np.ndarray) or - (B.ndim == 2 and (B.shape[0] > 1 or B.shape[1] > 1) and - not B.flags["F_CONTIGUOUS"])): - raise ValueError("In dgetrs: B must be a Fortran ordered NumPy array") - b = B - N = LU.shape[0] - if b.ndim == 1: + if B.ndim == 1: NRHS = 1 - elif b.ndim == 2: - NRHS = b.shape[1] - else: - raise ValueError("In zgetrs: B must be a vector or matrix") + elif B.ndim == 2: + NRHS = B.shape[1] - lapack.zgetrs("N", &N, &NRHS, <double complex *>LU.data, &N, - <l_int *>IPIV.data, <double complex *>b.data, &N, - &info) + if scalar is float: + lapack.sgetrs("N", &N, &NRHS, <float *>LU.data, &N, + <l_int *>IPIV.data, <float *>B.data, &N, + &info) + elif scalar is double: + lapack.dgetrs("N", &N, &NRHS, <double *>LU.data, &N, + <l_int *>IPIV.data, <double *>B.data, &N, + &info) + elif scalar is float_complex: + lapack.cgetrs("N", &N, &NRHS, <float complex *>LU.data, &N, + <l_int *>IPIV.data, <float complex *>B.data, &N, + &info) + elif scalar is double_complex: + lapack.zgetrs("N", &N, &NRHS, <double complex *>LU.data, &N, + <l_int *>IPIV.data, <double complex *>B.data, &N, + &info) - assert info == 0, "Argument error in zgetrs" + assert info == 0, "Argument error in getrs" - return b + return B # Wrappers for xGECON