Skip to content
Snippets Groups Projects
Commit c1da0799 authored by Joseph Weston's avatar Joseph Weston
Browse files

switch Xgges to fused types

parent af3b9c58
No related branches found
No related tags found
No related merge requests found
......@@ -346,21 +346,8 @@ def gen_schur(a, b, calc_q=True, calc_z=True, calc_ev=True,
LinAlError
If the underlying QZ iteration fails to converge.
"""
ltype, a, b = lapack.prepare_for_lapack(overwrite_ab, a, b)
if a.ndim != 2 or b.ndim != 2:
raise ValueError("Expect matrices as input")
if a.shape[0] != a.shape[1]:
raise ValueError("Expect square matrix a")
if a.shape[0] != b.shape[0] or a.shape[0] != b.shape[1]:
raise ValueError("Shape of b is incompatible to matrix a")
gges = getattr(lapack, ltype + "gges")
return gges(a, b, calc_q, calc_z, calc_ev)
return lapack.gges(a, b, calc_q, calc_z, calc_ev)
def order_gen_schur(select, s, t, q=None, z=None, calc_ev=True,
......
......@@ -15,7 +15,7 @@ __all__ = ['getrf',
'gees',
'trsen',
'trevc',
'sgges', 'dgges', 'cgges', 'zgges',
'gges',
'stgsen', 'dtgsen', 'ctgsen', 'ztgsen',
'stgevc', 'dtgevc', 'ctgevc', 'ztgevc',
'prepare_for_lapack']
......@@ -797,292 +797,156 @@ def trevc(np.ndarray[scalar, ndim=2] T,
return vr
# wrappers for xGGES
def sgges(np.ndarray[np.float32_t, ndim=2] A,
np.ndarray[np.float32_t, ndim=2] B,
def gges(np.ndarray[scalar, ndim=2] A,
np.ndarray[scalar, ndim=2] B,
calc_q=True, calc_z=True, calc_ev=True):
cdef l_int N, lwork, sdim, info
cdef char *jobvsl
cdef char *jobvsr
cdef float *vsl_ptr
cdef float *vsr_ptr
cdef float qwork
cdef np.ndarray[np.float32_t, ndim=2] vsl, vsr
cdef np.ndarray[np.float32_t] alphar, alphai, beta, work
assert_fortran_mat(A, B)
N = A.shape[0]
alphar = np.empty(N, dtype = np.float32)
alphai = np.empty(N, dtype = np.float32)
beta = np.empty(N, dtype = np.float32)
if calc_q:
vsl = np.empty((N,N), dtype = np.float32, order='F')
vsl_ptr = <float *>vsl.data
jobvsl = "V"
else:
vsl = None
vsl_ptr = NULL
jobvsl = "N"
if calc_z:
vsr = np.empty((N,N), dtype = np.float32, order='F')
vsr_ptr = <float *>vsr.data
jobvsr = "V"
else:
vsr = None
vsr_ptr = NULL
jobvsr = "N"
cdef l_int N, sdim, info
# workspace query
lwork = -1
lapack.sgges(jobvsl, jobvsr, "N", NULL,
&N, <float *>A.data, &N,
<float *>B.data, &N, &sdim,
<float *>alphar.data, <float *>alphai.data,
<float *>beta.data,
vsl_ptr, &N, vsr_ptr, &N,
&qwork, &lwork, NULL, &info)
assert info == 0, "Argument error in zgees"
lwork = <int>qwork
work = np.empty(lwork, dtype = np.float32)
# Now the real calculation
lapack.sgges(jobvsl, jobvsr, "N", NULL,
&N, <float *>A.data, &N,
<float *>B.data, &N, &sdim,
<float *>alphar.data, <float *>alphai.data,
<float *>beta.data,
vsl_ptr, &N, vsr_ptr, &N,
<float *>work.data, &lwork, NULL, &info)
if info > 0:
raise LinAlgError("QZ iteration failed to converge in sgges")
assert info == 0, "Argument error in zgees"
if alphai.nonzero()[0].size:
alpha = alphar + 1j * alphai
else:
alpha = alphar
return filter_args((True, True, calc_q, calc_z, calc_ev, calc_ev),
(A, B, vsl, vsr, alpha, beta))
def dgges(np.ndarray[np.float64_t, ndim=2] A,
np.ndarray[np.float64_t, ndim=2] B,
calc_q=True, calc_z=True, calc_ev=True):
cdef l_int N, lwork, sdim, info
cdef char *jobvsl
cdef char *jobvsr
cdef double *vsl_ptr
cdef double *vsr_ptr
cdef double qwork
cdef np.ndarray[np.float64_t, ndim=2] vsl, vsr
cdef np.ndarray[np.float64_t] alphar, alphai, beta, work
# Check parameters
assert_fortran_mat(A, B)
N = A.shape[0]
alphar = np.empty(N, dtype = np.float64)
alphai = np.empty(N, dtype = np.float64)
beta = np.empty(N, dtype = np.float64)
if calc_q:
vsl = np.empty((N,N), dtype = np.float64, order='F')
vsl_ptr = <double *>vsl.data
jobvsl = "V"
else:
vsl = None
vsl_ptr = NULL
jobvsl = "N"
if calc_z:
vsr = np.empty((N,N), dtype = np.float64, order='F')
vsr_ptr = <double *>vsr.data
jobvsr = "V"
else:
vsr = None
vsr_ptr = NULL
jobvsr = "N"
# workspace query
lwork = -1
lapack.dgges(jobvsl, jobvsr, "N", NULL,
&N, <double *>A.data, &N,
<double *>B.data, &N, &sdim,
<double *>alphar.data, <double *>alphai.data,
<double *>beta.data,
vsl_ptr, &N, vsr_ptr, &N,
&qwork, &lwork, NULL, &info)
if A.shape[0] != B.shape[1]:
raise ValueError("Expect square matrix A")
assert info == 0, "Argument error in zgees"
if A.shape[0] != B.shape[0] or A.shape[0] != B.shape[1]:
raise ValueError("Shape of B is incompatible with matrix A")
lwork = <int>qwork
work = np.empty(lwork, dtype = np.float64)
# Now the real calculation
lapack.dgges(jobvsl, jobvsr, "N", NULL,
&N, <double *>A.data, &N,
<double *>B.data, &N, &sdim,
<double *>alphar.data, <double *>alphai.data,
<double *>beta.data,
vsl_ptr, &N, vsr_ptr, &N,
<double *>work.data, &lwork, NULL, &info)
if info > 0:
raise LinAlgError("QZ iteration failed to converge in dgges")
# Allocate workspaces
assert info == 0, "Argument error in zgees"
N = A.shape[0]
if alphai.nonzero()[0].size:
alpha = alphar + 1j * alphai
cdef np.ndarray[scalar] alphar, alphai
if scalar in cmplx:
alphar = np.empty(N, dtype=A.dtype)
alphai = None
else:
alpha = alphar
alphar = np.empty(N, dtype=A.dtype)
alphai = np.empty(N, dtype=A.dtype)
return filter_args((True, True, calc_q, calc_z, calc_ev, calc_ev),
(A, B, vsl, vsr, alpha, beta))
cdef np.ndarray[scalar] beta = np.empty(N, dtype=A.dtype)
cdef np.ndarray rwork = None
if scalar is float_complex:
rwork = np.empty(8 * N, dtype=np.float32)
elif scalar is double_complex:
rwork = np.empty(8 * N, dtype=np.float64)
def cgges(np.ndarray[np.complex64_t, ndim=2] A,
np.ndarray[np.complex64_t, ndim=2] B,
calc_q=True, calc_z=True, calc_ev=True):
cdef l_int N, lwork, sdim, info
cdef char *jobvsl
cdef char *jobvsr
cdef float complex *vsl_ptr
cdef float complex *vsr_ptr
cdef float complex qwork
cdef np.ndarray[np.complex64_t, ndim=2] vsl, vsr
cdef np.ndarray[np.complex64_t] alpha, beta, work
cdef np.ndarray[np.float32_t] rwork
assert_fortran_mat(A, B)
N = A.shape[0]
alpha = np.empty(N, dtype = np.complex64)
beta = np.empty(N, dtype = np.complex64)
rwork = np.empty(8*N, dtype = np.float32)
cdef scalar *vsl_ptr
cdef np.ndarray[scalar, ndim=2] vsl
if calc_q:
vsl = np.empty((N,N), dtype = np.complex64, order='F')
vsl_ptr = <float complex *>vsl.data
vsl = np.empty((N,N), dtype=A.dtype, order='F')
vsl_ptr = <scalar *>vsl.data
jobvsl = "V"
else:
vsl = None
vsl_ptr = NULL
jobvsl = "N"
cdef char *jobvsr
cdef scalar *vsr_ptr
cdef np.ndarray[scalar, ndim=2] vsr
if calc_z:
vsr = np.empty((N,N), dtype = np.complex64, order='F')
vsr_ptr = <float complex *>vsr.data
vsr = np.empty((N,N), dtype=A.dtype, order='F')
vsr_ptr = <scalar *>vsr.data
jobvsr = "V"
else:
vsr = None
vsr_ptr = NULL
jobvsr = "N"
# workspace query
lwork = -1
lapack.cgges(jobvsl, jobvsr, "N", NULL,
&N, <float complex *>A.data, &N,
<float complex *>B.data, &N, &sdim,
<float complex *>alpha.data, <float complex *>beta.data,
vsl_ptr, &N, vsr_ptr, &N,
&qwork, &lwork, <float *>rwork.data, NULL, &info)
assert info == 0, "Argument error in zgees"
lwork = <int>qwork.real
work = np.empty(lwork, dtype = np.complex64)
# Now the real calculation
lapack.cgges(jobvsl, jobvsr, "N", NULL,
&N, <float complex *>A.data, &N,
<float complex *>B.data, &N, &sdim,
<float complex *>alpha.data, <float complex *>beta.data,
vsl_ptr, &N, vsr_ptr, &N,
<float complex *>work.data, &lwork,
<float *>rwork.data, NULL, &info)
if info > 0:
raise LinAlgError("QZ iteration failed to converge in cgges")
assert info == 0, "Argument error in zgees"
return filter_args((True, True, calc_q, calc_z, calc_ev, calc_ev),
(A, B, vsl, vsr, alpha, beta))
def zgges(np.ndarray[np.complex128_t, ndim=2] A,
np.ndarray[np.complex128_t, ndim=2] B,
calc_q=True, calc_z=True, calc_ev=True):
cdef l_int N, lwork, sdim, info
cdef char *jobvsl
cdef char *jobvsr
cdef double complex *vsl_ptr
cdef double complex *vsr_ptr
cdef double complex qwork
cdef np.ndarray[np.complex128_t, ndim=2] vsl, vsr
cdef np.ndarray[np.complex128_t] alpha, beta, work
cdef np.ndarray[np.float64_t] rwork
assert_fortran_mat(A, B)
# Workspace query
# Xgges expects &qwork as a <scalar *> (even though it's an integer)
cdef l_int lwork = -1
cdef scalar qwork
N = A.shape[0]
alpha = np.empty(N, dtype = np.complex128)
beta = np.empty(N, dtype = np.complex128)
rwork = np.empty(8*N, dtype = np.float64)
if scalar is float:
lapack.sgges(jobvsl, jobvsr, "N", NULL,
&N, <float *>A.data, &N,
<float *>B.data, &N, &sdim,
<float *>alphar.data, <float *>alphai.data,
<float *>beta.data,
vsl_ptr, &N, vsr_ptr, &N,
&qwork, &lwork, NULL, &info)
elif scalar is double:
lapack.dgges(jobvsl, jobvsr, "N", NULL,
&N, <double *>A.data, &N,
<double *>B.data, &N, &sdim,
<double *>alphar.data, <double *>alphai.data,
<double *>beta.data,
vsl_ptr, &N, vsr_ptr, &N,
&qwork, &lwork, NULL, &info)
elif scalar is float_complex:
lapack.cgges(jobvsl, jobvsr, "N", NULL,
&N, <float complex *>A.data, &N,
<float complex *>B.data, &N, &sdim,
<float complex *>alphar.data, <float complex *>beta.data,
vsl_ptr, &N, vsr_ptr, &N,
&qwork, &lwork, <float *>rwork.data, NULL, &info)
elif scalar is double_complex:
lapack.zgges(jobvsl, jobvsr, "N", NULL,
&N, <double complex *>A.data, &N,
<double complex *>B.data, &N, &sdim,
<double complex *>alphar.data, <double complex *>beta.data,
vsl_ptr, &N, vsr_ptr, &N,
&qwork, &lwork, <double *>rwork.data, NULL, &info)
if calc_q:
vsl = np.empty((N,N), dtype = np.complex128, order='F')
vsl_ptr = <double complex *>vsl.data
jobvsl = "V"
else:
vsl = None
vsl_ptr = NULL
jobvsl = "N"
assert info == 0, "Argument error in gges"
if calc_z:
vsr = np.empty((N,N), dtype = np.complex128, order='F')
vsr_ptr = <double complex *>vsr.data
jobvsr = "V"
if scalar in floating:
lwork = <l_int>qwork
else:
vsr = None
vsr_ptr = NULL
jobvsr = "N"
# workspace query
lwork = -1
lapack.zgges(jobvsl, jobvsr, "N", NULL,
&N, <double complex *>A.data, &N,
<double complex *>B.data, &N, &sdim,
<double complex *>alpha.data, <double complex *>beta.data,
vsl_ptr, &N, vsr_ptr, &N,
&qwork, &lwork, <double *>rwork.data, NULL, &info)
assert info == 0, "Argument error in zgees"
lwork = <l_int>qwork.real
cdef np.ndarray[scalar] work = np.empty(lwork, dtype=A.dtype)
lwork = <int>qwork.real
work = np.empty(lwork, dtype = np.complex128)
# The actual calculation
# Now the real calculation
lapack.zgges(jobvsl, jobvsr, "N", NULL,
&N, <double complex *>A.data, &N,
<double complex *>B.data, &N, &sdim,
<double complex *>alpha.data, <double complex *>beta.data,
vsl_ptr, &N, vsr_ptr, &N,
<double complex *>work.data, &lwork,
<double *>rwork.data, NULL, &info)
if scalar is float:
lapack.sgges(jobvsl, jobvsr, "N", NULL,
&N, <float *>A.data, &N,
<float *>B.data, &N, &sdim,
<float *>alphar.data, <float *>alphai.data,
<float *>beta.data,
vsl_ptr, &N, vsr_ptr, &N,
<float *>work.data, &lwork, NULL, &info)
elif scalar is double:
lapack.dgges(jobvsl, jobvsr, "N", NULL,
&N, <double *>A.data, &N,
<double *>B.data, &N, &sdim,
<double *>alphar.data, <double *>alphai.data,
<double *>beta.data,
vsl_ptr, &N, vsr_ptr, &N,
<double *>work.data, &lwork, NULL, &info)
elif scalar is float_complex:
lapack.cgges(jobvsl, jobvsr, "N", NULL,
&N, <float complex *>A.data, &N,
<float complex *>B.data, &N, &sdim,
<float complex *>alphar.data, <float complex *>beta.data,
vsl_ptr, &N, vsr_ptr, &N,
<float complex *>work.data, &lwork,
<float *>rwork.data, NULL, &info)
elif scalar is double_complex:
lapack.zgges(jobvsl, jobvsr, "N", NULL,
&N, <double complex *>A.data, &N,
<double complex *>B.data, &N, &sdim,
<double complex *>alphar.data, <double complex *>beta.data,
vsl_ptr, &N, vsr_ptr, &N,
<double complex *>work.data, &lwork,
<double *>rwork.data, NULL, &info)
if info > 0:
raise LinAlgError("QZ iteration failed to converge in zgges")
raise LinAlgError("QZ iteration failed to converge in sgges")
assert info == 0, "Argument error in gges"
assert info == 0, "Argument error in zgees"
cdef np.ndarray alpha
alpha = alphar
if scalar in floating:
if alphai.nonzero()[0].size:
alpha = alphar + 1j * alphai
else:
alpha = alphar
return filter_args((True, True, calc_q, calc_z, calc_ev, calc_ev),
(A, B, vsl, vsr, alpha, beta))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment