Skip to content
Snippets Groups Projects
Commit 3ef0b9fc authored by Joseph Weston's avatar Joseph Weston
Browse files

Revert "Merge branch 'cleanup/remove-lapack' into 'stable'"

This reverts commit 1e6c0860, reversing
changes made to 2c5da944.

The changes to Kwant's Lapack wrappers depend on features from
Scipy 0.16, so it cannot be merged into stable branch, which
depends on Scipy 0.14.
parent 75d4e2cd
No related branches found
No related tags found
No related merge requests found
Pipeline #
...@@ -50,5 +50,18 @@ def gen_eig(a, b, left=False, right=True, overwrite_ab=False): ...@@ -50,5 +50,18 @@ def gen_eig(a, b, left=False, right=True, overwrite_ab=False):
The right eigenvector corresponding to the eigenvalue The right eigenvector corresponding to the eigenvalue
``alpha[i]/beta[i]`` is the column ``vr[:,i]``. ``alpha[i]/beta[i]`` is the column ``vr[:,i]``.
""" """
a, b = lapack.prepare_for_lapack(overwrite_ab, a, b)
return lapack.ggev(a, b, left, right) ltype, a, b = lapack.prepare_for_lapack(overwrite_ab, a, b)
if a.ndim != 2 or b.ndim != 2:
raise ValueError("gen_eig requires both a and be to be matrices")
if a.shape[0] != a.shape[1]:
raise ValueError("gen_eig requires square matrix input")
if b.shape[0] != a.shape[0] or b.shape[1] != a.shape[1]:
raise ValueError("gen_eig requires a and be to have the same shape")
ggev = getattr(lapack, ltype + "ggev")
return ggev(a, b, left, right)
...@@ -45,8 +45,20 @@ def lu_factor(a, overwrite_a=False): ...@@ -45,8 +45,20 @@ def lu_factor(a, overwrite_a=False):
singular : boolean singular : boolean
Whether the matrix a is singular (up to machine precision) Whether the matrix a is singular (up to machine precision)
""" """
a = lapack.prepare_for_lapack(overwrite_a, a)
return lapack.getrf(a) ltype, a = lapack.prepare_for_lapack(overwrite_a, a)
if a.ndim != 2:
raise ValueError("lu_factor expects a matrix")
if ltype == 'd':
return lapack.dgetrf(a)
elif ltype == 'z':
return lapack.zgetrf(a)
elif ltype == 's':
return lapack.sgetrf(a)
else:
return lapack.cgetrf(a)
def lu_solve(matrix_factorization, b): def lu_solve(matrix_factorization, b):
...@@ -71,9 +83,23 @@ def lu_solve(matrix_factorization, b): ...@@ -71,9 +83,23 @@ def lu_solve(matrix_factorization, b):
"a singular matrix. Result of solve step " "a singular matrix. Result of solve step "
"are probably unreliable") "are probably unreliable")
lu, b = lapack.prepare_for_lapack(False, lu, b) ltype, lu, b = lapack.prepare_for_lapack(False, lu, b)
ipiv = np.ascontiguousarray(np.asanyarray(ipiv), dtype=lapack.int_dtype) ipiv = np.ascontiguousarray(np.asanyarray(ipiv), dtype=lapack.int_dtype)
return lapack.getrs(lu, ipiv, b)
if b.ndim > 2:
raise ValueError("lu_solve: b must be a vector or matrix")
if lu.shape[0] != b.shape[0]:
raise ValueError("lu_solve: incompatible dimensions of b")
if ltype == 'd':
return lapack.dgetrs(lu, ipiv, b)
elif ltype == 'z':
return lapack.zgetrs(lu, ipiv, b)
elif ltype == 's':
return lapack.sgetrs(lu, ipiv, b)
else:
return lapack.cgetrs(lu, ipiv, b)
def rcond_from_lu(matrix_factorization, norm_a, norm="1"): def rcond_from_lu(matrix_factorization, norm_a, norm="1"):
...@@ -101,6 +127,17 @@ def rcond_from_lu(matrix_factorization, norm_a, norm="1"): ...@@ -101,6 +127,17 @@ def rcond_from_lu(matrix_factorization, norm_a, norm="1"):
norm specified in norm norm specified in norm
""" """
(lu, ipiv, singular) = matrix_factorization (lu, ipiv, singular) = matrix_factorization
if not norm in ("1", "I"):
raise ValueError("norm in rcond_from_lu must be either '1' or 'I'")
norm = norm.encode('utf8') # lapack expects bytes norm = norm.encode('utf8') # lapack expects bytes
lu = lapack.prepare_for_lapack(False, lu)
return lapack.gecon(lu, norm_a, norm) ltype, lu = lapack.prepare_for_lapack(False, lu)
if ltype == 'd':
return lapack.dgecon(lu, norm_a, norm)
elif ltype == 'z':
return lapack.zgecon(lu, norm_a, norm)
elif ltype == 's':
return lapack.sgecon(lu, norm_a, norm)
else:
return lapack.cgecon(lu, norm_a, norm)
...@@ -62,8 +62,18 @@ def schur(a, calc_q=True, calc_ev=True, overwrite_a=False): ...@@ -62,8 +62,18 @@ def schur(a, calc_q=True, calc_ev=True, overwrite_a=False):
LinAlgError LinAlgError
If the underlying QR iteration fails to converge. If the underlying QR iteration fails to converge.
""" """
a = lapack.prepare_for_lapack(overwrite_a, a)
return lapack.gees(a, calc_q, calc_ev) ltype, a = lapack.prepare_for_lapack(overwrite_a, a)
if a.ndim != 2:
raise ValueError("Expect matrix as input")
if a.shape[0] != a.shape[1]:
raise ValueError("Expect square matrix")
gees = getattr(lapack, ltype + "gees")
return gees(a, calc_q, calc_ev)
def convert_r2c_schur(t, q): def convert_r2c_schur(t, q):
...@@ -182,7 +192,9 @@ def order_schur(select, t, q, calc_ev=True, overwrite_tq=False): ...@@ -182,7 +192,9 @@ def order_schur(select, t, q, calc_ev=True, overwrite_tq=False):
``calc_ev == True`` ``calc_ev == True``
""" """
t, q = lapack.prepare_for_lapack(overwrite_tq, t, q) ltype, t, q = lapack.prepare_for_lapack(overwrite_tq, t, q)
trsen = getattr(lapack, ltype + "trsen")
# Figure out if select is a function or array. # Figure out if select is a function or array.
isfun = isarray = True isfun = isarray = True
...@@ -211,7 +223,7 @@ def order_schur(select, t, q, calc_ev=True, overwrite_tq=False): ...@@ -211,7 +223,7 @@ def order_schur(select, t, q, calc_ev=True, overwrite_tq=False):
t, q = convert_r2c_schur(t, q) t, q = convert_r2c_schur(t, q)
return order_schur(select, t, q, calc_ev, True) return order_schur(select, t, q, calc_ev, True)
return lapack.trsen(select, t, q, calc_ev) return trsen(select, t, q, calc_ev)
def evecs_from_schur(t, q, select=None, left=False, right=True, def evecs_from_schur(t, q, select=None, left=False, right=True,
...@@ -255,7 +267,13 @@ def evecs_from_schur(t, q, select=None, left=False, right=True, ...@@ -255,7 +267,13 @@ def evecs_from_schur(t, q, select=None, left=False, right=True,
``right == True``. ``right == True``.
""" """
t, q = lapack.prepare_for_lapack(overwrite_tq, t, q) ltype, t, q = lapack.prepare_for_lapack(overwrite_tq, t, q)
if (t.shape[0] != t.shape[1] or q.shape[0] != q.shape[1]
or t.shape[0] != q.shape[0]):
raise ValueError("Invalid Schur decomposition as input")
trevc = getattr(lapack, ltype + "trevc")
# check if select is a function or an array # check if select is a function or an array
if select is not None: if select is not None:
...@@ -282,7 +300,7 @@ def evecs_from_schur(t, q, select=None, left=False, right=True, ...@@ -282,7 +300,7 @@ def evecs_from_schur(t, q, select=None, left=False, right=True,
else: else:
selectarr = None selectarr = None
return lapack.trevc(t, q, selectarr, left, right) return trevc(t, q, selectarr, left, right)
def gen_schur(a, b, calc_q=True, calc_z=True, calc_ev=True, def gen_schur(a, b, calc_q=True, calc_z=True, calc_ev=True,
...@@ -346,8 +364,21 @@ def gen_schur(a, b, calc_q=True, calc_z=True, calc_ev=True, ...@@ -346,8 +364,21 @@ def gen_schur(a, b, calc_q=True, calc_z=True, calc_ev=True,
LinAlError LinAlError
If the underlying QZ iteration fails to converge. If the underlying QZ iteration fails to converge.
""" """
a, b = lapack.prepare_for_lapack(overwrite_ab, a, b)
return lapack.gges(a, b, calc_q, calc_z, calc_ev) ltype, a, b = lapack.prepare_for_lapack(overwrite_ab, a, b)
if a.ndim != 2 or b.ndim != 2:
raise ValueError("Expect matrices as input")
if a.shape[0] != a.shape[1]:
raise ValueError("Expect square matrix a")
if a.shape[0] != b.shape[0] or a.shape[0] != b.shape[1]:
raise ValueError("Shape of b is incompatible to matrix a")
gges = getattr(lapack, ltype + "gges")
return gges(a, b, calc_q, calc_z, calc_ev)
def order_gen_schur(select, s, t, q=None, z=None, calc_ev=True, def order_gen_schur(select, s, t, q=None, z=None, calc_ev=True,
...@@ -408,8 +439,22 @@ def order_gen_schur(select, s, t, q=None, z=None, calc_ev=True, ...@@ -408,8 +439,22 @@ def order_gen_schur(select, s, t, q=None, z=None, calc_ev=True,
LinAlError LinAlError
If the problem is too ill-conditioned. If the problem is too ill-conditioned.
""" """
s, t, q, z = lapack.prepare_for_lapack(overwrite_stqz, s, t, q, z) ltype, s, t, q, z = lapack.prepare_for_lapack(overwrite_stqz, s, t, q, z)
if (s.ndim != 2 or t.ndim != 2 or
(q is not None and q.ndim != 2) or
(z is not None and z.ndim != 2)):
raise ValueError("Expect matrices as input")
if ((s.shape[0] != s.shape[1] or t.shape[0] != t.shape[1] or
s.shape[0] != t.shape[0]) or
(q is not None and (q.shape[0] != q.shape[1] or
s.shape[0] != q.shape[0])) or
(z is not None and (z.shape[0] != z.shape[1] or
s.shape[0] != z.shape[0]))):
raise ValueError("Invalid Schur decomposition as input")
tgsen = getattr(lapack, ltype + "tgsen")
# Figure out if select is a function or array. # Figure out if select is a function or array.
isfun = isarray = True isfun = isarray = True
...@@ -447,7 +492,7 @@ def order_gen_schur(select, s, t, q=None, z=None, calc_ev=True, ...@@ -447,7 +492,7 @@ def order_gen_schur(select, s, t, q=None, z=None, calc_ev=True,
return order_gen_schur(select, s, t, q, z, calc_ev, True) return order_gen_schur(select, s, t, q, z, calc_ev, True)
return lapack.tgsen(select, s, t, q, z, calc_ev) return tgsen(select, s, t, q, z, calc_ev)
def convert_r2c_gen_schur(s, t, q=None, z=None): def convert_r2c_gen_schur(s, t, q=None, z=None):
...@@ -491,7 +536,7 @@ def convert_r2c_gen_schur(s, t, q=None, z=None): ...@@ -491,7 +536,7 @@ def convert_r2c_gen_schur(s, t, q=None, z=None):
If it fails to convert a 2x2 block into complex form (unlikely). If it fails to convert a 2x2 block into complex form (unlikely).
""" """
s, t, q, z = lapack.prepare_for_lapack(True, s, t, q, z) ltype, s, t, q, z = lapack.prepare_for_lapack(True, s, t, q, z)
# Note: overwrite=True does not mean much here, the arrays are all copied # Note: overwrite=True does not mean much here, the arrays are all copied
if (s.ndim != 2 or t.ndim != 2 or if (s.ndim != 2 or t.ndim != 2 or
...@@ -611,7 +656,20 @@ def evecs_from_gen_schur(s, t, q=None, z=None, select=None, ...@@ -611,7 +656,20 @@ def evecs_from_gen_schur(s, t, q=None, z=None, select=None,
""" """
s, t, q, z = lapack.prepare_for_lapack(overwrite_qz, s, t, q, z) ltype, s, t, q, z = lapack.prepare_for_lapack(overwrite_qz, s, t, q, z)
if (s.ndim != 2 or t.ndim != 2 or
(q is not None and q.ndim != 2) or
(z is not None and z.ndim != 2)):
raise ValueError("Expect matrices as input")
if ((s.shape[0] != s.shape[1] or t.shape[0] != t.shape[1] or
s.shape[0] != t.shape[0]) or
(q is not None and (q.shape[0] != q.shape[1] or
s.shape[0] != q.shape[0])) or
(z is not None and (z.shape[0] != z.shape[1] or
s.shape[0] != z.shape[0]))):
raise ValueError("Invalid Schur decomposition as input")
if left and q is None: if left and q is None:
raise ValueError("Matrix q must be provided for left eigenvectors") raise ValueError("Matrix q must be provided for left eigenvectors")
...@@ -619,6 +677,8 @@ def evecs_from_gen_schur(s, t, q=None, z=None, select=None, ...@@ -619,6 +677,8 @@ def evecs_from_gen_schur(s, t, q=None, z=None, select=None,
if right and z is None: if right and z is None:
raise ValueError("Matrix z must be provided for right eigenvectors") raise ValueError("Matrix z must be provided for right eigenvectors")
tgevc = getattr(lapack, ltype + "tgevc")
# Check if select is a function or an array. # Check if select is a function or an array.
if select is not None: if select is not None:
isfun = isarray = True isfun = isarray = True
...@@ -644,4 +704,4 @@ def evecs_from_gen_schur(s, t, q=None, z=None, select=None, ...@@ -644,4 +704,4 @@ def evecs_from_gen_schur(s, t, q=None, z=None, select=None,
else: else:
selectarr = None selectarr = None
return lapack.tgevc(s, t, q, z, selectarr, left, right) return tgevc(s, t, q, z, selectarr, left, right)
# Copyright 2011-2013 Kwant authors.
#
# This file is part of Kwant. It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
# http://kwant-project.org/license. A list of Kwant authors can be found in
# the file AUTHORS.rst at the top-level directory of this distribution and at
# http://kwant-project.org/authors.
ctypedef int l_int
ctypedef int l_logical
cdef extern:
void sgetrf_(l_int *, l_int *, float *, l_int *, l_int *, l_int *)
void dgetrf_(l_int *, l_int *, double *, l_int *, l_int *, l_int *)
void cgetrf_(l_int *, l_int *, float complex *, l_int *, l_int *,
l_int *)
void zgetrf_(l_int *, l_int *, double complex *, l_int *, l_int *,
l_int *)
void sgetrs_(char *, l_int *, l_int *, float *, l_int *, l_int *,
float *, l_int *, l_int *)
void dgetrs_(char *, l_int *, l_int *, double *, l_int *, l_int *,
double *, l_int *, l_int *)
void cgetrs_(char *, l_int *, l_int *, float complex *, l_int *,
l_int *, float complex *, l_int *, l_int *)
void zgetrs_(char *, l_int *, l_int *, double complex *, l_int *,
l_int *, double complex *, l_int *, l_int *)
void sgecon_(char *, l_int *, float *, l_int *, float *, float *,
float *, l_int *, l_int *)
void dgecon_(char *, l_int *, double *, l_int *, double *, double *,
double *, l_int *, l_int *)
void cgecon_(char *, l_int *, float complex *, l_int *, float *,
float *, float complex *, float *, l_int *)
void zgecon_(char *, l_int *, double complex *, l_int *, double *,
double *, double complex *, double *, l_int *)
void sggev_(char *, char *, l_int *, float *, l_int *, float *, l_int *,
float *, float *, float *, float *, l_int *, float *, l_int *,
float *, l_int *, l_int *)
void dggev_(char *, char *, l_int *, double *, l_int *, double *, l_int *,
double *, double *, double *, double *, l_int *,
double *, l_int *, double *, l_int *, l_int *)
void cggev_(char *, char *, l_int *, float complex *, l_int *,
float complex *, l_int *, float complex *, float complex *,
float complex *, l_int *, float complex *, l_int *,
float complex *, l_int *, float *, l_int *)
void zggev_(char *, char *, l_int *, double complex *, l_int *,
double complex *, l_int *, double complex *,
double complex *, double complex *, l_int *,
double complex *, l_int *, double complex *, l_int *,
double *, l_int *)
void sgees_(char *, char *, l_logical (*)(float *, float *),
l_int *, float *, l_int *, l_int *,
float *, float *, float *, l_int *,
float *, l_int *, l_logical *, l_int *)
void dgees_(char *, char *, l_logical (*)(double *, double *),
l_int *, double *, l_int *, l_int *,
double *, double *, double *, l_int *,
double *, l_int *, l_logical *, l_int *)
void cgees_(char *, char *,
l_logical (*)(float complex *),
l_int *, float complex *,
l_int *, l_int *, float complex *,
float complex *, l_int *,
float complex *, l_int *, float *,
l_logical *, l_int *)
void zgees_(char *, char *,
l_logical (*)(double complex *),
l_int *, double complex *,
l_int *, l_int *, double complex *,
double complex *, l_int *,
double complex *, l_int *,
double *, l_logical *, l_int *)
void strsen_(char *, char *, l_logical *, l_int *,
float *, l_int *, float *,
l_int *, float *, float *, l_int *,
float *, float *, float *, l_int *,
l_int *, l_int *, l_int *)
void dtrsen_(char *, char *, l_logical *,
l_int *, double *, l_int *,
double *, l_int *, double *, double *,
l_int *, double *, double *, double *,
l_int *, l_int *, l_int *, l_int *)
void ctrsen_(char *, char *, l_logical *,
l_int *, float complex *,
l_int *, float complex *,
l_int *, float complex *, l_int *,
float *, float *, float complex *,
l_int *, l_int *)
void ztrsen_(char *, char *, l_logical *,
l_int *, double complex *,
l_int *, double complex *,
l_int *, double complex *, l_int *,
double *, double *, double complex *,
l_int *, l_int *)
void strevc_(char *, char *, l_logical *,
l_int *, float *, l_int *,
float *, l_int *, float *, l_int *,
l_int *, l_int *, float *, l_int *)
void dtrevc_(char *, char *, l_logical *,
l_int *, double *, l_int *,
double *, l_int *, double *,
l_int *, l_int *, l_int *, double *,
l_int *)
void ctrevc_(char *, char *, l_logical *,
l_int *, float complex *,
l_int *, float complex *,
l_int *, float complex *,
l_int *, l_int *, l_int *,
float complex *, float *, l_int *)
void ztrevc_(char *, char *, l_logical *,
l_int *, double complex *,
l_int *, double complex *,
l_int *, double complex *,
l_int *, l_int *, l_int *,
double complex *, double *, l_int *)
void sgges_(char *, char *, char *,
l_logical (*)(float *, float *, float *),
l_int *, float *, l_int *, float *,
l_int *, l_int *, float *, float *,
float *, float *, l_int *, float *,
l_int *, float *, l_int *, l_logical *,
l_int *)
void dgges_(char *, char *, char *,
l_logical (*)(double *, double *, double *),
l_int *, double *, l_int *, double *,
l_int *, l_int *, double *, double *,
double *, double *, l_int *, double *,
l_int *, double *, l_int *,
l_logical *, l_int *)
void cgges_(char *, char *, char *,
l_logical (*)(float complex *, float complex *),
l_int *, float complex *,
l_int *, float complex *,
l_int *, l_int *, float complex *,
float complex *, float complex *,
l_int *, float complex *,
l_int *, float complex *,
l_int *, float *, l_logical *, l_int *)
void zgges_(char *, char *, char *,
l_logical (*)(double complex *, double complex *),
l_int *, double complex *,
l_int *, double complex *,
l_int *, l_int *, double complex *,
double complex *,
double complex *, l_int *,
double complex *, l_int *,
double complex *, l_int *,
double *, l_logical *, l_int *)
void stgsen_(l_int *, l_logical *,
l_logical *, l_logical *,
l_int *, float *, l_int *, float *,
l_int *, float *, float *, float *,
float *, l_int *, float *, l_int *,
l_int *, float *, float *, float *, float *,
l_int *, l_int *, l_int *, l_int *)
void dtgsen_(l_int *, l_logical *,
l_logical *, l_logical *,
l_int *, double *, l_int *,
double *, l_int *, double *, double *,
double *, double *, l_int *, double *,
l_int *, l_int *, double *, double *,
double *, double *, l_int *, l_int *,
l_int *, l_int *)
void ctgsen_(l_int *, l_logical *,
l_logical *, l_logical *,
l_int *, float complex *,
l_int *, float complex *,
l_int *, float complex *,
float complex *,
float complex *, l_int *,
float complex *, l_int *, l_int *,
float *, float *, float *,
float complex *, l_int *, l_int *,
l_int *, l_int *)
void ztgsen_(l_int *, l_logical *,
l_logical *, l_logical *,
l_int *, double complex *,
l_int *, double complex *,
l_int *, double complex *,
double complex *,
double complex *, l_int *,
double complex *, l_int *, l_int *,
double *, double *, double *,
double complex *, l_int *, l_int *,
l_int *, l_int *)
void stgevc_(char *, char *, l_logical *,
l_int *, float *, l_int *,
float *, l_int *, float *,
l_int *, float *, l_int *,
l_int *, l_int *, float *, l_int *)
void dtgevc_(char *, char *, l_logical *,
l_int *, double *, l_int *,
double *, l_int *, double *,
l_int *, double *, l_int *,
l_int *, l_int *, double *, l_int *)
void ctgevc_(char *, char *, l_logical *,
l_int *, float complex *,
l_int *, float complex *,
l_int *, float complex *,
l_int *, float complex *,
l_int *, l_int *, l_int *,
float complex *, float *, l_int *)
void ztgevc_(char *, char *, l_logical *,
l_int *, double complex *,
l_int *, double complex *,
l_int *, double complex *,
l_int *, double complex *,
l_int *, l_int *, l_int *,
double complex *, double *, l_int *)
# Copyright 2011-2013 Kwant authors.
#
# This file is part of Kwant. It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
# http://kwant-project.org/license. A list of Kwant authors can be found in
# the file AUTHORS.rst at the top-level directory of this distribution and at
# http://kwant-project.org/authors.
__all__ = ['l_int_dtype', 'l_logical_dtype']
import numpy as np
l_int_dtype = np.int32
l_logical_dtype = np.int32
This diff is collapsed.
...@@ -372,9 +372,8 @@ def search_mumps(): ...@@ -372,9 +372,8 @@ def search_mumps():
# Conda (via conda-forge). # Conda (via conda-forge).
# TODO: remove dependency libs (scotch, metis...) when conda-forge # TODO: remove dependency libs (scotch, metis...) when conda-forge
# packaged mumps/scotch are built as properly linked shared libs # packaged mumps/scotch are built as properly linked shared libs
# 'openblas' provides Lapack and BLAS symbols
['zmumps', 'mumps_common', 'metis', 'esmumps', 'scotch', ['zmumps', 'mumps_common', 'metis', 'esmumps', 'scotch',
'scotcherr', 'mpiseq', 'openblas'], 'scotcherr', 'mpiseq'],
] ]
common_libs = ['pord', 'gfortran'] common_libs = ['pord', 'gfortran']
...@@ -385,7 +384,34 @@ def search_mumps(): ...@@ -385,7 +384,34 @@ def search_mumps():
return [] return []
def search_lapack():
"""Return the BLAS variant that is installed."""
lib_sets = [
# Debian
['blas', 'lapack'],
# Conda (via conda-forge). Openblas contains lapack symbols
['openblas', 'gfortran'],
]
for libs in lib_sets:
found_libs = search_libs(libs)
if found_libs:
return found_libs
print('Error: BLAS/LAPACK are required but were not found.',
file=sys.stderr)
sys.exit(1)
def configure_special_extensions(exts, build_summary): def configure_special_extensions(exts, build_summary):
#### Special config for LAPACK.
lapack = exts['kwant.linalg.lapack']
if 'libraries' in lapack:
build_summary.append('User-configured LAPACK and BLAS')
else:
lapack['libraries'] = search_lapack()
build_summary.append('Default LAPACK and BLAS')
#### Special config for MUMPS. #### Special config for MUMPS.
mumps = exts['kwant.linalg._mumps'] mumps = exts['kwant.linalg._mumps']
if 'libraries' in mumps: if 'libraries' in mumps:
...@@ -400,6 +426,12 @@ def configure_special_extensions(exts, build_summary): ...@@ -400,6 +426,12 @@ def configure_special_extensions(exts, build_summary):
del exts['kwant.linalg._mumps'] del exts['kwant.linalg._mumps']
build_summary.append('No MUMPS support') build_summary.append('No MUMPS support')
if mumps:
# Copy config from LAPACK.
for key, value in lapack.items():
if key not in ['sources', 'depends']:
mumps.setdefault(key, []).extend(value)
return exts return exts
...@@ -518,7 +550,8 @@ def main(): ...@@ -518,7 +550,8 @@ def main():
'kwant/graph/c_slicer/partitioner.h', 'kwant/graph/c_slicer/partitioner.h',
'kwant/graph/c_slicer/slicer.h'])), 'kwant/graph/c_slicer/slicer.h'])),
('kwant.linalg.lapack', ('kwant.linalg.lapack',
dict(sources=['kwant/linalg/lapack.pyx'])), dict(sources=['kwant/linalg/lapack.pyx'],
depends=['kwant/linalg/f_lapack.pxd'])),
('kwant.linalg._mumps', ('kwant.linalg._mumps',
dict(sources=['kwant/linalg/_mumps.pyx'], dict(sources=['kwant/linalg/_mumps.pyx'],
depends=['kwant/linalg/cmumps.pxd']))]) depends=['kwant/linalg/cmumps.pxd']))])
...@@ -534,7 +567,8 @@ def main(): ...@@ -534,7 +567,8 @@ def main():
for ext in exts.values(): for ext in exts.values():
ext.setdefault('include_dirs', []).append(numpy_include) ext.setdefault('include_dirs', []).append(numpy_include)
aliases = [('mumps', 'kwant.linalg._mumps')] aliases = [('lapack', 'kwant.linalg.lapack'),
('mumps', 'kwant.linalg._mumps')]
init_cython() init_cython()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment