From 2f3bcdc175a2bca3bb0b1c7898b75974529134e7 Mon Sep 17 00:00:00 2001 From: Joseph Weston <joseph.weston08@gmail.com> Date: Mon, 26 Jun 2017 12:33:10 +0200 Subject: [PATCH] factor out complex array construction --- kwant/linalg/lapack.pyx | 34 ++++++++++++++-------------------- 1 file changed, 14 insertions(+), 20 deletions(-) diff --git a/kwant/linalg/lapack.pyx b/kwant/linalg/lapack.pyx index 9f437c00..53fdaab2 100644 --- a/kwant/linalg/lapack.pyx +++ b/kwant/linalg/lapack.pyx @@ -75,6 +75,16 @@ def assert_fortran_mat(*mats): raise ValueError("Input matrix must be Fortran contiguous") +cdef np.ndarray maybe_complex(scalar selector, + np.ndarray real, np.ndarray imag): + cdef np.ndarray r + r = real + if scalar in floating: + if imag.nonzero()[0].size: + r = real + 1j * imag + return r + + cdef l_int lwork_from_qwork(scalar qwork): if scalar in floating: return <l_int>qwork @@ -498,11 +508,7 @@ def gees(np.ndarray[scalar, ndim=2] A, calc_q=True, calc_ev=True): assert info == 0, "Argument error in gees" # Real inputs possibly produce complex output - cdef np.ndarray w - w = wr - if scalar in floating: - if wi.nonzero()[0].size: - w = wr + 1j * wi + cdef np.ndarray w = maybe_complex[scalar](0, wr, wi) return filter_args((True, calc_q, calc_ev), (A, vs, w)) @@ -603,11 +609,7 @@ def trsen(np.ndarray[l_logical] select, assert info == 0, "Argument error in trsen" # Real inputs possibly produce complex output - cdef np.ndarray w - w = wr - if scalar in floating: - if wi.nonzero()[0].size: - w = wr + 1j * wi + cdef np.ndarray w = maybe_complex[scalar](0, wr, wi) return filter_args((True, Q is not None, calc_ev), (T, Q, w)) @@ -938,11 +940,7 @@ def gges(np.ndarray[scalar, ndim=2] A, assert info == 0, "Argument error in gges" # Real inputs possibly produce complex output - cdef np.ndarray alpha - alpha = alphar - if scalar in floating: - if alphai.nonzero()[0].size: - alpha = alphar + 1j * alphai + cdef np.ndarray alpha = maybe_complex[scalar](0, alphar, alphai) return filter_args((True, True, calc_q, calc_z, calc_ev, calc_ev), (A, B, vsl, vsr, alpha, beta)) @@ -1091,11 +1089,7 @@ def tgsen(np.ndarray[l_logical] select, assert info == 0, "Argument error in tgsen" # Real inputs possibly produce complex output - cdef np.ndarray alpha - alpha = alphar - if scalar in floating: - if alphai.nonzero()[0].size: - alpha = alphar + 1j * alphai + cdef np.ndarray alpha = maybe_complex[scalar](0, alphar, alphai) return filter_args((True, True, Q is not None, Z is not None, calc_ev, calc_ev), -- GitLab