Skip to content
Snippets Groups Projects
Commit 4cd85a3d authored by Rafal Skolasinski's avatar Rafal Skolasinski Committed by Joseph Weston
Browse files

factor out and extend check of constraints for "prim_vecs"

Currently code that do constraints check on "prim_vecs" during lattice
creation is doubled. In addition to factoring it out into separate
function this commit also adds check if "prim_vecs" are linearly independent.
parent a14a0f72
Branches
No related tags found
No related merge requests found
......@@ -52,6 +52,19 @@ def general(prim_vecs, basis=None, name='', norbs=None):
return Polyatomic(prim_vecs, basis, name=name, norbs=norbs)
def _check_prim_vecs(prim_vecs):
"""Check constraints to ensure that prim_vecs is correct."""
if prim_vecs.ndim != 2:
raise ValueError('``prim_vecs`` must be a 2d array-like object.')
if prim_vecs.shape[0] > prim_vecs.shape[1]:
raise ValueError('Number of primitive vectors exceeds '
'the space dimensionality.')
if np.linalg.matrix_rank(prim_vecs) < len(prim_vecs):
raise ValueError('"prim_vecs" must be linearly independent.')
class Polyatomic:
"""
A Bravais lattice with an arbitrary number of sites in the basis.
......@@ -81,17 +94,14 @@ class Polyatomic:
"""
def __init__(self, prim_vecs, basis, name='', norbs=None):
prim_vecs = ta.array(prim_vecs, float)
if prim_vecs.ndim != 2:
raise ValueError('`prim_vecs` must be a 2d array-like object.')
_check_prim_vecs(prim_vecs)
dim = prim_vecs.shape[1]
if name is None:
name = ''
if isinstance(name, str):
name = [name + str(i) for i in range(len(basis))]
if prim_vecs.shape[0] > dim:
raise ValueError('Number of primitive vectors exceeds '
'the space dimensionality.')
basis = ta.array(basis, float)
if basis.ndim != 2:
raise ValueError('`basis` must be a 2d array-like object.')
......@@ -423,14 +433,12 @@ class Monatomic(builder.SiteFamily, Polyatomic):
def __init__(self, prim_vecs, offset=None, name='', norbs=None):
prim_vecs = ta.array(prim_vecs, float)
if prim_vecs.ndim != 2:
raise ValueError('``prim_vecs`` must be a 2d array-like object.')
_check_prim_vecs(prim_vecs)
dim = prim_vecs.shape[1]
if name is None:
name = ''
if prim_vecs.shape[0] > dim:
raise ValueError('Number of primitive vectors exceeds '
'the space dimensionality.')
if offset is None:
offset = ta.zeros(dim)
else:
......
......@@ -12,6 +12,7 @@ import tinyarray as ta
from pytest import raises
from kwant import lattice, builder
from kwant._common import ensure_rng
import pytest
def test_closest():
......@@ -198,6 +199,19 @@ def test_monatomic_lattice():
lat3 = lattice.square(name='no')
assert len(set([lat, lat2, lat3, lat(0, 0), lat2(0, 0), lat3(0, 0)])) == 4
@pytest.mark.parametrize('prim_vecs, basis', [
(1, None),
([1], None),
([1, 0], [[0, 0]]),
([[1, 0], [2, 0]], None),
([[1, 0], [2, 0]], [[0, 0]]),
([[1, 0], [0, 2], [1, 2]], None),
([[1, 0], [0, 2], [1, 2]], [[0, 0]]),
])
def test_lattice_constraints(prim_vecs, basis):
with pytest.raises(ValueError):
lattice.general(prim_vecs, basis)
def test_norbs():
id_mat = np.identity(2)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment