From 44ad9adad064c9935e98f2776430a5772044649c Mon Sep 17 00:00:00 2001
From: Joseph Weston <joseph.weston08@gmail.com>
Date: Thu, 22 Jun 2017 23:01:23 +0200
Subject: [PATCH] switch Xtrevc to fused types

---
 kwant/linalg/decomp_schur.py |   8 +-
 kwant/linalg/lapack.pyx      | 364 +++++++----------------------------
 2 files changed, 71 insertions(+), 301 deletions(-)

diff --git a/kwant/linalg/decomp_schur.py b/kwant/linalg/decomp_schur.py
index 0acb4332..a59eac45 100644
--- a/kwant/linalg/decomp_schur.py
+++ b/kwant/linalg/decomp_schur.py
@@ -257,12 +257,6 @@ def evecs_from_schur(t, q, select=None, left=False, right=True,
 
     ltype, t, q = lapack.prepare_for_lapack(overwrite_tq, t, q)
 
-    if (t.shape[0] != t.shape[1] or q.shape[0] != q.shape[1]
-        or t.shape[0] != q.shape[0]):
-        raise ValueError("Invalid Schur decomposition as input")
-
-    trevc = getattr(lapack, ltype + "trevc")
-
     # check if select is a function or an array
     if select is not None:
         isfun = isarray = True
@@ -288,7 +282,7 @@ def evecs_from_schur(t, q, select=None, left=False, right=True,
     else:
         selectarr = None
 
-    return trevc(t, q, selectarr, left, right)
+    return lapack.trevc(t, q, selectarr, left, right)
 
 
 def gen_schur(a, b, calc_q=True, calc_z=True, calc_ev=True,
diff --git a/kwant/linalg/lapack.pyx b/kwant/linalg/lapack.pyx
index 87d33dbb..a1745f27 100644
--- a/kwant/linalg/lapack.pyx
+++ b/kwant/linalg/lapack.pyx
@@ -14,7 +14,7 @@ __all__ = ['getrf',
            'ggev',
            'gees',
            'trsen',
-           'strevc', 'dtrevc', 'ctrevc', 'ztrevc',
+           'trevc',
            'sgges', 'dgges', 'cgges', 'zgges',
            'stgsen', 'dtgsen', 'ctgsen', 'ztgsen',
            'stgevc', 'dtgevc', 'ctgevc', 'ztgevc',
@@ -651,129 +651,37 @@ def txevc_postprocess(dtype, T, vreal, np.ndarray[l_logical] select):
     return v
 
 
-# Wrappers for xTREVC
-def strevc(np.ndarray[np.float32_t, ndim=2] T,
-           np.ndarray[np.float32_t, ndim=2] Q=None,
-           np.ndarray[l_logical] select=None,
-           left=False, right=True):
+def trevc(np.ndarray[scalar, ndim=2] T,
+          np.ndarray[scalar, ndim=2] Q,
+          np.ndarray[l_logical] select,
+          left=False, right=True):
     cdef l_int N, info, M, MM
     cdef char *side
     cdef char *howmny
-    cdef np.ndarray[np.float32_t, ndim=2] vl_r, vr_r
-    cdef float *vl_r_ptr
-    cdef float *vr_r_ptr
-    cdef np.ndarray[l_logical] select_cpy
-    cdef l_logical *select_ptr
-    cdef np.ndarray[np.float32_t] work
-
-    assert_fortran_mat(T, Q)
-
-    N = T.shape[0]
-    work = np.empty(4*N, dtype = np.float32)
-
-    if left and right:
-        side = "B"
-    elif left:
-        side = "L"
-    elif right:
-        side = "R"
-    else:
-        return
-
-    if select is not None:
-        howmny = "S"
-        MM = select.nonzero()[0].size
-        # Correct for possible additional storage if a single complex
-        # eigenvalue is selected.
-        # For that: Figure out the positions of the 2x2 blocks.
-        cmplxindx = np.diagonal(T, -1).nonzero()[0]
-        for i in cmplxindx:
-            if bool(select[i]) != bool(select[i+1]):
-                MM += 1
-
-        # Select is overwritten in strevc.
-        select_cpy = np.array(select, dtype = logical_dtype,
-                              order = 'F')
-        select_ptr = <l_logical *>select_cpy.data
-    else:
-        MM = N
-        select_ptr = NULL
-        if Q is not None:
-            howmny = "B"
-        else:
-            howmny = "A"
 
-    if left:
-        if Q is not None and select is None:
-            vl_r = np.asfortranarray(Q.copy())
-        else:
-            vl_r = np.empty((N, MM), dtype = np.float32, order='F')
-        vl_r_ptr = <float *>vl_r.data
-    else:
-        vl_r_ptr = NULL
-
-    if right:
-        if Q is not None and select is None:
-            vr_r = np.asfortranarray(Q.copy())
-        else:
-            vr_r = np.empty((N, MM), dtype = np.float32, order='F')
-        vr_r_ptr = <float *>vr_r.data
-    else:
-        vr_r_ptr = NULL
+    # Parameter checks
 
-    lapack.strevc(side, howmny, select_ptr,
-                     &N, <float *>T.data, &N,
-                     vl_r_ptr, &N, vr_r_ptr, &N, &MM, &M,
-                     <float *>work.data, &info)
+    if (T.shape[0] != T.shape[1] or Q.shape[0] != Q.shape[1]
+        or T.shape[0] != Q.shape[0]):
+        raise ValueError("Invalid Schur decomposition as input")
 
-    assert info == 0, "Argument error in strevc"
-    assert MM == M, "Unexpected number of eigenvectors returned in strevc"
+    assert_fortran_mat(T, Q)
 
-    if select is not None and Q is not None:
-        if left:
-            vl_r = np.asfortranarray(np.dot(Q, vl_r))
-        if right:
-            vr_r = np.asfortranarray(np.dot(Q, vr_r))
+    # Workspace allocation
 
-    # If there are complex eigenvalues, we need to postprocess the
-    # eigenvectors.
-    if np.diagonal(T, -1).nonzero()[0].size:
-        if left:
-            vl = txevc_postprocess(np.complex64, T, vl_r, select)
-        if right:
-            vr = txevc_postprocess(np.complex64, T, vr_r, select)
-    else:
-        if left:
-            vl = vl_r
-        if right:
-            vr = vr_r
+    N = T.shape[0]
 
-    if left and right:
-        return (vl, vr)
-    elif left:
-        return vl
+    cdef np.ndarray[scalar] work
+    if scalar in floating:
+        work = np.empty(4 * N, dtype=T.dtype)
     else:
-        return vr
-
-
-def dtrevc(np.ndarray[np.float64_t, ndim=2] T,
-           np.ndarray[np.float64_t, ndim=2] Q=None,
-           np.ndarray[l_logical] select=None,
-           left=False, right=True):
-    cdef l_int N, info, M, MM
-    cdef char *side
-    cdef char *howmny
-    cdef np.ndarray[np.float64_t, ndim=2] vl_r, vr_r
-    cdef double *vl_r_ptr
-    cdef double *vr_r_ptr
-    cdef np.ndarray[l_logical] select_cpy
-    cdef l_logical *select_ptr
-    cdef np.ndarray[np.float64_t] work
-
-    assert_fortran_mat(T, Q)
+        work = np.empty(2 * N, dtype=T.dtype)
 
-    N = T.shape[0]
-    work = np.empty(4*N, dtype = np.float64)
+    cdef np.ndarray rwork = None
+    if scalar is float_complex:
+        rwork = np.empty(N, dtype=np.float32)
+    elif scalar is double_complex:
+        rwork = np.empty(N, dtype=np.float64)
 
     if left and right:
         side = "B"
@@ -784,6 +692,8 @@ def dtrevc(np.ndarray[np.float64_t, ndim=2] T,
     else:
         return
 
+    cdef np.ndarray[l_logical] select_cpy
+    cdef l_logical *select_ptr
     if select is not None:
         howmny = "S"
         MM = select.nonzero()[0].size
@@ -795,7 +705,7 @@ def dtrevc(np.ndarray[np.float64_t, ndim=2] T,
             if bool(select[i]) != bool(select[i+1]):
                 MM += 1
 
-        # Select is overwritten in dtrevc.
+        # Select is overwritten in strevc.
         select_cpy = np.array(select, dtype = logical_dtype,
                               order = 'F')
         select_ptr = <l_logical *>select_cpy.data
@@ -807,31 +717,53 @@ def dtrevc(np.ndarray[np.float64_t, ndim=2] T,
         else:
             howmny = "A"
 
+    cdef np.ndarray[scalar, ndim=2] vl_r = None
+    cdef scalar *vl_r_ptr
     if left:
         if Q is not None and select is None:
             vl_r = np.asfortranarray(Q.copy())
         else:
-            vl_r = np.empty((N, MM), dtype = np.float64, order='F')
-        vl_r_ptr = <double *>vl_r.data
+            vl_r = np.empty((N, MM), dtype=T.dtype, order='F')
+        vl_r_ptr = <scalar *>vl_r.data
     else:
         vl_r_ptr = NULL
 
+    cdef np.ndarray[scalar, ndim=2]  vr_r = None
+    cdef scalar *vr_r_ptr
     if right:
         if Q is not None and select is None:
             vr_r = np.asfortranarray(Q.copy())
         else:
-            vr_r = np.empty((N, MM), dtype = np.float64, order='F')
-        vr_r_ptr = <double *>vr_r.data
+            vr_r = np.empty((N, MM), dtype=T.dtype, order='F')
+        vr_r_ptr = <scalar *>vr_r.data
     else:
         vr_r_ptr = NULL
 
-    lapack.dtrevc(side, howmny, select_ptr,
-                     &N, <double *>T.data, &N,
-                     vl_r_ptr, &N, vr_r_ptr, &N, &MM, &M,
-                     <double *>work.data, &info)
+    # The actual calculation
 
-    assert info == 0, "Argument error in dtrevc"
-    assert MM == M, "Unexpected number of eigenvectors returned in dtrevc"
+    if scalar is float:
+        lapack.strevc(side, howmny, select_ptr,
+                      &N, <float *>T.data, &N,
+                      vl_r_ptr, &N, vr_r_ptr, &N, &MM, &M,
+                      <float *>work.data, &info)
+    elif scalar is double:
+        lapack.dtrevc(side, howmny, select_ptr,
+                      &N, <double *>T.data, &N,
+                      vl_r_ptr, &N, vr_r_ptr, &N, &MM, &M,
+                      <double *>work.data, &info)
+    elif scalar is float_complex:
+        lapack.ctrevc(side, howmny, select_ptr,
+                      &N, <float complex *>T.data, &N,
+                      vl_r_ptr, &N, vr_r_ptr, &N, &MM, &M,
+                      <float complex *>work.data, <float *>rwork.data, &info)
+    elif scalar is double_complex:
+        lapack.ztrevc(side, howmny, select_ptr,
+                      &N, <double complex *>T.data, &N,
+                      vl_r_ptr, &N, vr_r_ptr, &N, &MM, &M,
+                      <double complex *>work.data, <double *>rwork.data, &info)
+
+    assert info == 0, "Argument error in trevc"
+    assert MM == M, "Unexpected number of eigenvectors returned in strevc"
 
     if select is not None and Q is not None:
         if left:
@@ -839,179 +771,23 @@ def dtrevc(np.ndarray[np.float64_t, ndim=2] T,
         if right:
             vr_r = np.asfortranarray(np.dot(Q, vr_r))
 
-    # If there are complex eigenvalues, we need to postprocess the eigenvectors
-    if np.diagonal(T, -1).nonzero()[0].size:
-        if left:
-            vl = txevc_postprocess(np.complex128, T, vl_r, select)
-        if right:
-            vr = txevc_postprocess(np.complex128, T, vr_r, select)
-    else:
-        if left:
-            vl = vl_r
-        if right:
-            vr = vr_r
-
-    if left and right:
-        return (vl, vr)
-    elif left:
-        return vl
-    else:
-        return vr
-
-
-def ctrevc(np.ndarray[np.complex64_t, ndim=2] T,
-           np.ndarray[np.complex64_t, ndim=2] Q=None,
-           np.ndarray[l_logical] select=None,
-           left=False, right=True):
-    cdef l_int N, info, M, MM
-    cdef char *side
-    cdef char *howmny
-    cdef np.ndarray[np.complex64_t, ndim=2] vl, vr
-    cdef float complex *vl_ptr
-    cdef float complex *vr_ptr
-    cdef l_logical *select_ptr
-    cdef np.ndarray[np.complex64_t] work
-    cdef np.ndarray[np.float32_t] rwork
-
-    assert_fortran_mat(T, Q)
-
-    N = T.shape[0]
-    work = np.empty(2*N, dtype = np.complex64)
-    rwork = np.empty(N, dtype = np.float32)
-
-    if left and right:
-        side = "B"
-    elif left:
-        side = "L"
-    elif right:
-        side = "R"
-    else:
-        return
-
-    if select is not None:
-        howmny = "S"
-        MM = select.nonzero()[0].size
-        select_ptr = <l_logical *>select.data
-    else:
-        MM = N
-        select_ptr = NULL
-        if Q is not None:
-            howmny = "B"
-        else:
-            howmny = "A"
-
-    if left:
-        if Q is not None and select is None:
-            vl = np.asfortranarray(Q.copy())
-        else:
-            vl = np.empty((N, MM), dtype = np.complex64, order='F')
-        vl_ptr = <float complex *>vl.data
-    else:
-        vl_ptr = NULL
-
-    if right:
-        if Q is not None and select is None:
-            vr = np.asfortranarray(Q.copy())
-        else:
-            vr = np.empty((N, MM), dtype = np.complex64, order='F')
-        vr_ptr = <float complex *>vr.data
-    else:
-        vr_ptr = NULL
-
-    lapack.ctrevc(side, howmny, select_ptr,
-                     &N, <float complex *>T.data, &N,
-                     vl_ptr, &N, vr_ptr, &N, &MM, &M,
-                     <float complex *>work.data, <float *>rwork.data, &info)
-
-    assert info == 0, "Argument error in ctrevc"
-    assert MM == M, "Unexpected number of eigenvectors returned in ctrevc"
-
-    if select is not None and Q is not None:
-        if left:
-            vl = np.asfortranarray(np.dot(Q, vl))
-        if right:
-            vr = np.asfortranarray(np.dot(Q, vr))
-
-    if left and right:
-        return (vl, vr)
-    elif left:
-        return vl
-    else:
-        return vr
-
-
-def ztrevc(np.ndarray[np.complex128_t, ndim=2] T,
-           np.ndarray[np.complex128_t, ndim=2] Q=None,
-           np.ndarray[l_logical] select=None,
-           left=False, right=True):
-    cdef l_int N, info, M, MM
-    cdef char *side
-    cdef char *howmny
-    cdef np.ndarray[np.complex128_t, ndim=2] vl, vr
-    cdef double complex *vl_ptr
-    cdef double complex *vr_ptr
-    cdef l_logical *select_ptr
-    cdef np.ndarray[np.complex128_t] work
-    cdef np.ndarray[np.float64_t] rwork
-
-    assert_fortran_mat(T, Q)
-
-    N = T.shape[0]
-    work = np.empty(2*N, dtype = np.complex128)
-    rwork = np.empty(N, dtype = np.float64)
-
-    if left and right:
-        side = "B"
-    elif left:
-        side = "L"
-    elif right:
-        side = "R"
-    else:
-        return
-
-    if select is not None:
-        howmny = "S"
-        MM = select.nonzero()[0].size
-        select_ptr = <l_logical *>select.data
-    else:
-        MM = N
-        select_ptr = NULL
-        if Q is not None:
-            howmny = "B"
-        else:
-            howmny = "A"
-
+    cdef np.ndarray vl, vr
     if left:
-        if Q is not None and select is None:
-            vl = np.asfortranarray(Q.copy())
-        else:
-            vl = np.empty((N, MM), dtype = np.complex128, order='F')
-        vl_ptr = <double complex *>vl.data
-    else:
-        vl_ptr = NULL
-
+        vl = vl_r
     if right:
-        if Q is not None and select is None:
-            vr = np.asfortranarray(Q.copy())
+        vr = vr_r
+    if scalar in floating:
+        # If there are complex eigenvalues, we need to postprocess the
+        # eigenvectors.
+        if scalar is float:
+            dtype = np.complex64
         else:
-            vr = np.empty((N, MM), dtype = np.complex128, order='F')
-        vr_ptr = <double complex *>vr.data
-    else:
-        vr_ptr = NULL
-
-    lapack.ztrevc(side, howmny, select_ptr,
-                     &N, <double complex *>T.data, &N,
-                     vl_ptr, &N, vr_ptr, &N, &MM, &M,
-                     <double complex *>work.data, <double *>rwork.data, &info)
-
-    assert info == 0, "Argument error in ztrevc"
-    assert MM == M, "Unexpected number of eigenvectors returned in ztrevc"
-
-    if select is not None and Q is not None:
-        if left:
-            vl = np.asfortranarray(np.dot(Q, vl))
-        if right:
-            vr = np.asfortranarray(np.dot(Q, vr))
+            dtype = np.complex128
+        if np.diagonal(T, -1).nonzero()[0].size:
+            if left:
+                vl = txevc_postprocess(dtype, T, vl_r, select)
+            if right:
+                vr = txevc_postprocess(dtype, T, vr_r, select)
 
     if left and right:
         return (vl, vr)
-- 
GitLab