Verified Commit 655f6f55 authored by Anton Akhmerov's avatar Anton Akhmerov
Browse files

switch to a dtype fixture in linalg tests and use @ for dot

parent a46b11e6
......@@ -6,36 +6,37 @@
# the file AUTHORS.rst at the top-level directory of this distribution and at
# https://kwant-project.org/authors.
import pytest
import numpy as np
from kwant.linalg import (
lu_factor, lu_solve, rcond_from_lu, gen_eig, schur,
convert_r2c_schur, order_schur, evecs_from_schur, gen_schur,
convert_r2c_gen_schur, order_gen_schur, evecs_from_gen_schur)
import numpy as np
from ._test_utils import _Random, assert_array_almost_equal
def test_gen_eig():
def _test_gen_eig(dtype):
# int should always be propagated to float64
@pytest.fixture(scope='module', params=[
np.float32, np.float64, np.complex64, np.complex128, np.int32
])
def dtype(request):
return request.param
def test_gen_eig(dtype):
rand = _Random()
a = rand.randmat(4, 4, dtype)
b = rand.randmat(4, 4, dtype)
(alpha, beta, vl, vr) = gen_eig(a, b, True, True)
assert_array_almost_equal(dtype, np.dot(np.dot(a, vr), beta),
np.dot(np.dot(b, vr), alpha))
assert_array_almost_equal(dtype,
np.dot(beta, np.dot(np.conj(vl.T), a)),
np.dot(alpha, np.dot(np.conj(vl.T), b)))
_test_gen_eig(np.float32)
_test_gen_eig(np.float64)
_test_gen_eig(np.complex64)
_test_gen_eig(np.complex128)
#int should be propagated to float64
_test_gen_eig(np.int32)
def test_lu():
def _test_lu(dtype):
assert_array_almost_equal(dtype, a @ vr @ beta, b @ vr @ alpha)
assert_array_almost_equal(dtype, beta @ vl.T.conj() @ a,
alpha @ vl.T.conj() @ b)
def test_lu(dtype):
rand = _Random()
a = rand.randmat(4, 4, dtype)
bmat = rand.randmat(4, 4, dtype)
......@@ -45,18 +46,11 @@ def test_lu():
xmat = lu_solve(lu, bmat)
xvec = lu_solve(lu, bvec)
assert_array_almost_equal(dtype, np.dot(a, xmat), bmat)
assert_array_almost_equal(dtype, np.dot(a, xvec), bvec)
assert_array_almost_equal(dtype, a @ xmat, bmat)
assert_array_almost_equal(dtype, a @ xvec, bvec)
_test_lu(np.float32)
_test_lu(np.float64)
_test_lu(np.complex64)
_test_lu(np.complex128)
#int should be propagated to float64
_test_lu(np.int32)
def test_rcond_from_lu():
def _test_rcond_from_lu(dtype):
def test_rcond_from_lu(dtype):
rand = _Random()
a = rand.randmat(10, 10, dtype)
......@@ -81,51 +75,29 @@ def test_rcond_from_lu():
assert err1/rcond1 < 0.1
assert errI/rcondI < 0.1
_test_rcond_from_lu(np.float32)
_test_rcond_from_lu(np.float64)
_test_rcond_from_lu(np.complex64)
_test_rcond_from_lu(np.complex128)
#int should be propagated to float64
_test_rcond_from_lu(np.int32)
def test_schur():
def _test_schur(dtype):
def test_schur(dtype):
rand = _Random()
a = rand.randmat(5, 5, dtype)
t, q = schur(a)[:2]
assert_array_almost_equal(dtype, np.dot(np.dot(q, t), np.conj(q.T)), a)
assert_array_almost_equal(dtype, q @ t @ q.T.conj(), a)
_test_schur(np.float32)
_test_schur(np.float64)
_test_schur(np.complex64)
_test_schur(np.complex128)
#int should be propagated to float64
_test_schur(np.int32)
def test_convert_r2c_schur():
def _test_convert_r2c_schur(dtype):
# in the complex case the function should actually just copy
def test_convert_r2c_schur(dtype):
rand = _Random()
a = rand.randmat(10, 10, dtype)
t, q = schur(a)[:2]
t2, q2 = convert_r2c_schur(t, q)
assert_array_almost_equal(dtype, np.dot(np.dot(q, t), np.conj(q.T)), a)
assert_array_almost_equal(dtype, np.dot(np.dot(q2, t2), np.conj(q2.T)),
a)
assert_array_almost_equal(dtype, q @ t @ q.T.conj(), a)
assert_array_almost_equal(dtype, q2 @ t2 @ q2.T.conj(), a)
_test_convert_r2c_schur(np.float32)
_test_convert_r2c_schur(np.float64)
#in the complex case the function should actually just copy
_test_convert_r2c_schur(np.complex64)
_test_convert_r2c_schur(np.complex128)
#int should be propagated to float64
_test_convert_r2c_schur(np.int32)
def test_order_schur():
def _test_order_schur(dtype):
def test_order_schur(dtype):
rand = _Random()
a = rand.randmat(10, 10, dtype)
......@@ -133,29 +105,20 @@ def test_order_schur():
t2, q2, ev2 = order_schur(lambda i: i>2 and i<7, t, q)
assert_array_almost_equal(dtype, np.dot(np.dot(q, t), np.conj(q.T)), a)
assert_array_almost_equal(dtype, np.dot(np.dot(q2, t2), np.conj(q2.T)),
a)
assert_array_almost_equal(dtype, q @ t @ q.T.conj(), a)
assert_array_almost_equal(dtype, q2 @ t2 @ q2.T.conj(), a)
assert_array_almost_equal(dtype, np.sort(ev), np.sort(ev2))
assert_array_almost_equal(dtype, np.sort(ev[3:7]), np.sort(ev2[:4]))
sel = [False, False, 0, True, True, True, 1, False, False, False]
t3, q3 = order_schur(sel, t, q)[:2]
assert_array_almost_equal(dtype, np.dot(np.dot(q3, t3), np.conj(q3.T)),
a)
assert_array_almost_equal(dtype, q3 @ t3 @ q3.T.conj(), a)
assert_array_almost_equal(dtype, t2, t3)
assert_array_almost_equal(dtype, q2, q3)
_test_order_schur(np.float32)
_test_order_schur(np.float64)
_test_order_schur(np.complex64)
_test_order_schur(np.complex128)
#int should be propagated to float64
_test_order_schur(np.int32)
def test_evecs_from_schur():
def _test_evecs_from_schur(dtype):
def test_evecs_from_schur(dtype):
rand = _Random()
a = rand.randmat(5, 5, dtype)
......@@ -163,59 +126,40 @@ def test_evecs_from_schur():
vl, vr = evecs_from_schur(t, q, select=None, left=True, right=True)
assert_array_almost_equal(dtype, np.dot(vr, np.dot(np.diag(ev),
np.linalg.inv(vr))), a)
assert_array_almost_equal(dtype, np.dot(np.linalg.inv(np.conj(vl.T)),
np.dot(np.diag(ev), np.conj(vl.T))),
a)
assert_array_almost_equal(dtype, vr @ np.diag(ev) @ np.linalg.inv(vr), a)
assert_array_almost_equal(dtype, (np.linalg.inv(vl.T.conj())
@ np.diag(ev) @ vl.T.conj()), a)
select = np.array([True, True, False, False, False], dtype=bool)
vl, vr = evecs_from_schur(t, q, select, left=True, right=True)
assert vr.shape[1] == 2
assert vl.shape[1] == 2
assert_array_almost_equal(dtype, np.dot(a, vr),
np.dot(vr, np.diag(ev[select])))
assert_array_almost_equal(dtype, np.dot(vl.T.conj(), a),
np.dot(np.diag(ev[select]), vl.T.conj()))
assert vr.shape[1] == vl.shape[1] == 2
assert_array_almost_equal(dtype, a @ vr, vr @ np.diag(ev[select]))
assert_array_almost_equal(dtype, vl.T.conj() @ a,
np.diag(ev[select]) @ vl.T.conj())
vl, vr = evecs_from_schur(t, q, lambda i: i<2, left=True, right=True)
assert vr.shape[1] == 2
assert vl.shape[1] == 2
assert_array_almost_equal(dtype, np.dot(a, vr),
np.dot(vr, np.diag(ev[select])))
assert_array_almost_equal(dtype, np.dot(vl.T.conj(), a),
np.dot(np.diag(ev[select]), vl.T.conj()))
_test_evecs_from_schur(np.float32)
_test_evecs_from_schur(np.float64)
_test_evecs_from_schur(np.complex64)
_test_evecs_from_schur(np.complex128)
#int should be propagated to float64
_test_evecs_from_schur(np.int32)
def test_gen_schur():
def _test_gen_schur(dtype):
assert vr.shape[1] == vl.shape[1] == 2
assert_array_almost_equal(dtype, a @ vr, vr @ np.diag(ev[select]))
assert_array_almost_equal(dtype, vl.T.conj() @ a,
np.diag(ev[select]) @ vl.T.conj())
def test_gen_schur(dtype):
rand = _Random()
a = rand.randmat(5, 5, dtype)
b = rand.randmat(5, 5, dtype)
s, t, q, z = gen_schur(a, b)[:4]
assert_array_almost_equal(dtype, np.dot(np.dot(q, s), z.T.conj()), a)
assert_array_almost_equal(dtype, np.dot(np.dot(q, t), z.T.conj()), b)
assert_array_almost_equal(dtype, q @ s @ z.T.conj(), a)
assert_array_almost_equal(dtype, q @ t @ z.T.conj(), b)
_test_gen_schur(np.float32)
_test_gen_schur(np.float64)
_test_gen_schur(np.complex64)
_test_gen_schur(np.complex128)
#int should be propagated to float64
_test_gen_schur(np.int32)
def test_convert_r2c_gen_schur():
def _test_convert_r2c_gen_schur(dtype):
# in the complex case the function should actually just copy
def test_convert_r2c_gen_schur(dtype):
rand = _Random()
a = rand.randmat(10, 10, dtype)
b = rand.randmat(10, 10, dtype)
......@@ -223,23 +167,13 @@ def test_convert_r2c_gen_schur():
s, t, q, z = gen_schur(a, b)[:4]
s2, t2, q2, z2 = convert_r2c_gen_schur(s, t, q, z)
assert_array_almost_equal(dtype, np.dot(np.dot(q, s), z.T.conj()), a)
assert_array_almost_equal(dtype, np.dot(np.dot(q, t), z.T.conj()), b)
assert_array_almost_equal(dtype, np.dot(np.dot(q2, s2), z2.T.conj()),
a)
assert_array_almost_equal(dtype, np.dot(np.dot(q2, t2), z2.T.conj()),
b)
_test_convert_r2c_gen_schur(np.float32)
_test_convert_r2c_gen_schur(np.float64)
#in the complex case the function should actually just copy
_test_convert_r2c_gen_schur(np.complex64)
_test_convert_r2c_gen_schur(np.complex128)
#int should be propagated to float64
_test_convert_r2c_gen_schur(np.int32)
def test_order_gen_schur():
def _test_order_gen_schur(dtype):
assert_array_almost_equal(dtype, q @ s @ z.T.conj(), a)
assert_array_almost_equal(dtype, q @ t @ z.T.conj(), b)
assert_array_almost_equal(dtype, q2 @ s2 @ z2.T.conj(), a)
assert_array_almost_equal(dtype, q2 @ t2 @ z2.T.conj(), b)
def test_order_gen_schur(dtype):
rand = _Random()
a = rand.randmat(10, 10, dtype)
b = rand.randmat(10, 10, dtype)
......@@ -249,12 +183,10 @@ def test_order_gen_schur():
s2, t2, q2, z2, alpha2, beta2 = order_gen_schur(lambda i: i>2 and i<7,
s, t, q, z)
assert_array_almost_equal(dtype, np.dot(np.dot(q, s), z.T.conj()), a)
assert_array_almost_equal(dtype, np.dot(np.dot(q, t), z.T.conj()), b)
assert_array_almost_equal(dtype, np.dot(np.dot(q2, s2), z2.T.conj()),
a)
assert_array_almost_equal(dtype, np.dot(np.dot(q2, t2), z2.T.conj()),
b)
assert_array_almost_equal(dtype, q @ s @ z.T.conj(), a)
assert_array_almost_equal(dtype, q @ t @ z.T.conj(), b)
assert_array_almost_equal(dtype, q2 @ s2 @ z2.T.conj(), a)
assert_array_almost_equal(dtype, q2 @ t2 @ z2.T.conj(), b)
#Sorting here is a bit tricky: For real matrices we expect
#for complex conjugated pairs identical real parts - however
......@@ -276,25 +208,15 @@ def test_order_gen_schur():
sel = [False, False, 0, True, True, True, 1, False, False, False]
s3, t3, q3, z3 = order_gen_schur(sel, s, t, q, z)[:4]
assert_array_almost_equal(dtype, np.dot(np.dot(q3, s3), z3.T.conj()),
a)
assert_array_almost_equal(dtype, np.dot(np.dot(q3, t3), z3.T.conj()),
b)
assert_array_almost_equal(dtype, q3 @ s3 @ z3.T.conj(), a)
assert_array_almost_equal(dtype, q3 @ t3 @ z3.T.conj(), b)
assert_array_almost_equal(dtype, s2, s3)
assert_array_almost_equal(dtype, t2, t3)
assert_array_almost_equal(dtype, q2, q3)
assert_array_almost_equal(dtype, z2, z3)
_test_order_gen_schur(np.float32)
_test_order_gen_schur(np.float64)
_test_order_gen_schur(np.complex64)
_test_order_gen_schur(np.complex128)
#int should be propagated to float64
_test_order_gen_schur(np.int32)
def test_evecs_from_gen_schur():
def _test_evecs_from_gen_schur(dtype):
def test_evecs_from_gen_schur(dtype):
rand = _Random()
a = rand.randmat(5, 5, dtype)
b = rand.randmat(5, 5, dtype)
......@@ -304,55 +226,26 @@ def test_evecs_from_gen_schur():
vl, vr = evecs_from_gen_schur(s, t, q, z , select=None,
left=True, right=True)
assert_array_almost_equal(dtype, np.dot(a, np.dot(vr, np.diag(beta))),
np.dot(b, np.dot(vr, np.diag(alpha))))
assert_array_almost_equal(dtype,
np.dot(np.dot(np.diag(beta), vl.T.conj()),
a),
np.dot(np.dot(np.diag(alpha), vl.T.conj()),
b))
assert_array_almost_equal(dtype, a @ vr @ np.diag(beta),
b @ vr @ np.diag(alpha))
assert_array_almost_equal(dtype, np.diag(beta) @ vl.T.conj() @ a,
np.diag(alpha) @ vl.T.conj() @ b)
select = np.array([True, True, False, False, False], dtype=bool)
vl, vr = evecs_from_gen_schur(s, t, q, z, select,
left=True, right=True)
vl, vr = evecs_from_gen_schur(s, t, q, z, select, left=True, right=True)
assert vr.shape[1] == 2
assert vl.shape[1] == 2
assert_array_almost_equal(dtype,
np.dot(a, np.dot(vr,
np.diag(beta[select]))),
np.dot(b, np.dot(vr,
np.diag(alpha[select]))))
assert_array_almost_equal(dtype,
np.dot(np.dot(np.diag(beta[select]),
vl.T.conj()),
a),
np.dot(np.dot(np.diag(alpha[select]),
vl.T.conj()),
b))
assert vr.shape[1] == vl.shape[1] == 2
assert_array_almost_equal(dtype, a @ vr @ np.diag(beta[select]),
b @ vr @ np.diag(alpha[select]))
assert_array_almost_equal(dtype, np.diag(beta[select]) @ vl.T.conj() @ a,
np.diag(alpha[select]) @ vl.T.conj() @ b)
vl, vr = evecs_from_gen_schur(s, t, q, z, lambda i: i<2, left=True,
right=True)
assert vr.shape[1] == 2
assert vl.shape[1] == 2
assert_array_almost_equal(dtype,
np.dot(a, np.dot(vr,
np.diag(beta[select]))),
np.dot(b, np.dot(vr,
np.diag(alpha[select]))))
assert_array_almost_equal(dtype,
np.dot(np.dot(np.diag(beta[select]),
vl.T.conj()),
a),
np.dot(np.dot(np.diag(alpha[select]),
vl.T.conj()),
b))
_test_evecs_from_gen_schur(np.float32)
_test_evecs_from_gen_schur(np.float64)
_test_evecs_from_gen_schur(np.complex64)
_test_evecs_from_gen_schur(np.complex128)
#int should be propagated to float64
_test_evecs_from_gen_schur(np.int32)
assert vr.shape[1] == vl.shape[1] == 2
assert_array_almost_equal(dtype, a @ vr @ np.diag(beta[select]),
b @ vr @ np.diag(alpha[select]))
assert_array_almost_equal(dtype, np.diag(beta[select]) @ vl.T.conj() @ a,
np.diag(alpha[select]) @ vl.T.conj() @ b)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment