diff --git a/kwant/linalg/decomp_lu.py b/kwant/linalg/decomp_lu.py index 19ea4e70feab5b0c985f717a8321accb3e7ff029..ace2af08c3d8d4b986fb69179944596ec5976360 100644 --- a/kwant/linalg/decomp_lu.py +++ b/kwant/linalg/decomp_lu.py @@ -101,17 +101,6 @@ def rcond_from_lu(matrix_factorization, norm_a, norm="1"): norm specified in norm """ (lu, ipiv, singular) = matrix_factorization - if not norm in ("1", "I"): - raise ValueError("norm in rcond_from_lu must be either '1' or 'I'") norm = norm.encode('utf8') # lapack expects bytes - ltype, lu = lapack.prepare_for_lapack(False, lu) - - if ltype == 'd': - return lapack.dgecon(lu, norm_a, norm) - elif ltype == 'z': - return lapack.zgecon(lu, norm_a, norm) - elif ltype == 's': - return lapack.sgecon(lu, norm_a, norm) - else: - return lapack.cgecon(lu, norm_a, norm) + return lapack.gecon(lu, norm_a, norm) diff --git a/kwant/linalg/lapack.pyx b/kwant/linalg/lapack.pyx index 3f52b18ee715f719ebedfbd0199ec137f81f790f..266d499775db0e45e5c2e2001317664dbfc2e35d 100644 --- a/kwant/linalg/lapack.pyx +++ b/kwant/linalg/lapack.pyx @@ -10,7 +10,7 @@ __all__ = ['getrf', 'getrs', - 'sgecon', 'dgecon', 'cgecon', 'zgecon', + 'gecon', 'sggev', 'dggev', 'cggev', 'zggev', 'sgees', 'dgees', 'cgees', 'zgees', 'strsen', 'dtrsen', 'ctrsen', 'ztrsen', @@ -153,91 +153,65 @@ def getrs(np.ndarray[scalar, ndim=2] LU, np.ndarray[l_int] IPIV, return B -# Wrappers for xGECON -def sgecon(np.ndarray[np.float32_t, ndim=2] LU, - float normA, char *norm = b"1"): +def gecon(np.ndarray[scalar, ndim=2] LU, double normA, char *norm = b"1"): cdef l_int N, info - cdef float rcond - cdef np.ndarray[np.float32_t, ndim=1] work - cdef np.ndarray[l_int, ndim=1] iwork + cdef float srcond, snormA + cdef double drcond - assert_fortran_mat(LU) - - N = LU.shape[0] - work = np.empty(4*N, dtype = np.float32) - iwork = np.empty(N, dtype = int_dtype) - - lapack.sgecon(norm, &N, <float *>LU.data, &N, &normA, - &rcond, <float *>work.data, - <l_int *>iwork.data, &info) - - assert info == 0, "Argument error in sgecon" - - return rcond - -def dgecon(np.ndarray[np.float64_t, ndim=2] LU, - double normA, char *norm = b"1"): - cdef l_int N, info - cdef double rcond - cdef np.ndarray[np.float64_t, ndim=1] work - cdef np.ndarray[l_int, ndim=1] iwork + # Parameter checks assert_fortran_mat(LU) + if norm[0] != b"1" and norm[0] != b"I": + raise ValueError("'norm' must be either '1' or 'I'") + if scalar in single_precision: + snormA = normA - N = LU.shape[0] - work = np.empty(4*N, dtype = np.float64) - iwork = np.empty(N, dtype = int_dtype) - - lapack.dgecon(norm, &N, <double *>LU.data, &N, &normA, - &rcond, <double *>work.data, - <l_int *>iwork.data, &info) - - assert info == 0, "Argument error in dgecon" - - return rcond - -def cgecon(np.ndarray[np.complex64_t, ndim=2] LU, - float normA, char *norm = b"1"): - cdef l_int N, info - cdef float rcond - cdef np.ndarray[np.complex64_t, ndim=1] work - cdef np.ndarray[np.float32_t, ndim=1] rwork - - assert_fortran_mat(LU) + # Allocate workspaces N = LU.shape[0] - work = np.empty(2*N, dtype = np.complex64) - rwork = np.empty(2*N, dtype = np.float32) - - lapack.cgecon(norm, &N, <float complex *>LU.data, &N, &normA, - &rcond, <float complex *>work.data, - <float *>rwork.data, &info) - assert info == 0, "Argument error in cgecon" - - return rcond + cdef np.ndarray[l_int] iwork + if scalar in floating: + iwork = np.empty(N, dtype=int_dtype) -def zgecon(np.ndarray[np.complex128_t, ndim=2] LU, - double normA, char *norm = b"1"): - cdef l_int N, info - cdef double rcond - cdef np.ndarray[np.complex128_t, ndim=1] work - cdef np.ndarray[np.float64_t, ndim=1] rwork + cdef np.ndarray[scalar] work + if scalar in floating: + work = np.empty(4 * N, dtype=LU.dtype) + else: + work = np.empty(2 * N, dtype=LU.dtype) - assert_fortran_mat(LU) + cdef np.ndarray rwork + if scalar is float_complex: + rwork = np.empty(2 * N, dtype=np.float32) + elif scalar is double_complex: + rwork = np.empty(2 * N, dtype=np.float64) - N = LU.shape[0] - work = np.empty(2*N, dtype = np.complex128) - rwork = np.empty(2*N, dtype = np.float64) + # The actual calculation - lapack.zgecon(norm, &N, <double complex *>LU.data, &N, &normA, - &rcond, <double complex *>work.data, - <double *>rwork.data, &info) + if scalar is float: + lapack.sgecon(norm, &N, <float *>LU.data, &N, &snormA, + &srcond, <float *>work.data, + <l_int *>iwork.data, &info) + elif scalar is double: + lapack.dgecon(norm, &N, <double *>LU.data, &N, &normA, + &drcond, <double *>work.data, + <l_int *>iwork.data, &info) + elif scalar is float_complex: + lapack.cgecon(norm, &N, <float complex *>LU.data, &N, &snormA, + &srcond, <float complex *>work.data, + <float *>rwork.data, &info) + elif scalar is double_complex: + lapack.zgecon(norm, &N, <double complex *>LU.data, &N, &normA, + &drcond, <double complex *>work.data, + <double *>rwork.data, &info) - assert info == 0, "Argument error in zgecon" + assert info == 0, "Argument error in gecon" - return rcond + if scalar in single_precision: + return srcond + else: + return drcond # Wrappers for xGGEV