Skip to content
Snippets Groups Projects
Commit 1e6c0860 authored by Joseph Weston's avatar Joseph Weston
Browse files

Merge branch 'cleanup/remove-lapack' into 'stable'

We can now use Scipy's Cython interface to Lapack, which removes the need for us to explicitly link against blas or lapack ourselves. Also cut down boilerplate by re-arrange the existing Lapack wrappers to use Cython fused types.

Closes #28

See merge request !149
parents 2c5da944 bd885029
No related branches found
No related tags found
No related merge requests found
......@@ -50,18 +50,5 @@ def gen_eig(a, b, left=False, right=True, overwrite_ab=False):
The right eigenvector corresponding to the eigenvalue
``alpha[i]/beta[i]`` is the column ``vr[:,i]``.
"""
ltype, a, b = lapack.prepare_for_lapack(overwrite_ab, a, b)
if a.ndim != 2 or b.ndim != 2:
raise ValueError("gen_eig requires both a and be to be matrices")
if a.shape[0] != a.shape[1]:
raise ValueError("gen_eig requires square matrix input")
if b.shape[0] != a.shape[0] or b.shape[1] != a.shape[1]:
raise ValueError("gen_eig requires a and be to have the same shape")
ggev = getattr(lapack, ltype + "ggev")
return ggev(a, b, left, right)
a, b = lapack.prepare_for_lapack(overwrite_ab, a, b)
return lapack.ggev(a, b, left, right)
......@@ -45,20 +45,8 @@ def lu_factor(a, overwrite_a=False):
singular : boolean
Whether the matrix a is singular (up to machine precision)
"""
ltype, a = lapack.prepare_for_lapack(overwrite_a, a)
if a.ndim != 2:
raise ValueError("lu_factor expects a matrix")
if ltype == 'd':
return lapack.dgetrf(a)
elif ltype == 'z':
return lapack.zgetrf(a)
elif ltype == 's':
return lapack.sgetrf(a)
else:
return lapack.cgetrf(a)
a = lapack.prepare_for_lapack(overwrite_a, a)
return lapack.getrf(a)
def lu_solve(matrix_factorization, b):
......@@ -83,23 +71,9 @@ def lu_solve(matrix_factorization, b):
"a singular matrix. Result of solve step "
"are probably unreliable")
ltype, lu, b = lapack.prepare_for_lapack(False, lu, b)
lu, b = lapack.prepare_for_lapack(False, lu, b)
ipiv = np.ascontiguousarray(np.asanyarray(ipiv), dtype=lapack.int_dtype)
if b.ndim > 2:
raise ValueError("lu_solve: b must be a vector or matrix")
if lu.shape[0] != b.shape[0]:
raise ValueError("lu_solve: incompatible dimensions of b")
if ltype == 'd':
return lapack.dgetrs(lu, ipiv, b)
elif ltype == 'z':
return lapack.zgetrs(lu, ipiv, b)
elif ltype == 's':
return lapack.sgetrs(lu, ipiv, b)
else:
return lapack.cgetrs(lu, ipiv, b)
return lapack.getrs(lu, ipiv, b)
def rcond_from_lu(matrix_factorization, norm_a, norm="1"):
......@@ -127,17 +101,6 @@ def rcond_from_lu(matrix_factorization, norm_a, norm="1"):
norm specified in norm
"""
(lu, ipiv, singular) = matrix_factorization
if not norm in ("1", "I"):
raise ValueError("norm in rcond_from_lu must be either '1' or 'I'")
norm = norm.encode('utf8') # lapack expects bytes
ltype, lu = lapack.prepare_for_lapack(False, lu)
if ltype == 'd':
return lapack.dgecon(lu, norm_a, norm)
elif ltype == 'z':
return lapack.zgecon(lu, norm_a, norm)
elif ltype == 's':
return lapack.sgecon(lu, norm_a, norm)
else:
return lapack.cgecon(lu, norm_a, norm)
lu = lapack.prepare_for_lapack(False, lu)
return lapack.gecon(lu, norm_a, norm)
......@@ -62,18 +62,8 @@ def schur(a, calc_q=True, calc_ev=True, overwrite_a=False):
LinAlgError
If the underlying QR iteration fails to converge.
"""
ltype, a = lapack.prepare_for_lapack(overwrite_a, a)
if a.ndim != 2:
raise ValueError("Expect matrix as input")
if a.shape[0] != a.shape[1]:
raise ValueError("Expect square matrix")
gees = getattr(lapack, ltype + "gees")
return gees(a, calc_q, calc_ev)
a = lapack.prepare_for_lapack(overwrite_a, a)
return lapack.gees(a, calc_q, calc_ev)
def convert_r2c_schur(t, q):
......@@ -192,9 +182,7 @@ def order_schur(select, t, q, calc_ev=True, overwrite_tq=False):
``calc_ev == True``
"""
ltype, t, q = lapack.prepare_for_lapack(overwrite_tq, t, q)
trsen = getattr(lapack, ltype + "trsen")
t, q = lapack.prepare_for_lapack(overwrite_tq, t, q)
# Figure out if select is a function or array.
isfun = isarray = True
......@@ -223,7 +211,7 @@ def order_schur(select, t, q, calc_ev=True, overwrite_tq=False):
t, q = convert_r2c_schur(t, q)
return order_schur(select, t, q, calc_ev, True)
return trsen(select, t, q, calc_ev)
return lapack.trsen(select, t, q, calc_ev)
def evecs_from_schur(t, q, select=None, left=False, right=True,
......@@ -267,13 +255,7 @@ def evecs_from_schur(t, q, select=None, left=False, right=True,
``right == True``.
"""
ltype, t, q = lapack.prepare_for_lapack(overwrite_tq, t, q)
if (t.shape[0] != t.shape[1] or q.shape[0] != q.shape[1]
or t.shape[0] != q.shape[0]):
raise ValueError("Invalid Schur decomposition as input")
trevc = getattr(lapack, ltype + "trevc")
t, q = lapack.prepare_for_lapack(overwrite_tq, t, q)
# check if select is a function or an array
if select is not None:
......@@ -300,7 +282,7 @@ def evecs_from_schur(t, q, select=None, left=False, right=True,
else:
selectarr = None
return trevc(t, q, selectarr, left, right)
return lapack.trevc(t, q, selectarr, left, right)
def gen_schur(a, b, calc_q=True, calc_z=True, calc_ev=True,
......@@ -364,21 +346,8 @@ def gen_schur(a, b, calc_q=True, calc_z=True, calc_ev=True,
LinAlError
If the underlying QZ iteration fails to converge.
"""
ltype, a, b = lapack.prepare_for_lapack(overwrite_ab, a, b)
if a.ndim != 2 or b.ndim != 2:
raise ValueError("Expect matrices as input")
if a.shape[0] != a.shape[1]:
raise ValueError("Expect square matrix a")
if a.shape[0] != b.shape[0] or a.shape[0] != b.shape[1]:
raise ValueError("Shape of b is incompatible to matrix a")
gges = getattr(lapack, ltype + "gges")
return gges(a, b, calc_q, calc_z, calc_ev)
a, b = lapack.prepare_for_lapack(overwrite_ab, a, b)
return lapack.gges(a, b, calc_q, calc_z, calc_ev)
def order_gen_schur(select, s, t, q=None, z=None, calc_ev=True,
......@@ -439,22 +408,8 @@ def order_gen_schur(select, s, t, q=None, z=None, calc_ev=True,
LinAlError
If the problem is too ill-conditioned.
"""
ltype, s, t, q, z = lapack.prepare_for_lapack(overwrite_stqz, s, t, q, z)
s, t, q, z = lapack.prepare_for_lapack(overwrite_stqz, s, t, q, z)
if (s.ndim != 2 or t.ndim != 2 or
(q is not None and q.ndim != 2) or
(z is not None and z.ndim != 2)):
raise ValueError("Expect matrices as input")
if ((s.shape[0] != s.shape[1] or t.shape[0] != t.shape[1] or
s.shape[0] != t.shape[0]) or
(q is not None and (q.shape[0] != q.shape[1] or
s.shape[0] != q.shape[0])) or
(z is not None and (z.shape[0] != z.shape[1] or
s.shape[0] != z.shape[0]))):
raise ValueError("Invalid Schur decomposition as input")
tgsen = getattr(lapack, ltype + "tgsen")
# Figure out if select is a function or array.
isfun = isarray = True
......@@ -492,7 +447,7 @@ def order_gen_schur(select, s, t, q=None, z=None, calc_ev=True,
return order_gen_schur(select, s, t, q, z, calc_ev, True)
return tgsen(select, s, t, q, z, calc_ev)
return lapack.tgsen(select, s, t, q, z, calc_ev)
def convert_r2c_gen_schur(s, t, q=None, z=None):
......@@ -536,7 +491,7 @@ def convert_r2c_gen_schur(s, t, q=None, z=None):
If it fails to convert a 2x2 block into complex form (unlikely).
"""
ltype, s, t, q, z = lapack.prepare_for_lapack(True, s, t, q, z)
s, t, q, z = lapack.prepare_for_lapack(True, s, t, q, z)
# Note: overwrite=True does not mean much here, the arrays are all copied
if (s.ndim != 2 or t.ndim != 2 or
......@@ -656,20 +611,7 @@ def evecs_from_gen_schur(s, t, q=None, z=None, select=None,
"""
ltype, s, t, q, z = lapack.prepare_for_lapack(overwrite_qz, s, t, q, z)
if (s.ndim != 2 or t.ndim != 2 or
(q is not None and q.ndim != 2) or
(z is not None and z.ndim != 2)):
raise ValueError("Expect matrices as input")
if ((s.shape[0] != s.shape[1] or t.shape[0] != t.shape[1] or
s.shape[0] != t.shape[0]) or
(q is not None and (q.shape[0] != q.shape[1] or
s.shape[0] != q.shape[0])) or
(z is not None and (z.shape[0] != z.shape[1] or
s.shape[0] != z.shape[0]))):
raise ValueError("Invalid Schur decomposition as input")
s, t, q, z = lapack.prepare_for_lapack(overwrite_qz, s, t, q, z)
if left and q is None:
raise ValueError("Matrix q must be provided for left eigenvectors")
......@@ -677,8 +619,6 @@ def evecs_from_gen_schur(s, t, q=None, z=None, select=None,
if right and z is None:
raise ValueError("Matrix z must be provided for right eigenvectors")
tgevc = getattr(lapack, ltype + "tgevc")
# Check if select is a function or an array.
if select is not None:
isfun = isarray = True
......@@ -704,4 +644,4 @@ def evecs_from_gen_schur(s, t, q=None, z=None, select=None,
else:
selectarr = None
return tgevc(s, t, q, z, selectarr, left, right)
return lapack.tgevc(s, t, q, z, selectarr, left, right)
# Copyright 2011-2013 Kwant authors.
#
# This file is part of Kwant. It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
# http://kwant-project.org/license. A list of Kwant authors can be found in
# the file AUTHORS.rst at the top-level directory of this distribution and at
# http://kwant-project.org/authors.
ctypedef int l_int
ctypedef int l_logical
cdef extern:
void sgetrf_(l_int *, l_int *, float *, l_int *, l_int *, l_int *)
void dgetrf_(l_int *, l_int *, double *, l_int *, l_int *, l_int *)
void cgetrf_(l_int *, l_int *, float complex *, l_int *, l_int *,
l_int *)
void zgetrf_(l_int *, l_int *, double complex *, l_int *, l_int *,
l_int *)
void sgetrs_(char *, l_int *, l_int *, float *, l_int *, l_int *,
float *, l_int *, l_int *)
void dgetrs_(char *, l_int *, l_int *, double *, l_int *, l_int *,
double *, l_int *, l_int *)
void cgetrs_(char *, l_int *, l_int *, float complex *, l_int *,
l_int *, float complex *, l_int *, l_int *)
void zgetrs_(char *, l_int *, l_int *, double complex *, l_int *,
l_int *, double complex *, l_int *, l_int *)
void sgecon_(char *, l_int *, float *, l_int *, float *, float *,
float *, l_int *, l_int *)
void dgecon_(char *, l_int *, double *, l_int *, double *, double *,
double *, l_int *, l_int *)
void cgecon_(char *, l_int *, float complex *, l_int *, float *,
float *, float complex *, float *, l_int *)
void zgecon_(char *, l_int *, double complex *, l_int *, double *,
double *, double complex *, double *, l_int *)
void sggev_(char *, char *, l_int *, float *, l_int *, float *, l_int *,
float *, float *, float *, float *, l_int *, float *, l_int *,
float *, l_int *, l_int *)
void dggev_(char *, char *, l_int *, double *, l_int *, double *, l_int *,
double *, double *, double *, double *, l_int *,
double *, l_int *, double *, l_int *, l_int *)
void cggev_(char *, char *, l_int *, float complex *, l_int *,
float complex *, l_int *, float complex *, float complex *,
float complex *, l_int *, float complex *, l_int *,
float complex *, l_int *, float *, l_int *)
void zggev_(char *, char *, l_int *, double complex *, l_int *,
double complex *, l_int *, double complex *,
double complex *, double complex *, l_int *,
double complex *, l_int *, double complex *, l_int *,
double *, l_int *)
void sgees_(char *, char *, l_logical (*)(float *, float *),
l_int *, float *, l_int *, l_int *,
float *, float *, float *, l_int *,
float *, l_int *, l_logical *, l_int *)
void dgees_(char *, char *, l_logical (*)(double *, double *),
l_int *, double *, l_int *, l_int *,
double *, double *, double *, l_int *,
double *, l_int *, l_logical *, l_int *)
void cgees_(char *, char *,
l_logical (*)(float complex *),
l_int *, float complex *,
l_int *, l_int *, float complex *,
float complex *, l_int *,
float complex *, l_int *, float *,
l_logical *, l_int *)
void zgees_(char *, char *,
l_logical (*)(double complex *),
l_int *, double complex *,
l_int *, l_int *, double complex *,
double complex *, l_int *,
double complex *, l_int *,
double *, l_logical *, l_int *)
void strsen_(char *, char *, l_logical *, l_int *,
float *, l_int *, float *,
l_int *, float *, float *, l_int *,
float *, float *, float *, l_int *,
l_int *, l_int *, l_int *)
void dtrsen_(char *, char *, l_logical *,
l_int *, double *, l_int *,
double *, l_int *, double *, double *,
l_int *, double *, double *, double *,
l_int *, l_int *, l_int *, l_int *)
void ctrsen_(char *, char *, l_logical *,
l_int *, float complex *,
l_int *, float complex *,
l_int *, float complex *, l_int *,
float *, float *, float complex *,
l_int *, l_int *)
void ztrsen_(char *, char *, l_logical *,
l_int *, double complex *,
l_int *, double complex *,
l_int *, double complex *, l_int *,
double *, double *, double complex *,
l_int *, l_int *)
void strevc_(char *, char *, l_logical *,
l_int *, float *, l_int *,
float *, l_int *, float *, l_int *,
l_int *, l_int *, float *, l_int *)
void dtrevc_(char *, char *, l_logical *,
l_int *, double *, l_int *,
double *, l_int *, double *,
l_int *, l_int *, l_int *, double *,
l_int *)
void ctrevc_(char *, char *, l_logical *,
l_int *, float complex *,
l_int *, float complex *,
l_int *, float complex *,
l_int *, l_int *, l_int *,
float complex *, float *, l_int *)
void ztrevc_(char *, char *, l_logical *,
l_int *, double complex *,
l_int *, double complex *,
l_int *, double complex *,
l_int *, l_int *, l_int *,
double complex *, double *, l_int *)
void sgges_(char *, char *, char *,
l_logical (*)(float *, float *, float *),
l_int *, float *, l_int *, float *,
l_int *, l_int *, float *, float *,
float *, float *, l_int *, float *,
l_int *, float *, l_int *, l_logical *,
l_int *)
void dgges_(char *, char *, char *,
l_logical (*)(double *, double *, double *),
l_int *, double *, l_int *, double *,
l_int *, l_int *, double *, double *,
double *, double *, l_int *, double *,
l_int *, double *, l_int *,
l_logical *, l_int *)
void cgges_(char *, char *, char *,
l_logical (*)(float complex *, float complex *),
l_int *, float complex *,
l_int *, float complex *,
l_int *, l_int *, float complex *,
float complex *, float complex *,
l_int *, float complex *,
l_int *, float complex *,
l_int *, float *, l_logical *, l_int *)
void zgges_(char *, char *, char *,
l_logical (*)(double complex *, double complex *),
l_int *, double complex *,
l_int *, double complex *,
l_int *, l_int *, double complex *,
double complex *,
double complex *, l_int *,
double complex *, l_int *,
double complex *, l_int *,
double *, l_logical *, l_int *)
void stgsen_(l_int *, l_logical *,
l_logical *, l_logical *,
l_int *, float *, l_int *, float *,
l_int *, float *, float *, float *,
float *, l_int *, float *, l_int *,
l_int *, float *, float *, float *, float *,
l_int *, l_int *, l_int *, l_int *)
void dtgsen_(l_int *, l_logical *,
l_logical *, l_logical *,
l_int *, double *, l_int *,
double *, l_int *, double *, double *,
double *, double *, l_int *, double *,
l_int *, l_int *, double *, double *,
double *, double *, l_int *, l_int *,
l_int *, l_int *)
void ctgsen_(l_int *, l_logical *,
l_logical *, l_logical *,
l_int *, float complex *,
l_int *, float complex *,
l_int *, float complex *,
float complex *,
float complex *, l_int *,
float complex *, l_int *, l_int *,
float *, float *, float *,
float complex *, l_int *, l_int *,
l_int *, l_int *)
void ztgsen_(l_int *, l_logical *,
l_logical *, l_logical *,
l_int *, double complex *,
l_int *, double complex *,
l_int *, double complex *,
double complex *,
double complex *, l_int *,
double complex *, l_int *, l_int *,
double *, double *, double *,
double complex *, l_int *, l_int *,
l_int *, l_int *)
void stgevc_(char *, char *, l_logical *,
l_int *, float *, l_int *,
float *, l_int *, float *,
l_int *, float *, l_int *,
l_int *, l_int *, float *, l_int *)
void dtgevc_(char *, char *, l_logical *,
l_int *, double *, l_int *,
double *, l_int *, double *,
l_int *, double *, l_int *,
l_int *, l_int *, double *, l_int *)
void ctgevc_(char *, char *, l_logical *,
l_int *, float complex *,
l_int *, float complex *,
l_int *, float complex *,
l_int *, float complex *,
l_int *, l_int *, l_int *,
float complex *, float *, l_int *)
void ztgevc_(char *, char *, l_logical *,
l_int *, double complex *,
l_int *, double complex *,
l_int *, double complex *,
l_int *, double complex *,
l_int *, l_int *, l_int *,
double complex *, double *, l_int *)
# Copyright 2011-2013 Kwant authors.
#
# This file is part of Kwant. It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
# http://kwant-project.org/license. A list of Kwant authors can be found in
# the file AUTHORS.rst at the top-level directory of this distribution and at
# http://kwant-project.org/authors.
__all__ = ['l_int_dtype', 'l_logical_dtype']
import numpy as np
l_int_dtype = np.int32
l_logical_dtype = np.int32
This diff is collapsed.
......@@ -372,8 +372,9 @@ def search_mumps():
# Conda (via conda-forge).
# TODO: remove dependency libs (scotch, metis...) when conda-forge
# packaged mumps/scotch are built as properly linked shared libs
# 'openblas' provides Lapack and BLAS symbols
['zmumps', 'mumps_common', 'metis', 'esmumps', 'scotch',
'scotcherr', 'mpiseq'],
'scotcherr', 'mpiseq', 'openblas'],
]
common_libs = ['pord', 'gfortran']
......@@ -384,34 +385,7 @@ def search_mumps():
return []
def search_lapack():
"""Return the BLAS variant that is installed."""
lib_sets = [
# Debian
['blas', 'lapack'],
# Conda (via conda-forge). Openblas contains lapack symbols
['openblas', 'gfortran'],
]
for libs in lib_sets:
found_libs = search_libs(libs)
if found_libs:
return found_libs
print('Error: BLAS/LAPACK are required but were not found.',
file=sys.stderr)
sys.exit(1)
def configure_special_extensions(exts, build_summary):
#### Special config for LAPACK.
lapack = exts['kwant.linalg.lapack']
if 'libraries' in lapack:
build_summary.append('User-configured LAPACK and BLAS')
else:
lapack['libraries'] = search_lapack()
build_summary.append('Default LAPACK and BLAS')
#### Special config for MUMPS.
mumps = exts['kwant.linalg._mumps']
if 'libraries' in mumps:
......@@ -426,12 +400,6 @@ def configure_special_extensions(exts, build_summary):
del exts['kwant.linalg._mumps']
build_summary.append('No MUMPS support')
if mumps:
# Copy config from LAPACK.
for key, value in lapack.items():
if key not in ['sources', 'depends']:
mumps.setdefault(key, []).extend(value)
return exts
......@@ -550,8 +518,7 @@ def main():
'kwant/graph/c_slicer/partitioner.h',
'kwant/graph/c_slicer/slicer.h'])),
('kwant.linalg.lapack',
dict(sources=['kwant/linalg/lapack.pyx'],
depends=['kwant/linalg/f_lapack.pxd'])),
dict(sources=['kwant/linalg/lapack.pyx'])),
('kwant.linalg._mumps',
dict(sources=['kwant/linalg/_mumps.pyx'],
depends=['kwant/linalg/cmumps.pxd']))])
......@@ -567,8 +534,7 @@ def main():
for ext in exts.values():
ext.setdefault('include_dirs', []).append(numpy_include)
aliases = [('lapack', 'kwant.linalg.lapack'),
('mumps', 'kwant.linalg._mumps')]
aliases = [('mumps', 'kwant.linalg._mumps')]
init_cython()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment