From d4c1cc6c1fcc925fbcb83f84ce153b9213476b83 Mon Sep 17 00:00:00 2001
From: Joseph Weston <joseph.weston08@gmail.com>
Date: Tue, 24 May 2016 17:49:27 +0200
Subject: [PATCH] add an `n_orbs` attribute to `SiteFamily`

Specifying this is presently optional, in order to remain
backwards compatibility, but will be compulsory in Kwant2.
---
 kwant/builder.py            | 35 +++++++++++++-----
 kwant/lattice.py            | 71 ++++++++++++++++++++++++-------------
 kwant/tests/test_lattice.py | 33 +++++++++++++++++
 3 files changed, 106 insertions(+), 33 deletions(-)

diff --git a/kwant/builder.py b/kwant/builder.py
index b6712a2e..24e2c9d6 100644
--- a/kwant/builder.py
+++ b/kwant/builder.py
@@ -95,7 +95,10 @@ class SiteFamily(metaclass=abc.ABCMeta):
     representation and a name.  The canonical representation will be returned as
     the objects representation and must uniquely identify the site family
     instance.  The name is a string used to distinguish otherwise identical site
-    families.  It may be empty.
+    families.  It may be empty. ``norbs`` defines the number of orbitals
+    on sites associated with this site family; it may be `None`, in which case
+    the number of orbitals is not specified.
+
 
     All site families must define the method `normalize_tag` which brings a tag
     to the standard format for this site family.
@@ -104,22 +107,37 @@ class SiteFamily(metaclass=abc.ABCMeta):
     method `pos(tag)`, which returns a vector with real-space coordinates of the
     site belonging to this family with a given tag.
 
+    If the ``norbs`` of a site family are provided, and sites of this family
+    are used to populate a `~kwant.builder.Builder`, then the associated
+    Hamiltonian values must have the correct shape. That is, if a site family
+    has ``norbs = 2``, then any on-site terms for sites belonging to this
+    family should be 2x2 matrices. Similarly, any hoppings to/from sites
+    belonging to this family must have a matrix structure where there are two
+    rows/columns. This condition applies equally to Hamiltonian values that
+    are given by functions. If this condition is not satisfied, an error will
+    be raised.
     """
 
-    def __init__(self, canonical_repr, name):
+    def __init__(self, canonical_repr, name, norbs):
         self.canonical_repr = canonical_repr
         self.hash = hash(canonical_repr)
         self.name = name
+        if norbs is not None:
+            if int(norbs) != norbs or norbs <= 0:
+                raise ValueError('The norbs parameter must be an integer > 0.')
+            norbs = int(norbs)
+        self.norbs = norbs
 
     def __repr__(self):
         return self.canonical_repr
 
     def __str__(self):
         if self.name:
-            msg = '<{0} site family {1}>'
+            msg = '<{0} site family {1}{2}>'
         else:
-            msg = '<unnamed {0} site family>'
-        return msg.format(self.__class__.__name__, self.name)
+            msg = '<unnamed {0} site family{2}>'
+        orbs = ' with {0} orbitals'.format(self.norbs) if self.norbs else ''
+        return msg.format(self.__class__.__name__, self.name, orbs)
 
     def __hash__(self):
         return self.hash
@@ -171,9 +189,10 @@ class SimpleSiteFamily(SiteFamily):
     `SimpleSiteFamily` when `kwant.lattice.Monatomic` would also work.
     """
 
-    def __init__(self, name=None):
-        canonical_repr = '{0}({1})'.format(self.__class__, repr(name))
-        super().__init__(canonical_repr, name)
+    def __init__(self, name=None, norbs=None):
+        canonical_repr = '{0}({1}, {2})'.format(self.__class__, repr(name),
+                                                repr(norbs))
+        super().__init__(canonical_repr, name, norbs)
 
     def normalize_tag(self, tag):
         tag = tuple(tag)
diff --git a/kwant/lattice.py b/kwant/lattice.py
index 22fc0320..c9de68b4 100644
--- a/kwant/lattice.py
+++ b/kwant/lattice.py
@@ -18,7 +18,7 @@ from .linalg import lll
 from ._common import ensure_isinstance
 
 
-def general(prim_vecs, basis=None, name=''):
+def general(prim_vecs, basis=None, name='', norbs=None):
     """
     Create a Bravais lattice of any dimensionality, with any number of sites.
 
@@ -32,6 +32,9 @@ def general(prim_vecs, basis=None, name=''):
         Name of the lattice, or sequence of names of all of the sublattices.
         If the name of the lattice is given, the names of sublattices (if any)
         are obtained by appending their number to the name of the lattice.
+    norbs : int or sequence of ints, optional
+        The number of orbitals per site on the lattice, or a sequence
+        of the number of orbitals of sites on each of the sublattices.
 
     Returns
     -------
@@ -44,9 +47,9 @@ def general(prim_vecs, basis=None, name=''):
     lattices.
     """
     if basis is None:
-        return Monatomic(prim_vecs, name=name)
+        return Monatomic(prim_vecs, name=name, norbs=norbs)
     else:
-        return Polyatomic(prim_vecs, basis, name=name)
+        return Polyatomic(prim_vecs, basis, name=name, norbs=norbs)
 
 
 class Polyatomic:
@@ -62,18 +65,21 @@ class Polyatomic:
         The primitive vectors of the Bravais lattice
     basis : 2d array-like of floats
         The coordinates of the basis sites inside the unit cell.
-    name : string or sequence of strings
+    name : string or sequence of strings, optional
         The name of the lattice, or a sequence of the names of all the
         sublattices.  If the name of the lattice is given, the names of
         sublattices are obtained by appending their number to the name of the
         lattice.
+    norbs : int or sequence of ints, optional
+        The number of orbitals per site on the lattice, or a sequence
+        of the number of orbitals of sites on each of the sublattices.
 
     Raises
     ------
     ValueError
         If dimensionalities do not match.
     """
-    def __init__(self, prim_vecs, basis, name=''):
+    def __init__(self, prim_vecs, basis, name='', norbs=None):
         prim_vecs = ta.array(prim_vecs, float)
         if prim_vecs.ndim != 2:
             raise ValueError('`prim_vecs` must be a 2d array-like object.')
@@ -82,6 +88,7 @@ class Polyatomic:
             name = ''
         if isinstance(name, str):
             name = [name + str(i) for i in range(len(basis))]
+
         if prim_vecs.shape[0] > dim:
             raise ValueError('Number of primitive vectors exceeds '
                              'the space dimensionality.')
@@ -91,8 +98,17 @@ class Polyatomic:
         if basis.shape[1] != dim:
             raise ValueError('Basis dimensionality does not match '
                              'the space dimensionality.')
-        self.sublattices = [Monatomic(prim_vecs, offset, sname)
-                            for offset, sname in zip(basis, name)]
+
+        try:
+            norbs = list(norbs)
+            if len(norbs) != len(basis):
+                raise ValueError('Length of `norbs` is not the same as '
+                                 'the number of basis vectors')
+        except TypeError:
+            norbs = [norbs] * len(basis)
+
+        self.sublattices = [Monatomic(prim_vecs, offset, sname, norb)
+                            for offset, sname, norb in zip(basis, name, norbs)]
         # Sequence of primitive vectors of the lattice.
         self._prim_vecs = prim_vecs
         # Precalculation of auxiliary arrays for real space calculations.
@@ -405,7 +421,7 @@ class Monatomic(builder.SiteFamily, Polyatomic):
         Displacement of the lattice origin from the real space coordinates origin
     """
 
-    def __init__(self, prim_vecs, offset=None, name=''):
+    def __init__(self, prim_vecs, offset=None, name='', norbs=None):
         prim_vecs = ta.array(prim_vecs, float)
         if prim_vecs.ndim != 2:
             raise ValueError('`prim_vecs` must be a 2d array-like object.')
@@ -423,11 +439,12 @@ class Monatomic(builder.SiteFamily, Polyatomic):
                 raise ValueError('Dimensionality of offset does not match '
                                  'that of the space.')
 
-        msg = '{0}({1}, {2}, {3})'
+        msg = '{0}({1}, {2}, {3}, {4})'
         cl = self.__module__ + '.' + self.__class__.__name__
         canonical_repr = msg.format(cl, short_array_repr(prim_vecs),
-                                    short_array_repr(offset), repr(name))
-        super().__init__(canonical_repr, name)
+                                    short_array_repr(offset),
+                                    repr(name), repr(norbs))
+        super().__init__(canonical_repr, name, norbs)
 
         self.sublattices = [self]
         self._prim_vecs = prim_vecs
@@ -442,12 +459,14 @@ class Monatomic(builder.SiteFamily, Polyatomic):
         self.lattice_dim = len(prim_vecs)
 
         if name != '':
-            msg = "<Monatomic lattice {0}>"
-            self.cached_str = msg.format(name)
+            msg = "<Monatomic lattice {0}{1}>"
+            orbs = ' with {0} orbitals'.format(self.norbs) if self.norbs else ''
+            self.cached_str = msg.format(name, orbs)
         else:
-            msg = "<unnamed Monatomic lattice, vectors {0}, origin [{1}]>"
+            msg = "<unnamed Monatomic lattice, vectors {0}, origin [{1}]{2}>"
+            orbs = ', with {0} orbitals'.format(norbs) if norbs else ''
             self.cached_str = msg.format(short_array_str(self._prim_vecs),
-                                         short_array_str(self.offset))
+                                         short_array_str(self.offset), orbs)
 
     def __str__(self):
         return self.cached_str
@@ -684,33 +703,35 @@ class TranslationalSymmetry(builder.Symmetry):
 
 ################ Library of lattices
 
-def chain(a=1, name=''):
+def chain(a=1, name='', norbs=None):
     """Make a one-dimensional lattice."""
-    return Monatomic(((a,),), name=name)
+    return Monatomic(((a,),), name=name, norbs=norbs)
 
 
-def square(a=1, name=''):
+def square(a=1, name='', norbs=None):
     """Make a square lattice."""
-    return Monatomic(((a, 0), (0, a)), name=name)
+    return Monatomic(((a, 0), (0, a)), name=name, norbs=norbs)
 
 
 tri = ta.array(((1, 0), (0.5, 0.5 * sqrt(3))))
 
-def triangular(a=1, name=''):
+def triangular(a=1, name='', norbs=None):
     """Make a triangular lattice."""
-    return Monatomic(a * tri, name=name)
+    return Monatomic(a * tri, name=name, norbs=norbs)
 
 
-def honeycomb(a=1, name=''):
+def honeycomb(a=1, name='', norbs=None):
     """Make a honeycomb lattice."""
-    lat = Polyatomic(a * tri, ((0, 0), (0, a / sqrt(3))), name=name)
+    lat = Polyatomic(a * tri, ((0, 0), (0, a / sqrt(3))),
+                     name=name, norbs=norbs)
     lat.a, lat.b = lat.sublattices
     return lat
 
 
-def kagome(a=1, name=''):
+def kagome(a=1, name='', norbs=None):
     """Make a kagome lattice."""
-    lat = Polyatomic(a * tri, ((0, 0),) + tuple(0.5 * a * tri), name=name)
+    lat = Polyatomic(a * tri, ((0, 0),) + tuple(0.5 * a * tri),
+                     name=name, norbs=norbs)
     lat.a, lat.b, lat.c = lat.sublattices
     return lat
 
diff --git a/kwant/tests/test_lattice.py b/kwant/tests/test_lattice.py
index baa108ed..92a36ca2 100644
--- a/kwant/tests/test_lattice.py
+++ b/kwant/tests/test_lattice.py
@@ -197,3 +197,36 @@ def test_monatomic_lattice():
     lat2 = lattice.general(np.identity(2))
     lat3 = lattice.square(name='no')
     assert len(set([lat, lat2, lat3, lat(0, 0), lat2(0, 0), lat3(0, 0)])) == 4
+
+
+def test_norbs():
+    id_mat = np.identity(2)
+    # Monatomic lattices
+    assert_equal(lattice.general(id_mat).norbs, None)
+    assert_equal(lattice.general(id_mat, norbs=2).norbs, 2)
+    # Polyatomic lattices
+    lat = lattice.general(id_mat, basis=id_mat, norbs=None)
+    for l in lat.sublattices:
+        assert_equal(l.norbs, None)
+    lat = lattice.general(id_mat, basis=id_mat, norbs=2)
+    for l in lat.sublattices:
+        assert_equal(l.norbs, 2)
+    lat = lattice.general(id_mat, basis=id_mat, norbs=[1, 2])
+    for l, n in zip(lat.sublattices, [1, 2]):
+        assert_equal(l.norbs, n)
+    # should raise ValueError for # of norbs different to length of `basis`
+    assert_raises(ValueError, lattice.general, id_mat, id_mat, norbs=[])
+    assert_raises(ValueError, lattice.general, id_mat, id_mat, norbs=[1, 2, 3])
+    # TypeError if Monatomic lattice
+    assert_raises(TypeError, lattice.general, id_mat, norbs=[])
+    # should raise ValueError if norbs not an integer
+    assert_raises(ValueError, lattice.general, id_mat, norbs=1.5)
+    assert_raises(ValueError, lattice.general, id_mat, id_mat, norbs=1.5)
+    assert_raises(ValueError, lattice.general, id_mat, id_mat, norbs=[1.5, 1.5])
+    # test that lattices with different norbs are compared `not equal`
+    lat = lattice.general(id_mat, basis=id_mat, norbs=None)
+    lat1 = lattice.general(id_mat, basis=id_mat, norbs=1)
+    lat2 = lattice.general(id_mat, basis=id_mat, norbs=2)
+    assert_not_equal(lat, lat1)
+    assert_not_equal(lat, lat2)
+    assert_not_equal(lat1, lat2)
-- 
GitLab