Skip to content
Snippets Groups Projects
Commit 8a19196e authored by Joseph Weston's avatar Joseph Weston
Browse files

switch Xggev to fused types

parent 658c4230
No related branches found
No related tags found
1 merge request!149Remove lapack wrappers
......@@ -50,18 +50,5 @@ def gen_eig(a, b, left=False, right=True, overwrite_ab=False):
The right eigenvector corresponding to the eigenvalue
``alpha[i]/beta[i]`` is the column ``vr[:,i]``.
"""
ltype, a, b = lapack.prepare_for_lapack(overwrite_ab, a, b)
if a.ndim != 2 or b.ndim != 2:
raise ValueError("gen_eig requires both a and be to be matrices")
if a.shape[0] != a.shape[1]:
raise ValueError("gen_eig requires square matrix input")
if b.shape[0] != a.shape[0] or b.shape[1] != a.shape[1]:
raise ValueError("gen_eig requires a and be to have the same shape")
ggev = getattr(lapack, ltype + "ggev")
return ggev(a, b, left, right)
return lapack.ggev(a, b, left, right)
......@@ -11,7 +11,7 @@
__all__ = ['getrf',
'getrs',
'gecon',
'sggev', 'dggev', 'cggev', 'zggev',
'ggev',
'sgees', 'dgees', 'cgees', 'zgees',
'strsen', 'dtrsen', 'ctrsen', 'ztrsen',
'strevc', 'dtrevc', 'ctrevc', 'ztrevc',
......@@ -213,7 +213,6 @@ def gecon(np.ndarray[scalar, ndim=2] LU, double normA, char *norm = b"1"):
else:
return drcond
# Wrappers for xGGEV
# Helper function for xGGEV
def ggev_postprocess(dtype, alphar, alphai, vl_r=None, vr_r=None):
......@@ -248,281 +247,154 @@ def ggev_postprocess(dtype, alphar, alphai, vl_r=None, vr_r=None):
return (alpha, vl, vr)
def sggev(np.ndarray[np.float32_t, ndim=2] A,
np.ndarray[np.float32_t, ndim=2] B,
left=False, right=True):
def ggev(np.ndarray[scalar, ndim=2] A, np.ndarray[scalar, ndim=2] B,
left=False, right=True):
cdef l_int N, info, lwork
cdef char *jobvl
cdef char *jobvr
cdef np.ndarray[np.float32_t, ndim=2] vl_r, vr_r
cdef float *vl_ptr
cdef float *vr_ptr
cdef float qwork
cdef np.ndarray[np.float32_t, ndim=1] work, alphar, alphai, beta
assert_fortran_mat(A, B)
N = A.shape[0]
alphar = np.empty(N, dtype = np.float32)
alphai = np.empty(N, dtype = np.float32)
beta = np.empty(N, dtype = np.float32)
if left:
vl_r = np.empty((N,N), dtype = np.float32, order='F')
vl_ptr = <float *>vl_r.data
jobvl = "V"
else:
vl_r = None
vl_ptr = NULL
jobvl = "N"
if right:
vr_r = np.empty((N,N), dtype = np.float32, order='F')
vr_ptr = <float *>vr_r.data
jobvr = "V"
else:
vr_r = None
vr_ptr = NULL
jobvr = "N"
# workspace query
lwork = -1
lapack.sggev(jobvl, jobvr, &N, <float *>A.data, &N,
<float *>B.data, &N,
<float *>alphar.data, <float *> alphai.data,
<float *>beta.data,
vl_ptr, &N, vr_ptr, &N,
&qwork, &lwork, &info)
assert info == 0, "Argument error in sggev"
lwork = <l_int>qwork
work = np.empty(lwork, dtype = np.float32)
# Now the real calculation
lapack.sggev(jobvl, jobvr, &N, <float *>A.data, &N,
<float *>B.data, &N,
<float *>alphar.data, <float *> alphai.data,
<float *>beta.data,
vl_ptr, &N, vr_ptr, &N,
<float *>work.data, &lwork, &info)
if info > 0:
raise LinAlgError("QZ iteration failed to converge in sggev")
# Parameter checks
assert info == 0, "Argument error in sggev"
assert_fortran_mat(A, B)
alpha, vl, vr = ggev_postprocess(np.complex64, alphar, alphai, vl_r, vr_r)
if A.ndim != 2 or A.ndim != 2:
raise ValueError("gen_eig requires both a and be to be matrices")
return filter_args((True, True, left, right), (alpha, beta, vl, vr))
if A.shape[0] != A.shape[1]:
raise ValueError("gen_eig requires square matrix input")
if A.shape[0] != B.shape[0] or A.shape[1] != B.shape[1]:
raise ValueError("A and B do not have the same shape")
def dggev(np.ndarray[np.float64_t, ndim=2] A,
np.ndarray[np.float64_t, ndim=2] B,
left=False, right=True):
cdef l_int N, info, lwork
cdef char *jobvl
cdef char *jobvr
cdef np.ndarray[np.float64_t, ndim=2] vl_r, vr_r
cdef double *vl_ptr
cdef double *vr_ptr
cdef double qwork
cdef np.ndarray[np.float64_t, ndim=1] work, alphar, alphai, beta
assert_fortran_mat(A, B)
# Allocate workspaces
N = A.shape[0]
alphar = np.empty(N, dtype = np.float64)
alphai = np.empty(N, dtype = np.float64)
beta = np.empty(N, dtype = np.float64)
if left:
vl_r = np.empty((N,N), dtype = np.float64, order='F')
vl_ptr = <double *>vl_r.data
jobvl = "V"
else:
vl_r = None
vl_ptr = NULL
jobvl = "N"
if right:
vr_r = np.empty((N,N), dtype = np.float64, order='F')
vr_ptr = <double *>vr_r.data
jobvr = "V"
cdef np.ndarray[scalar] alphar, alphai
if scalar in cmplx:
alphar = np.empty(N, dtype=A.dtype)
alphai = None
else:
vr_r = None
vr_ptr = NULL
jobvr = "N"
alphar = np.empty(N, dtype=A.dtype)
alphai = np.empty(N, dtype=A.dtype)
# workspace query
lwork = -1
lapack.dggev(jobvl, jobvr, &N, <double *>A.data, &N,
<double *>B.data, &N,
<double *>alphar.data, <double *> alphai.data,
<double *>beta.data,
vl_ptr, &N, vr_ptr, &N,
&qwork, &lwork, &info)
assert info == 0, "Argument error in dggev"
lwork = <l_int>qwork
work = np.empty(lwork, dtype = np.float64)
# Now the real calculation
lapack.dggev(jobvl, jobvr, &N, <double *>A.data, &N,
<double *>B.data, &N,
<double *>alphar.data, <double *> alphai.data,
<double *>beta.data,
vl_ptr, &N, vr_ptr, &N,
<double *>work.data, &lwork, &info)
if info > 0:
raise LinAlgError("QZ iteration failed to converge in dggev")
assert info == 0, "Argument error in dggev"
alpha, vl, vr = ggev_postprocess(np.complex128, alphar, alphai, vl_r, vr_r)
return filter_args((True, True, left, right), (alpha, beta, vl, vr))
cdef np.ndarray[scalar] beta = np.empty(N, dtype=A.dtype)
cdef np.ndarray rwork = None
if scalar is float_complex:
rwork = np.empty(8 * N, dtype=np.float32)
elif scalar is double_complex:
rwork = np.empty(8 * N, dtype=np.float64)
def cggev(np.ndarray[np.complex64_t, ndim=2] A,
np.ndarray[np.complex64_t, ndim=2] B,
left=False, right=True):
cdef l_int N, info, lwork
cdef np.ndarray vl
cdef scalar *vl_ptr
cdef char *jobvl
cdef char *jobvr
cdef np.ndarray[np.complex64_t, ndim=2] vl, vr
cdef float complex *vl_ptr
cdef float complex *vr_ptr
cdef float complex qwork
cdef np.ndarray[np.complex64_t, ndim=1] work, alpha, beta
cdef np.ndarray[np.float32_t, ndim=1] rwork
assert_fortran_mat(A, B)
N = A.shape[0]
alpha = np.empty(N, dtype = np.complex64)
beta = np.empty(N, dtype = np.complex64)
if left:
vl = np.empty((N,N), dtype = np.complex64, order='F')
vl_ptr = <float complex *>vl.data
vl = np.empty((N,N), dtype=A.dtype, order='F')
vl_ptr = <scalar *>vl.data
jobvl = "V"
else:
vl = None
vl_ptr = NULL
jobvl = "N"
cdef np.ndarray vr
cdef scalar *vr_ptr
cdef char *jobvr
if right:
vr = np.empty((N,N), dtype = np.complex64, order='F')
vr_ptr = <float complex *>vr.data
vr = np.empty((N,N), dtype=A.dtype, order='F')
vr_ptr = <scalar *>vr.data
jobvr = "V"
else:
vr = None
vr_ptr = NULL
jobvr = "N"
rwork = np.empty(8 * N, dtype = np.float32)
# workspace query
# Workspace query
# Xggev expects &qwork as a <scalar *> (even though it's an integer)
lwork = -1
work = np.empty(1, dtype = np.complex64)
lapack.cggev(jobvl, jobvr, &N, <float complex *>A.data, &N,
<float complex *>B.data, &N,
<float complex *>alpha.data, <float complex *>beta.data,
vl_ptr, &N, vr_ptr, &N,
&qwork, &lwork,
<float *>rwork.data, &info)
assert info == 0, "Argument error in cggev"
lwork = <l_int>qwork.real
work = np.empty(lwork, dtype = np.complex64)
# Now the real calculation
lapack.cggev(jobvl, jobvr, &N, <float complex *>A.data, &N,
<float complex *>B.data, &N,
<float complex *>alpha.data, <float complex *>beta.data,
vl_ptr, &N, vr_ptr, &N,
<float complex *>work.data, &lwork,
<float *>rwork.data, &info)
cdef scalar qwork
if info > 0:
raise LinAlgError("QZ iteration failed to converge in cggev")
assert info == 0, "Argument error in cggev"
return filter_args((True, True, left, right), (alpha, beta, vl, vr))
def zggev(np.ndarray[np.complex128_t, ndim=2] A,
np.ndarray[np.complex128_t, ndim=2] B,
left=False, right=True):
cdef l_int N, info, lwork
cdef char *jobvl
cdef char *jobvr
cdef np.ndarray[np.complex128_t, ndim=2] vl, vr
cdef double complex *vl_ptr
cdef double complex *vr_ptr
cdef double complex qwork
cdef np.ndarray[np.complex128_t, ndim=1] work, alpha, beta
cdef np.ndarray[np.float64_t, ndim=1] rwork
assert_fortran_mat(A, B)
N = A.shape[0]
alpha = np.empty(N, dtype = np.complex128)
beta = np.empty(N, dtype = np.complex128)
if scalar is float:
lapack.sggev(jobvl, jobvr, &N, <float *>A.data, &N,
<float *>B.data, &N,
<float *>alphar.data, <float *> alphai.data,
<float *>beta.data,
vl_ptr, &N, vr_ptr, &N,
&qwork, &lwork, &info)
elif scalar is double:
lapack.dggev(jobvl, jobvr, &N, <double *>A.data, &N,
<double *>B.data, &N,
<double *>alphar.data, <double *> alphai.data,
<double *>beta.data,
vl_ptr, &N, vr_ptr, &N,
&qwork, &lwork, &info)
elif scalar is float_complex:
lapack.cggev(jobvl, jobvr, &N, <float complex *>A.data, &N,
<float complex *>B.data, &N,
<float complex *>alphar.data, <float complex *>beta.data,
vl_ptr, &N, vr_ptr, &N,
&qwork, &lwork,
<float *>rwork.data, &info)
elif scalar is double_complex:
lapack.zggev(jobvl, jobvr, &N, <double complex *>A.data, &N,
<double complex *>B.data, &N,
<double complex *>alphar.data, <double complex *>beta.data,
vl_ptr, &N, vr_ptr, &N,
&qwork, &lwork,
<double *>rwork.data, &info)
if left:
vl = np.empty((N,N), dtype = np.complex128, order='F')
vl_ptr = <double complex *>vl.data
jobvl = "V"
else:
vl_ptr = NULL
jobvl = "N"
assert info == 0, "Argument error in ggev"
if right:
vr = np.empty((N,N), dtype = np.complex128, order='F')
vr_ptr = <double complex *>vr.data
jobvr = "V"
if scalar in floating:
lwork = <l_int>qwork
else:
vr_ptr = NULL
jobvr = "N"
rwork = np.empty(8 * N, dtype = np.float64)
lwork = <l_int>qwork.real
cdef np.ndarray[scalar] work = np.empty(lwork, dtype=A.dtype)
# workspace query
lwork = -1
work = np.empty(1, dtype = np.complex128)
lapack.zggev(jobvl, jobvr, &N, <double complex *>A.data, &N,
<double complex *>B.data, &N,
<double complex *>alpha.data, <double complex *>beta.data,
vl_ptr, &N, vr_ptr, &N,
&qwork, &lwork,
<double *>rwork.data, &info)
# The actual calculation
assert info == 0, "Argument error in zggev"
if scalar is float:
lapack.sggev(jobvl, jobvr, &N, <float *>A.data, &N,
<float *>B.data, &N,
<float *>alphar.data, <float *> alphai.data,
<float *>beta.data,
vl_ptr, &N, vr_ptr, &N,
<float *>work.data, &lwork, &info)
elif scalar is double:
lapack.dggev(jobvl, jobvr, &N, <double *>A.data, &N,
<double *>B.data, &N,
<double *>alphar.data, <double *> alphai.data,
<double *>beta.data,
vl_ptr, &N, vr_ptr, &N,
<double *>work.data, &lwork, &info)
elif scalar is float_complex:
lapack.cggev(jobvl, jobvr, &N, <float complex *>A.data, &N,
<float complex *>B.data, &N,
<float complex *>alphar.data, <float complex *>beta.data,
vl_ptr, &N, vr_ptr, &N,
<float complex *>work.data, &lwork,
<float *>rwork.data, &info)
elif scalar is double_complex:
lapack.zggev(jobvl, jobvr, &N, <double complex *>A.data, &N,
<double complex *>B.data, &N,
<double complex *>alphar.data, <double complex *>beta.data,
vl_ptr, &N, vr_ptr, &N,
<double complex *>work.data, &lwork,
<double *>rwork.data, &info)
lwork = <l_int>qwork.real
work = np.empty(lwork, dtype = np.complex128)
if info > 0:
raise LinAlgError("QZ iteration failed to converge in sggev")
# Now the real calculation
lapack.zggev(jobvl, jobvr, &N, <double complex *>A.data, &N,
<double complex *>B.data, &N,
<double complex *>alpha.data, <double complex *>beta.data,
vl_ptr, &N, vr_ptr, &N,
<double complex *>work.data, &lwork,
<double *>rwork.data, &info)
assert info == 0, "Argument error in ggev"
if info > 0:
raise LinAlgError("QZ iteration failed to converge in zggev")
if scalar is float:
post_dtype = np.complex64
elif scalar is double:
post_dtype = np.complex128
assert info == 0, "Argument error in zggev"
cdef np.ndarray alpha
alpha = alphar
if scalar in floating:
alpha, vl, vr = ggev_postprocess(post_dtype, alphar, alphai, vl, vr)
return filter_args((True, True, left, right), (alpha, beta, vl, vr))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment