From 2f3bcdc175a2bca3bb0b1c7898b75974529134e7 Mon Sep 17 00:00:00 2001
From: Joseph Weston <joseph.weston08@gmail.com>
Date: Mon, 26 Jun 2017 12:33:10 +0200
Subject: [PATCH] factor out complex array construction

---
 kwant/linalg/lapack.pyx | 34 ++++++++++++++--------------------
 1 file changed, 14 insertions(+), 20 deletions(-)

diff --git a/kwant/linalg/lapack.pyx b/kwant/linalg/lapack.pyx
index 9f437c00..53fdaab2 100644
--- a/kwant/linalg/lapack.pyx
+++ b/kwant/linalg/lapack.pyx
@@ -75,6 +75,16 @@ def assert_fortran_mat(*mats):
             raise ValueError("Input matrix must be Fortran contiguous")
 
 
+cdef np.ndarray maybe_complex(scalar selector,
+                              np.ndarray real, np.ndarray imag):
+    cdef np.ndarray r
+    r = real
+    if scalar in floating:
+        if imag.nonzero()[0].size:
+            r = real + 1j * imag
+    return r
+
+
 cdef l_int lwork_from_qwork(scalar qwork):
     if scalar in floating:
         return <l_int>qwork
@@ -498,11 +508,7 @@ def gees(np.ndarray[scalar, ndim=2] A, calc_q=True, calc_ev=True):
     assert info == 0, "Argument error in gees"
 
     # Real inputs possibly produce complex output
-    cdef np.ndarray w
-    w = wr
-    if scalar in floating:
-        if wi.nonzero()[0].size:
-            w = wr + 1j * wi
+    cdef np.ndarray w = maybe_complex[scalar](0, wr, wi)
 
     return filter_args((True, calc_q, calc_ev), (A, vs, w))
 
@@ -603,11 +609,7 @@ def trsen(np.ndarray[l_logical] select,
     assert info == 0, "Argument error in trsen"
 
     # Real inputs possibly produce complex output
-    cdef np.ndarray w
-    w = wr
-    if scalar in floating:
-        if wi.nonzero()[0].size:
-            w = wr + 1j * wi
+    cdef np.ndarray w = maybe_complex[scalar](0, wr, wi)
 
     return filter_args((True, Q is not None, calc_ev), (T, Q, w))
 
@@ -938,11 +940,7 @@ def gges(np.ndarray[scalar, ndim=2] A,
     assert info == 0, "Argument error in gges"
 
     # Real inputs possibly produce complex output
-    cdef np.ndarray alpha
-    alpha = alphar
-    if scalar in floating:
-        if alphai.nonzero()[0].size:
-            alpha = alphar + 1j * alphai
+    cdef np.ndarray alpha = maybe_complex[scalar](0, alphar, alphai)
 
     return filter_args((True, True, calc_q, calc_z, calc_ev, calc_ev),
                        (A, B, vsl, vsr, alpha, beta))
@@ -1091,11 +1089,7 @@ def tgsen(np.ndarray[l_logical] select,
     assert info == 0, "Argument error in tgsen"
 
     # Real inputs possibly produce complex output
-    cdef np.ndarray alpha
-    alpha = alphar
-    if scalar in floating:
-        if alphai.nonzero()[0].size:
-            alpha = alphar + 1j * alphai
+    cdef np.ndarray alpha = maybe_complex[scalar](0, alphar, alphai)
 
     return filter_args((True, True, Q is not None, Z is not None,
                         calc_ev, calc_ev),
-- 
GitLab