diff --git a/kwant/lattice.py b/kwant/lattice.py index e050552c76e1a7503d4058d08c3c0f32cc660c02..cb4e0634e41d7e40b6705577d5bb151b90eae1e4 100644 --- a/kwant/lattice.py +++ b/kwant/lattice.py @@ -52,6 +52,19 @@ def general(prim_vecs, basis=None, name='', norbs=None): return Polyatomic(prim_vecs, basis, name=name, norbs=norbs) +def _check_prim_vecs(prim_vecs): + """Check constraints to ensure that prim_vecs is correct.""" + if prim_vecs.ndim != 2: + raise ValueError('``prim_vecs`` must be a 2d array-like object.') + + if prim_vecs.shape[0] > prim_vecs.shape[1]: + raise ValueError('Number of primitive vectors exceeds ' + 'the space dimensionality.') + + if np.linalg.matrix_rank(prim_vecs) < len(prim_vecs): + raise ValueError('"prim_vecs" must be linearly independent.') + + class Polyatomic: """ A Bravais lattice with an arbitrary number of sites in the basis. @@ -81,17 +94,14 @@ class Polyatomic: """ def __init__(self, prim_vecs, basis, name='', norbs=None): prim_vecs = ta.array(prim_vecs, float) - if prim_vecs.ndim != 2: - raise ValueError('`prim_vecs` must be a 2d array-like object.') + _check_prim_vecs(prim_vecs) + dim = prim_vecs.shape[1] if name is None: name = '' if isinstance(name, str): name = [name + str(i) for i in range(len(basis))] - if prim_vecs.shape[0] > dim: - raise ValueError('Number of primitive vectors exceeds ' - 'the space dimensionality.') basis = ta.array(basis, float) if basis.ndim != 2: raise ValueError('`basis` must be a 2d array-like object.') @@ -423,14 +433,12 @@ class Monatomic(builder.SiteFamily, Polyatomic): def __init__(self, prim_vecs, offset=None, name='', norbs=None): prim_vecs = ta.array(prim_vecs, float) - if prim_vecs.ndim != 2: - raise ValueError('``prim_vecs`` must be a 2d array-like object.') + _check_prim_vecs(prim_vecs) + dim = prim_vecs.shape[1] if name is None: name = '' - if prim_vecs.shape[0] > dim: - raise ValueError('Number of primitive vectors exceeds ' - 'the space dimensionality.') + if offset is None: offset = ta.zeros(dim) else: diff --git a/kwant/tests/test_lattice.py b/kwant/tests/test_lattice.py index ab4cea6096c5b96235dd28619aeceab9e2dea5d6..e803514b9057460b91e1fd9bc76930e903bed74b 100644 --- a/kwant/tests/test_lattice.py +++ b/kwant/tests/test_lattice.py @@ -12,6 +12,7 @@ import tinyarray as ta from pytest import raises from kwant import lattice, builder from kwant._common import ensure_rng +import pytest def test_closest(): @@ -198,6 +199,19 @@ def test_monatomic_lattice(): lat3 = lattice.square(name='no') assert len(set([lat, lat2, lat3, lat(0, 0), lat2(0, 0), lat3(0, 0)])) == 4 +@pytest.mark.parametrize('prim_vecs, basis', [ + (1, None), + ([1], None), + ([1, 0], [[0, 0]]), + ([[1, 0], [2, 0]], None), + ([[1, 0], [2, 0]], [[0, 0]]), + ([[1, 0], [0, 2], [1, 2]], None), + ([[1, 0], [0, 2], [1, 2]], [[0, 0]]), +]) +def test_lattice_constraints(prim_vecs, basis): + with pytest.raises(ValueError): + lattice.general(prim_vecs, basis) + def test_norbs(): id_mat = np.identity(2) diff --git a/kwant/wraparound.py b/kwant/wraparound.py index d43d6cc8f343709d97f57cf417a8be72b80afd8a..9b0e41854919d894a05f16f71a59026acc9e1744 100644 --- a/kwant/wraparound.py +++ b/kwant/wraparound.py @@ -421,7 +421,7 @@ def plot_2d_bands(syst, k_x=31, k_y=31, params=None, # columns of B are lattice vectors B = np.array(syst._wrapped_symmetry.periods).T # columns of A are reciprocal lattice vectors - A = B.dot(np.linalg.inv(B.T.dot(B))) + A = np.linalg.pinv(B).T ## calculate the bounding box for the 1st Brillouin zone