Verified Commit 4a0a7962 authored by Anton Akhmerov's avatar Anton Akhmerov
Browse files

remove our custom LAPACK LU wrapper

parent c32e0371
......@@ -10,10 +10,8 @@ __all__ = ['lapack']
from . import lapack
# Merge the public interface of the other submodules.
from .decomp_lu import *
from .decomp_schur import *
from .decomp_ev import *
# Copyright 2011-2013 Kwant authors.
# This file is part of Kwant. It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
# A list of Kwant authors can be found in
# the file AUTHORS.rst at the top-level directory of this distribution and at
__all__ = ['lu_factor', 'lu_solve', 'rcond_from_lu']
import numpy as np
from . import lapack
def lu_factor(a, overwrite_a=False):
"""Compute the LU factorization of a matrix A = P * L * U. The function
returns a tuple (lu, p, singular), where lu contains the LU factorization
storing the unit lower triangular matrix L in the strictly lower triangle
(the unit diagonal is not stored) and the upper triangular matrix U in the
upper triangle. p is a vector of pivot indices, and singular a Boolean
value indicating whether the matrix A is singular up to machine precision.
NOTE: This function mimics the behavior of scipy.linalg.lu_factor (except
that it has in addition the flag singular). The main reason is that
lu_factor in SciPy has a bug that depending on the type of NumPy matrix
passed to it, it would not return what was descirbed in the
documentation. This bug will be (probably) fixed in 0.10.0 but until this
is standard, this version is better to use.
a : array, shape (M, M)
Matrix to factorize
overwrite_a : boolean
Whether to overwrite data in a (may increase performance)
lu : array, shape (N, N)
Matrix containing U in its upper triangle, and L in its lower triangle.
The unit diagonal elements of L are not stored.
piv : array, shape (N,)
Pivot indices representing the permutation matrix P:
row i of matrix was interchanged with row piv[i].
singular : boolean
Whether the matrix a is singular (up to machine precision)
a = lapack.prepare_for_lapack(overwrite_a, a)
return lapack.getrf(a)
def lu_solve(matrix_factorization, b):
"""Solve a linear system of equations, a x = b, given the LU
factorization of a
Factorization of the coefficient matrix a, as given by lu_factor
b : array (vector or matrix)
Right-hand side
x : array (vector or matrix)
Solution to the system
(lu, ipiv, singular) = matrix_factorization
if singular:
raise RuntimeWarning("In lu_solve: the flag singular indicates "
"a singular matrix. Result of solve step "
"are probably unreliable")
lu, b = lapack.prepare_for_lapack(False, lu, b)
ipiv = np.ascontiguousarray(np.asanyarray(ipiv), dtype=lapack.int_dtype)
return lapack.getrs(lu, ipiv, b)
def rcond_from_lu(matrix_factorization, norm_a, norm="1"):
"""Compute the reciprocal condition number from the LU decomposition as
returned from lu_factor(), given additionally the norm of the matrix a in
The reciprocal condition number is given as 1/(||A||*||A^-1||), where
||...|| is a matrix norm.
Factorization of the matrix a, as given by lu_factor
norm_a : float or complex
norm of the original matrix a (type of norm is specified in norm)
norm : {'1', 'I'}, optional
type of matrix norm which should be used to compute the condition
number ("1": 1-norm, "I": infinity norm). Default: '1'.
rcond : float or complex
reciprocal condition number of a with respect to the type of matrix
norm specified in norm
(lu, ipiv, singular) = matrix_factorization
norm = norm.encode('utf8') # lapack expects bytes
lu = lapack.prepare_for_lapack(False, lu)
return lapack.gecon(lu, norm_a, norm)
......@@ -8,9 +8,7 @@
"""Low-level access to LAPACK functions. """
__all__ = ['getrf',
__all__ = ['gecon',
......@@ -92,85 +90,6 @@ cdef l_int lwork_from_qwork(scalar qwork):
return <l_int>qwork.real
def getrf(np.ndarray[scalar, ndim=2] A):
cdef l_int M, N, info
cdef np.ndarray[l_int] ipiv
M = A.shape[0]
N = A.shape[1]
ipiv = np.empty(min(M,N), dtype = int_dtype)
if scalar is float:
lapack.sgetrf(&M, &N, <float *>, &M,
<l_int *>, &info)
elif scalar is double:
lapack.dgetrf(&M, &N, <double *>, &M,
<l_int *>, &info)
elif scalar is float_complex:
lapack.cgetrf(&M, &N, <float complex *>, &M,
<l_int *>, &info)
elif scalar is double_complex:
lapack.zgetrf(&M, &N, <double complex *>, &M,
<l_int *>, &info)
assert info >= 0, "Argument error in getrf"
return (A, ipiv, info > 0 or M != N)
def getrs(np.ndarray[scalar, ndim=2] LU, np.ndarray[l_int] IPIV,
np.ndarray B):
cdef l_int N, NRHS, info
# Consistency checks for LU and B
if B.descr.type_num != LU.descr.type_num:
raise TypeError('B must have same dtype as LU')
# Workaround for 1x1-Fortran bug in NumPy < v2.0
if ((B.ndim == 2 and (B.shape[0] > 1 or B.shape[1] > 1) and
not B.flags["F_CONTIGUOUS"])):
raise ValueError("B must be Fortran ordered")
if B.ndim > 2:
raise ValueError("B must be a vector or matrix")
if LU.shape[0] != B.shape[0]:
raise ValueError('LU and B have incompatible shapes')
N = LU.shape[0]
if B.ndim == 1:
NRHS = 1
elif B.ndim == 2:
NRHS = B.shape[1]
if scalar is float:
lapack.sgetrs("N", &N, &NRHS, <float *>, &N,
<l_int *>, <float *>, &N,
elif scalar is double:
lapack.dgetrs("N", &N, &NRHS, <double *>, &N,
<l_int *>, <double *>, &N,
elif scalar is float_complex:
lapack.cgetrs("N", &N, &NRHS, <float complex *>, &N,
<l_int *>, <float complex *>, &N,
elif scalar is double_complex:
lapack.zgetrs("N", &N, &NRHS, <double complex *>, &N,
<l_int *>, <double complex *>, &N,
assert info == 0, "Argument error in getrs"
return B
def gecon(np.ndarray[scalar, ndim=2] LU, double normA, char *norm = b"1"):
cdef l_int N, info
cdef float srcond, snormA
......@@ -8,9 +8,10 @@
import pytest
import numpy as np
from scipy import linalg
from kwant.linalg import (
lu_factor, lu_solve, rcond_from_lu, gen_eig, schur,
gen_eig, schur,
convert_r2c_schur, order_schur, evecs_from_schur, gen_schur,
convert_r2c_gen_schur, order_gen_schur, evecs_from_gen_schur)
from ._test_utils import _Random, assert_array_almost_equal
......@@ -36,46 +37,6 @@ def test_gen_eig(dtype):
alpha @ vl.T.conj() @ b)
def test_lu(dtype):
rand = _Random()
a = rand.randmat(4, 4, dtype)
bmat = rand.randmat(4, 4, dtype)
bvec = rand.randvec(4, dtype)
lu = lu_factor(a)
xmat = lu_solve(lu, bmat)
xvec = lu_solve(lu, bvec)
assert_array_almost_equal(dtype, a @ xmat, bmat)
assert_array_almost_equal(dtype, a @ xvec, bvec)
def test_rcond_from_lu(dtype):
rand = _Random()
a = rand.randmat(10, 10, dtype)
norm1_a = np.linalg.norm(a, 1)
normI_a = np.linalg.norm(a, np.inf)
lu = lu_factor(a)
rcond1 = rcond_from_lu(lu, norm1_a, '1')
rcondI = rcond_from_lu(lu, normI_a, 'I')
err1 = abs(rcond1 -
1/(norm1_a * np.linalg.norm(np.linalg.inv(a), 1)))
errI = abs(rcondI -
1/(normI_a * np.linalg.norm(np.linalg.inv(a), np.inf)))
#rcond_from_lu returns an estimate for the reciprocal
#condition number only; hence we shouldn't be too strict about
#the assertions here
#Note: in my experience the estimate is excellent for somewhat
#larger matrices
assert err1/rcond1 < 0.1
assert errI/rcondI < 0.1
def test_schur(dtype):
rand = _Random()
a = rand.randmat(5, 5, dtype)
......@@ -9,6 +9,7 @@
from math import sin, cos, sqrt, pi, copysign
from collections import namedtuple
import warnings
from itertools import combinations_with_replacement
import numpy as np
......@@ -23,6 +24,30 @@ from scipy.sparse import (identity as sp_identity, hstack as sp_hstack,
__all__ = ['selfenergy', 'modes', 'PropagatingModes', 'StabilizedModes']
def lu_factor_rcond(mat: np.ndarray):
"""Perform LU factorization and check condition number.
mat : numpy array
Matrix to be factorized
sol :
LU factorization
rcond : float
Condition number
with warnings.catch_warnings():
warnings.simplefilter("ignore", la.LinAlgWarning)
sol = la.lu_factor(mat)
lu = np.asfortranarray(sol[0])
gecon = la.lapack.get_lapack_funcs("gecon", [lu])
return sol, gecon(lu, npl.norm(mat, 1))[0]
def nonzero_symm_projection(matrix):
"""Check whether a discrete symmetry relation between two blocks of the
Hamiltonian vanishes or not.
......@@ -283,8 +308,7 @@ def setup_linsys(h_cell, h_hop, tol=1e6, stabilization=None):
# Check if there is a chance we will not need to add an imaginary term.
if not add_imaginary:
h = h_cell
sol = kla.lu_factor(h)
rcond = kla.rcond_from_lu(sol, npl.norm(h, 1))
sol, rcond = lu_factor_rcond(h)
if rcond < eps:
need_to_stabilize = True
......@@ -298,8 +322,7 @@ def setup_linsys(h_cell, h_hop, tol=1e6, stabilization=None):
temp = u @ u.T.conj() + v @ v.T.conj()
h = h_cell + 1j * temp
sol = kla.lu_factor(h)
rcond = kla.rcond_from_lu(sol, npl.norm(h, 1))
sol, rcond = lu_factor_rcond(h)
# If the condition number of the stabilized h is
# still bad, there is nothing we can do.
......@@ -315,7 +338,7 @@ def setup_linsys(h_cell, h_hop, tol=1e6, stabilization=None):
if need_to_stabilize:
wf += 1j * (v @ psi[: n_nonsing] +
u @ (psi[n_nonsing:] * lmbdainv))
return kla.lu_solve(sol, wf)
return la.lu_solve(sol, wf)
# Setup the generalized eigenvalue problem.
......@@ -325,7 +348,7 @@ def setup_linsys(h_cell, h_hop, tol=1e6, stabilization=None):
begin, end = slice(n_nonsing), slice(n_nonsing, None)
A[end, begin] = np.identity(n_nonsing)
temp = kla.lu_solve(sol, v)
temp = la.lu_solve(sol, v)
temp2 = u.T.conj() @ temp
if need_to_stabilize:
A[begin, begin] = -1j * temp2
......@@ -336,7 +359,7 @@ def setup_linsys(h_cell, h_hop, tol=1e6, stabilization=None):
A[end, end] = temp2
B[begin, end] = -np.identity(n_nonsing)
temp = kla.lu_solve(sol, u)
temp = la.lu_solve(sol, u)
temp2 = u.T.conj() @ temp
B[begin, begin] = -temp2
if need_to_stabilize:
......@@ -354,15 +377,15 @@ def setup_linsys(h_cell, h_hop, tol=1e6, stabilization=None):
# the generalized eigenvalue problem to a regular one, provided
# the matrix B can be safely inverted.
lu_b = kla.lu_factor(B)
lu_b, rcond = lu_factor_rcond(B)
if not stabilization[1]:
rcond = kla.rcond_from_lu(lu_b, npl.norm(B, 1))
# A more stringent condition is used here since errors can
# accumulate from here to the eigenvalue calculation later.
stabilization[1] = rcond > eps * tol
if stabilization[1]:
matrices = (kla.lu_solve(lu_b, A), None)
matrices = (la.lu_solve(lu_b, A), None)
matrices = (A, B)
return Linsys(matrices, v, extract_wf)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment