From b9c05bea7d6971c26f3d7190476c3319288cbb0f Mon Sep 17 00:00:00 2001
From: Joseph Weston <joseph.weston08@gmail.com>
Date: Fri, 23 Jun 2017 23:09:36 +0200
Subject: [PATCH] switch Xtgevc to fused types

---
 kwant/linalg/decomp_schur.py |  17 +-
 kwant/linalg/lapack.pyx      | 404 +++++++----------------------------
 2 files changed, 81 insertions(+), 340 deletions(-)

diff --git a/kwant/linalg/decomp_schur.py b/kwant/linalg/decomp_schur.py
index 0735a4c0..af50f1e8 100644
--- a/kwant/linalg/decomp_schur.py
+++ b/kwant/linalg/decomp_schur.py
@@ -613,27 +613,12 @@ def evecs_from_gen_schur(s, t, q=None, z=None, select=None,
 
     ltype, s, t, q, z = lapack.prepare_for_lapack(overwrite_qz, s, t, q, z)
 
-    if (s.ndim != 2 or t.ndim != 2 or
-        (q is not None and q.ndim != 2) or
-        (z is not None and z.ndim != 2)):
-        raise ValueError("Expect matrices as input")
-
-    if ((s.shape[0] != s.shape[1] or t.shape[0] != t.shape[1] or
-         s.shape[0] != t.shape[0]) or
-        (q is not None and (q.shape[0] != q.shape[1] or
-                            s.shape[0] != q.shape[0])) or
-        (z is not None and (z.shape[0] != z.shape[1] or
-                            s.shape[0] != z.shape[0]))):
-        raise ValueError("Invalid Schur decomposition as input")
-
     if left and q is None:
         raise ValueError("Matrix q must be provided for left eigenvectors")
 
     if right and z is None:
         raise ValueError("Matrix z must be provided for right eigenvectors")
 
-    tgevc = getattr(lapack, ltype + "tgevc")
-
     # Check if select is a function or an array.
     if select is not None:
         isfun = isarray = True
@@ -659,4 +644,4 @@ def evecs_from_gen_schur(s, t, q=None, z=None, select=None,
     else:
         selectarr = None
 
-    return tgevc(s, t, q, z, selectarr, left, right)
+    return lapack.tgevc(s, t, q, z, selectarr, left, right)
diff --git a/kwant/linalg/lapack.pyx b/kwant/linalg/lapack.pyx
index 7049fc3b..e8ddba2e 100644
--- a/kwant/linalg/lapack.pyx
+++ b/kwant/linalg/lapack.pyx
@@ -17,7 +17,7 @@ __all__ = ['getrf',
            'trevc',
            'gges',
            'tgsen',
-           'stgevc', 'dtgevc', 'ctgevc', 'ztgevc',
+           'tgevc',
            'prepare_for_lapack']
 
 import numpy as np
@@ -1108,139 +1108,43 @@ def tgsen(np.ndarray[l_logical] select,
                        (S, T, Q, Z, alpha, beta))
 
 
-# xTGEVC
-def stgevc(np.ndarray[np.float32_t, ndim=2] S,
-           np.ndarray[np.float32_t, ndim=2] T,
-           np.ndarray[np.float32_t, ndim=2] Q=None,
-           np.ndarray[np.float32_t, ndim=2] Z=None,
-           np.ndarray[l_logical] select=None,
-           left=False, right=True):
+def tgevc(np.ndarray[scalar, ndim=2] S,
+          np.ndarray[scalar, ndim=2] T,
+          np.ndarray[scalar, ndim=2] Q,
+          np.ndarray[scalar, ndim=2] Z,
+          np.ndarray[l_logical] select,
+          left=False, right=True):
     cdef l_int N, info, M, MM
-    cdef char *side
-    cdef char *howmny
-    cdef np.ndarray[np.float32_t, ndim=2] vl_r, vr_r
-    cdef float *vl_r_ptr
-    cdef float *vr_r_ptr
-    cdef np.ndarray[l_logical] select_cpy
-    cdef l_logical *select_ptr
-    cdef np.ndarray[np.float32_t] work
-
-    assert_fortran_mat(S, T, Q, Z)
-
-    N = S.shape[0]
-    work = np.empty(6*N, dtype = np.float32)
-
-    if left and right:
-        side = "B"
-    elif left:
-        side = "L"
-    elif right:
-        side = "R"
-    else:
-        return
-
-    backtr = False
-
-    if select is not None:
-        howmny = "S"
-        MM = select.nonzero()[0].size
-        # Correct for possible additional storage if a single complex
-        # eigenvalue is selected.
-        # For that: Figure out the positions of the 2x2 blocks.
-        cmplxindx = np.diagonal(S, -1).nonzero()[0]
-        for i in cmplxindx:
-            if bool(select[i]) != bool(select[i+1]):
-                MM += 1
 
-        # select is overwritten in stgevc
-        select_cpy = np.array(select, dtype = logical_dtype,
-                              order = 'F')
-        select_ptr = <l_logical *>select_cpy.data
-    else:
-        MM = N
-        select_ptr = NULL
-        if ((left and right and Q is not None and Z is not None) or
-            (left and not right and Q is not None) or
-            (right and not left and Z is not None)):
-            howmny = "B"
-            backtr = True
-        else:
-            howmny = "A"
-
-    if left:
-        if backtr:
-            vl_r = Q
-        else:
-            vl_r = np.empty((N, MM), dtype = np.float32, order='F')
-        vl_r_ptr = <float *>vl_r.data
-    else:
-        vl_r_ptr = NULL
-
-    if right:
-        if backtr:
-            vr_r = Z
-        else:
-            vr_r = np.empty((N, MM), dtype = np.float32, order='F')
-        vr_r_ptr = <float *>vr_r.data
-    else:
-        vr_r_ptr = NULL
+    # Check parameters
 
-    lapack.stgevc(side, howmny, select_ptr,
-                     &N, <float *>S.data, &N,
-                     <float *>T.data, &N,
-                     vl_r_ptr, &N, vr_r_ptr, &N, &MM, &M,
-                     <float *>work.data, &info)
+    if ((S.shape[0] != S.shape[1] or T.shape[0] != T.shape[1] or
+         S.shape[0] != T.shape[0]) or
+        (Q is not None and (Q.shape[0] != Q.shape[1] or
+                            S.shape[0] != Q.shape[0])) or
+        (Z is not None and (Z.shape[0] != Z.shape[1] or
+                            S.shape[0] != Z.shape[0]))):
+        raise ValueError("Invalid Schur decomposition as input")
 
-    assert info == 0, "Argument error in stgevc"
-    assert MM == M, "Unexpected number of eigenvectors returned in stgevc"
+    assert_fortran_mat(S, T, Q, Z)
 
-    if not backtr:
-        if left:
-            vl_r = np.asfortranarray(np.dot(Q, vl_r))
-        if right:
-            vr_r = np.asfortranarray(np.dot(Z, vr_r))
+    # Allocate workspaces
 
-    # If there are complex eigenvalues, we need to postprocess the eigenvectors
-    if np.diagonal(S, -1).nonzero()[0].size:
-        if left:
-            vl = txevc_postprocess(np.complex64, S, vl_r, select)
-        if right:
-            vr = txevc_postprocess(np.complex64, S, vr_r, select)
-    else:
-        if left:
-            vl = vl_r
-        if right:
-            vr = vr_r
+    N = S.shape[0]
 
-    if left and right:
-        return (vl, vr)
-    elif left:
-        return vl
+    cdef np.ndarray[scalar] work
+    if scalar in floating:
+        work = np.empty(6 * N, dtype=S.dtype)
     else:
-        return vr
+        work = np.empty(2 * N, dtype=S.dtype)
 
+    cdef np.ndarray rwork = None
+    if scalar is float_complex:
+        rwork = np.empty(2 * N, dtype=np.float32)
+    elif scalar is double_complex:
+        rwork = np.empty(2 * N, dtype=np.float64)
 
-def dtgevc(np.ndarray[np.float64_t, ndim=2] S,
-           np.ndarray[np.float64_t, ndim=2] T,
-           np.ndarray[np.float64_t, ndim=2] Q=None,
-           np.ndarray[np.float64_t, ndim=2] Z=None,
-           np.ndarray[l_logical] select=None,
-           left=False, right=True):
-    cdef l_int N, info, M, MM
     cdef char *side
-    cdef char *howmny
-    cdef np.ndarray[np.float64_t, ndim=2] vl_r, vr_r
-    cdef double *vl_r_ptr
-    cdef double *vr_r_ptr
-    cdef np.ndarray[l_logical] select_cpy
-    cdef l_logical *select_ptr
-    cdef np.ndarray[np.float64_t] work
-
-    assert_fortran_mat(S, T, Q, Z)
-
-    N = S.shape[0]
-    work = np.empty(6*N, dtype = np.float64)
-
     if left and right:
         side = "B"
     elif left:
@@ -1250,8 +1154,11 @@ def dtgevc(np.ndarray[np.float64_t, ndim=2] S,
     else:
         return
 
-    backtr = False
+    cdef l_logical backtr = False
 
+    cdef char *howmny
+    cdef np.ndarray[l_logical] select_cpy = None
+    cdef l_logical *select_ptr
     if select is not None:
         howmny = "S"
         MM = select.nonzero()[0].size
@@ -1263,8 +1170,8 @@ def dtgevc(np.ndarray[np.float64_t, ndim=2] S,
             if bool(select[i]) != bool(select[i+1]):
                 MM += 1
 
-        # select is overwritten in dtgevc
-        select_cpy = np.array(select, dtype = logical_dtype,
+        # select is overwritten in tgevc
+        select_cpy = np.array(select, dtype=logical_dtype,
                               order = 'F')
         select_ptr = <l_logical *>select_cpy.data
     else:
@@ -1278,32 +1185,55 @@ def dtgevc(np.ndarray[np.float64_t, ndim=2] S,
         else:
             howmny = "A"
 
+    cdef np.ndarray[scalar, ndim=2] vl_r
+    cdef scalar *vl_r_ptr
     if left:
         if backtr:
             vl_r = Q
         else:
-            vl_r = np.empty((N, MM), dtype = np.float64, order='F')
-        vl_r_ptr = <double *>vl_r.data
+            vl_r = np.empty((N, MM), dtype=S.dtype, order='F')
+        vl_r_ptr = <scalar *>vl_r.data
     else:
         vl_r_ptr = NULL
 
+    cdef np.ndarray[scalar, ndim=2] vr_r
+    cdef scalar *vr_r_ptr
     if right:
         if backtr:
             vr_r = Z
         else:
-            vr_r = np.empty((N, MM), dtype = np.float64, order='F')
-        vr_r_ptr = <double *>vr_r.data
+            vr_r = np.empty((N, MM), dtype=S.dtype, order='F')
+        vr_r_ptr = <scalar *>vr_r.data
     else:
         vr_r_ptr = NULL
 
-    lapack.dtgevc(side, howmny, select_ptr,
-                     &N, <double *>S.data, &N,
-                     <double *>T.data, &N,
-                     vl_r_ptr, &N, vr_r_ptr, &N, &MM, &M,
-                     <double *>work.data, &info)
+    if scalar is float:
+        lapack.stgevc(side, howmny, select_ptr,
+                      &N, <float *>S.data, &N,
+                      <float *>T.data, &N,
+                      vl_r_ptr, &N, vr_r_ptr, &N, &MM, &M,
+                      <float *>work.data, &info)
+    elif scalar is double:
+        lapack.dtgevc(side, howmny, select_ptr,
+                      &N, <double *>S.data, &N,
+                      <double *>T.data, &N,
+                      vl_r_ptr, &N, vr_r_ptr, &N, &MM, &M,
+                      <double *>work.data, &info)
+    elif scalar is float_complex:
+        lapack.ctgevc(side, howmny, select_ptr,
+                      &N, <float complex *>S.data, &N,
+                      <float complex *>T.data, &N,
+                      vl_r_ptr, &N, vr_r_ptr, &N, &MM, &M,
+                      <float complex *>work.data, <float *>rwork.data, &info)
+    elif scalar is double_complex:
+        lapack.ztgevc(side, howmny, select_ptr,
+                      &N, <double complex *>S.data, &N,
+                      <double complex *>T.data, &N,
+                      vl_r_ptr, &N, vr_r_ptr, &N, &MM, &M,
+                      <double complex *>work.data, <double *>rwork.data, &info)
 
-    assert info == 0, "Argument error in dtgevc"
-    assert MM == M, "Unexpected number of eigenvectors returned in dtgevc"
+    assert info == 0, "Argument error in tgevc"
+    assert MM == M, "Unexpected number of eigenvectors returned in tgevc"
 
     if not backtr:
         if left:
@@ -1311,196 +1241,22 @@ def dtgevc(np.ndarray[np.float64_t, ndim=2] S,
         if right:
             vr_r = np.asfortranarray(np.dot(Z, vr_r))
 
-    # If there are complex eigenvalues, we need to postprocess the
-    # eigenvectors.
-    if np.diagonal(S, -1).nonzero()[0].size:
-        if left:
-            vl = txevc_postprocess(np.complex128, S, vl_r, select)
-        if right:
-            vr = txevc_postprocess(np.complex128, S, vr_r, select)
-    else:
-        if left:
-            vl = vl_r
-        if right:
-            vr = vr_r
-
-    if left and right:
-        return (vl, vr)
-    elif left:
-        return vl
-    else:
-        return vr
-
-
-def ctgevc(np.ndarray[np.complex64_t, ndim=2] S,
-           np.ndarray[np.complex64_t, ndim=2] T,
-           np.ndarray[np.complex64_t, ndim=2] Q=None,
-           np.ndarray[np.complex64_t, ndim=2] Z=None,
-           np.ndarray[l_logical] select=None,
-           left=False, right=True):
-    cdef l_int N, info, M, MM
-    cdef char *side
-    cdef char *howmny
-    cdef np.ndarray[np.complex64_t, ndim=2] vl, vr
-    cdef float complex *vl_ptr
-    cdef float complex *vr_ptr
-    cdef l_logical *select_ptr
-    cdef np.ndarray[np.complex64_t] work
-    cdef np.ndarray[np.float32_t] rwork
-
-    assert_fortran_mat(S, T, Q, Z)
-
-    N = S.shape[0]
-    work = np.empty(2*N, dtype = np.complex64)
-    rwork = np.empty(2*N, dtype = np.float32)
-
-    if left and right:
-        side = "B"
-    elif left:
-        side = "L"
-    elif right:
-        side = "R"
-    else:
-        return
-
-    backtr = False
-
-    if select is not None:
-        howmny = "S"
-        MM = select.nonzero()[0].size
-        select_ptr = <l_logical *>select.data
-    else:
-        MM = N
-        select_ptr = NULL
-        if ((left and right and Q is not None and Z is not None) or
-            (left and not right and Q is not None) or
-            (right and not left and Z is not None)):
-            howmny = "B"
-            backtr = True
-        else:
-            howmny = "A"
-
-    if left:
-        if backtr:
-            vl = Q
-        else:
-            vl = np.empty((N, MM), dtype = np.complex64, order='F')
-        vl_ptr = <float complex *>vl.data
-    else:
-        vl_ptr = NULL
-
-    if right:
-        if backtr:
-            vr = Z
-        else:
-            vr = np.empty((N, MM), dtype = np.complex64, order='F')
-        vr_ptr = <float complex *>vr.data
-    else:
-        vr_ptr = NULL
-
-    lapack.ctgevc(side, howmny, select_ptr,
-                     &N, <float complex *>S.data, &N,
-                     <float complex *>T.data, &N,
-                     vl_ptr, &N, vr_ptr, &N, &MM, &M,
-                     <float complex *>work.data, <float *>rwork.data, &info)
-
-    assert info == 0, "Argument error in ctgevc"
-    assert MM == M, "Unexpected number of eigenvectors returned in ctgevc"
-
-    if not backtr:
-        if left:
-            vl = np.asfortranarray(np.dot(Q, vl))
-        if right:
-            vr = np.asfortranarray(np.dot(Z, vr))
-
-    if left and right:
-        return (vl, vr)
-    elif left:
-        return vl
-    else:
-        return vr
-
-
-def ztgevc(np.ndarray[np.complex128_t, ndim=2] S,
-           np.ndarray[np.complex128_t, ndim=2] T,
-           np.ndarray[np.complex128_t, ndim=2] Q=None,
-           np.ndarray[np.complex128_t, ndim=2] Z=None,
-           np.ndarray[l_logical] select=None,
-           left=False, right=True):
-    cdef l_int N, info, M, MM
-    cdef char *side
-    cdef char *howmny
-    cdef np.ndarray[np.complex128_t, ndim=2] vl, vr
-    cdef double complex *vl_ptr
-    cdef double complex *vr_ptr
-    cdef l_logical *select_ptr
-    cdef np.ndarray[np.complex128_t] work
-    cdef np.ndarray[np.float64_t] rwork
-
-    assert_fortran_mat(S, T, Q, Z)
-
-    N = S.shape[0]
-    work = np.empty(2*N, dtype = np.complex128)
-    rwork = np.empty(2*N, dtype = np.float64)
-
-    if left and right:
-        side = "B"
-    elif left:
-        side = "L"
-    elif right:
-        side = "R"
-    else:
-        return
-
-    backtr = False
-
-    if select is not None:
-        howmny = "S"
-        MM = select.nonzero()[0].size
-        select_ptr = <l_logical *>select.data
-    else:
-        MM = N
-        select_ptr = NULL
-        if ((left and right and Q is not None and Z is not None) or
-            (left and not right and Q is not None) or
-            (right and not left and Z is not None)):
-            howmny = "B"
-            backtr = True
-        else:
-            howmny = "A"
-
+    # If there are complex eigenvalues, we need to postprocess the eigenvectors
+    cdef np.ndarray vl, vr
     if left:
-        if backtr:
-            vl = Q
-        else:
-            vl = np.empty((N, MM), dtype = np.complex128, order='F')
-        vl_ptr = <double complex *>vl.data
-    else:
-        vl_ptr = NULL
-
+        vl = vl_r
     if right:
-        if backtr:
-            vr = Z
+        vr = vr_r
+    if scalar in floating:
+        if scalar is float:
+            dtype = np.complex64
         else:
-            vr = np.empty((N, MM), dtype = np.complex128, order='F')
-        vr_ptr = <double complex *>vr.data
-    else:
-        vr_ptr = NULL
-
-    lapack.ztgevc(side, howmny, select_ptr,
-                     &N, <double complex *>S.data, &N,
-                     <double complex *>T.data, &N,
-                     vl_ptr, &N, vr_ptr, &N, &MM, &M,
-                     <double complex *>work.data, <double *>rwork.data, &info)
-
-    assert info == 0, "Argument error in ztgevc"
-    assert MM == M, "Unexpected number of eigenvectors returned in ztgevc"
-
-    if not backtr:
-        if left:
-            vl = np.asfortranarray(np.dot(Q, vl))
-        if right:
-            vr = np.asfortranarray(np.dot(Z, vr))
+            dtype = np.complex128
+        if np.diagonal(S, -1).nonzero()[0].size:
+            if left:
+                vl = txevc_postprocess(dtype, S, vl_r, select)
+            if right:
+                vr = txevc_postprocess(dtype, S, vr_r, select)
 
     if left and right:
         return (vl, vr)
-- 
GitLab