diff --git a/kwant/linalg/decomp_schur.py b/kwant/linalg/decomp_schur.py index 4accc17ff0f1643c9532ada85a29dfa08a2b67fa..fbc091218c0cc8cd5c9052af20e3c5c8459ecfa3 100644 --- a/kwant/linalg/decomp_schur.py +++ b/kwant/linalg/decomp_schur.py @@ -62,18 +62,8 @@ def schur(a, calc_q=True, calc_ev=True, overwrite_a=False): LinAlgError If the underlying QR iteration fails to converge. """ - ltype, a = lapack.prepare_for_lapack(overwrite_a, a) - - if a.ndim != 2: - raise ValueError("Expect matrix as input") - - if a.shape[0] != a.shape[1]: - raise ValueError("Expect square matrix") - - gees = getattr(lapack, ltype + "gees") - - return gees(a, calc_q, calc_ev) + return lapack.gees(a, calc_q, calc_ev) def convert_r2c_schur(t, q): diff --git a/kwant/linalg/lapack.pyx b/kwant/linalg/lapack.pyx index 19066e9e9074b797b9830e72021282ae9b101a96..ab7f2123a2040dc77e0deb844ffb9a8d2763a93b 100644 --- a/kwant/linalg/lapack.pyx +++ b/kwant/linalg/lapack.pyx @@ -12,7 +12,7 @@ __all__ = ['getrf', 'getrs', 'gecon', 'ggev', - 'sgees', 'dgees', 'cgees', 'zgees', + 'gees', 'strsen', 'dtrsen', 'ctrsen', 'ztrsen', 'strevc', 'dtrevc', 'ctrevc', 'ztrevc', 'sgges', 'dgges', 'cgges', 'zgges', @@ -399,207 +399,113 @@ def ggev(np.ndarray[scalar, ndim=2] A, np.ndarray[scalar, ndim=2] B, return filter_args((True, True, left, right), (alpha, beta, vl, vr)) -# Wrapper for xGEES -def sgees(np.ndarray[np.float32_t, ndim=2] A, - calc_q=True, calc_ev=True): +def gees(np.ndarray[scalar, ndim=2] A, calc_q=True, calc_ev=True): cdef l_int N, lwork, sdim, info - cdef char *jobvs - cdef float *vs_ptr - cdef float qwork - cdef np.ndarray[np.float32_t, ndim=2] vs - cdef np.ndarray[np.float32_t] wr, wi, work assert_fortran_mat(A) - N = A.shape[0] - wr = np.empty(N, dtype = np.float32) - wi = np.empty(N, dtype = np.float32) + if A.ndim != 2: + raise ValueError("Expect matrix as input") - if calc_q: - vs = np.empty((N,N), dtype = np.float32, order='F') - vs_ptr = <float *>vs.data - jobvs = "V" - else: - vs_ptr = NULL - jobvs = "N" - - # workspace query - lwork = -1 - lapack.sgees(jobvs, "N", NULL, &N, <float *>A.data, &N, - &sdim, <float *>wr.data, <float *>wi.data, vs_ptr, &N, - &qwork, &lwork, NULL, &info) - - assert info == 0, "Argument error in sgees" - - lwork = <int>qwork - work = np.empty(lwork, dtype = np.float32) - - # Now the real calculation - lapack.sgees(jobvs, "N", NULL, &N, <float *>A.data, &N, - &sdim, <float *>wr.data, <float *>wi.data, vs_ptr, &N, - <float *>work.data, &lwork, NULL, &info) + if A.shape[0] != A.shape[1]: + raise ValueError("Expect square matrix") - if info > 0: - raise LinAlgError("QR iteration failed to converge in sgees") + # Allocate workspaces - assert info == 0, "Argument error in sgees" + N = A.shape[0] - if wi.nonzero()[0].size: - w = wr + 1j * wi + cdef np.ndarray[scalar] wr, wi + if scalar in cmplx: + wr = np.empty(N, dtype=A.dtype) + wi = None else: - w = wr - - return filter_args((True, calc_q, calc_ev), (A, vs, w)) + wr = np.empty(N, dtype=A.dtype) + wi = np.empty(N, dtype=A.dtype) + cdef np.ndarray rwork + if scalar is float_complex: + rwork = np.empty(N, dtype=np.float32) + elif scalar is double_complex: + rwork = np.empty(N, dtype=np.float64) -def dgees(np.ndarray[np.float64_t, ndim=2] A, - calc_q=True, calc_ev=True): - cdef l_int N, lwork, sdim, info cdef char *jobvs - cdef double *vs_ptr - cdef double qwork - cdef np.ndarray[np.float64_t, ndim=2] vs - cdef np.ndarray[np.float64_t] wr, wi, work - - assert_fortran_mat(A) - - N = A.shape[0] - wr = np.empty(N, dtype = np.float64) - wi = np.empty(N, dtype = np.float64) - + cdef scalar *vs_ptr + cdef np.ndarray[scalar, ndim=2] vs if calc_q: - vs = np.empty((N,N), dtype = np.float64, order='F') - vs_ptr = <double *>vs.data + vs = np.empty((N,N), dtype=A.dtype, order='F') + vs_ptr = <scalar *>vs.data jobvs = "V" else: + vs = None vs_ptr = NULL jobvs = "N" - # workspace query + # Workspace query + # Xgees expects &qwork as a <scalar *> (even though it's an integer) lwork = -1 - lapack.dgees(jobvs, "N", NULL, &N, <double *>A.data, &N, - &sdim, <double *>wr.data, <double *>wi.data, vs_ptr, &N, - &qwork, &lwork, NULL, &info) - - assert info == 0, "Argument error in dgees" - - lwork = <int>qwork - work = np.empty(lwork, dtype = np.float64) - - # Now the real calculation - lapack.dgees(jobvs, "N", NULL, &N, <double *>A.data, &N, - &sdim, <double *>wr.data, <double *>wi.data, vs_ptr, &N, - <double *>work.data, &lwork, NULL, &info) - - if info > 0: - raise LinAlgError("QR iteration failed to converge in dgees") - - assert info == 0, "Argument error in dgees" - - if wi.nonzero()[0].size: - w = wr + 1j * wi - else: - w = wr - - return filter_args((True, calc_q, calc_ev), (A, vs, w)) - - -def cgees(np.ndarray[np.complex64_t, ndim=2] A, - calc_q=True, calc_ev=True): - cdef l_int N, lwork, sdim, info - cdef char *jobvs - cdef float complex *vs_ptr - cdef float complex qwork - cdef np.ndarray[np.complex64_t, ndim=2] vs - cdef np.ndarray[np.complex64_t] w, work - cdef np.ndarray[np.float32_t] rwork + cdef scalar qwork - assert_fortran_mat(A) + if scalar is float: + lapack.sgees(jobvs, "N", NULL, &N, <float *>A.data, &N, + &sdim, <float *>wr.data, <float *>wi.data, vs_ptr, &N, + &qwork, &lwork, NULL, &info) + elif scalar is double: + lapack.dgees(jobvs, "N", NULL, &N, <double *>A.data, &N, + &sdim, <double *>wr.data, <double *>wi.data, vs_ptr, &N, + &qwork, &lwork, NULL, &info) + elif scalar is float_complex: + lapack.cgees(jobvs, "N", NULL, &N, <float complex *>A.data, &N, + &sdim, <float complex *>wr.data, vs_ptr, &N, + &qwork, &lwork, <float *>rwork.data, NULL, &info) + elif scalar is double_complex: + lapack.zgees(jobvs, "N", NULL, &N, <double complex *>A.data, &N, + &sdim, <double complex *>wr.data, vs_ptr, &N, + &qwork, &lwork, <double *>rwork.data, NULL, &info) - N = A.shape[0] - w = np.empty(N, dtype = np.complex64) - rwork = np.empty(N, dtype = np.float32) + assert info == 0, "Argument error in sgees" - if calc_q: - vs = np.empty((N,N), dtype = np.complex64, order='F') - vs_ptr = <float complex *>vs.data - jobvs = "V" + if scalar in floating: + lwork = <l_int>qwork else: - vs_ptr = NULL - jobvs = "N" - - # workspace query - lwork = -1 - lapack.cgees(jobvs, "N", NULL, &N, <float complex *>A.data, &N, - &sdim, <float complex *>w.data, vs_ptr, &N, - &qwork, &lwork, <float *>rwork.data, NULL, &info) - - assert info == 0, "Argument error in cgees" + lwork = <l_int>qwork.real + cdef np.ndarray[scalar] work = np.empty(lwork, dtype=A.dtype) - lwork = <int>qwork.real - work = np.empty(lwork, dtype = np.complex64) + # The actual calculation - # Now the real calculation - lapack.cgees(jobvs, "N", NULL, &N, <float complex *>A.data, &N, - &sdim, <float complex *>w.data, vs_ptr, &N, - <float complex *>work.data, &lwork, - <float *>rwork.data, NULL, &info) + if scalar is float: + lapack.sgees(jobvs, "N", NULL, &N, <float *>A.data, &N, + &sdim, <float *>wr.data, <float *>wi.data, vs_ptr, &N, + <float *>work.data, &lwork, NULL, &info) + elif scalar is double: + lapack.dgees(jobvs, "N", NULL, &N, <double *>A.data, &N, + &sdim, <double *>wr.data, <double *>wi.data, vs_ptr, &N, + <double *>work.data, &lwork, NULL, &info) + elif scalar is float_complex: + lapack.cgees(jobvs, "N", NULL, &N, <float complex *>A.data, &N, + &sdim, <float complex *>wr.data, vs_ptr, &N, + <float complex *>work.data, &lwork, + <float *>rwork.data, NULL, &info) + elif scalar is double_complex: + lapack.zgees(jobvs, "N", NULL, &N, <double complex *>A.data, &N, + &sdim, <double complex *>wr.data, vs_ptr, &N, + <double complex *>work.data, &lwork, + <double *>rwork.data, NULL, &info) if info > 0: - raise LinAlgError("QR iteration failed to converge in cgees") + raise LinAlgError("QR iteration failed to converge in gees") - assert info == 0, "Argument error in cgees" - - return filter_args((True, calc_q, calc_ev), (A, vs, w)) + assert info == 0, "Argument error in gees" + # Real inputs possibly produce complex output + cdef np.ndarray w + w = wr + if scalar in floating: + if wi.nonzero()[0].size: + w = wr + 1j * wi -def zgees(np.ndarray[np.complex128_t, ndim=2] A, - calc_q=True, calc_ev=True): - cdef l_int N, lwork, sdim, info - cdef char *jobvs - cdef double complex *vs_ptr - cdef double complex qwork - cdef np.ndarray[np.complex128_t, ndim=2] vs - cdef np.ndarray[np.complex128_t] w, work - cdef np.ndarray[np.float64_t] rwork - - assert_fortran_mat(A) - - N = A.shape[0] - w = np.empty(N, dtype = np.complex128) - rwork = np.empty(N, dtype = np.float64) - - if calc_q: - vs = np.empty((N,N), dtype = np.complex128, order='F') - vs_ptr = <double complex *>vs.data - jobvs = "V" - else: - vs_ptr = NULL - jobvs = "N" - - # workspace query - lwork = -1 - lapack.zgees(jobvs, "N", NULL, &N, <double complex *>A.data, &N, - &sdim, <double complex *>w.data, vs_ptr, &N, - &qwork, &lwork, <double *>rwork.data, NULL, &info) - - assert info == 0, "Argument error in zgees" - - lwork = <int>qwork.real - work = np.empty(lwork, dtype = np.complex128) - - # Now the real calculation - lapack.zgees(jobvs, "N", NULL, &N, <double complex *>A.data, &N, - &sdim, <double complex *>w.data, vs_ptr, &N, - <double complex *>work.data, &lwork, - <double *>rwork.data, NULL, &info) - - if info > 0: - raise LinAlgError("QR iteration failed to converge in zgees") + return filter_args((True, calc_q, calc_ev), (A, vs, w)) - assert info == 0, "Argument error in zgees" - return filter_args((True, calc_q, calc_ev), (A, vs, w)) # Wrapper for xTRSEN def strsen(np.ndarray[l_logical] select,