Skip to content
Snippets Groups Projects
Commit e5d30c69 authored by Joseph Weston's avatar Joseph Weston
Browse files

add SiteArray and update SiteFamily to support them

We remove the ABC metaclass as now 'normalize_tag' and 'normalize_tags'
are cyclically defined, and subclasses must redefine at least 1.
parent 049d1501
No related branches found
No related tags found
No related merge requests found
...@@ -25,6 +25,7 @@ Sites ...@@ -25,6 +25,7 @@ Sites
:toctree: generated/ :toctree: generated/
Site Site
SiteArray
SiteFamily SiteFamily
Systems Systems
......
...@@ -93,8 +93,62 @@ class Site(tuple): ...@@ -93,8 +93,62 @@ class Site(tuple):
return self.family.pos(self.tag) return self.family.pos(self.tag)
class SiteArray:
"""An array of sites, members of a `SiteFamily`.
Parameters
----------
family : an instance of `SiteFamily`
The 'type' of the sites.
tags : a sequence of python objects
Sequence of unique identifiers of the sites within the
site array family, typically vectors of integers.
Raises
------
ValueError
If `tags` are not proper tags for `family`.
See Also
--------
kwant.system.Site
"""
def __init__(self, family, tags):
self.family = family
try:
tags = family.normalize_tags(tags)
except (TypeError, ValueError) as e:
msg = 'Tags {0} are not allowed for site family {1}: {2}'
raise type(e)(msg.format(repr(tags), repr(family), e.args[0]))
self.tags = tags
def __repr__(self):
return 'SiteArray({0}, {1})'.format(repr(self.family), repr(self.tags))
def __str__(self):
sf = self.family
return ('<SiteArray {0} of {1}>'
.format(self.tags, sf.name if sf.name else sf))
def __len__(self):
return len(self.tags)
def __eq__(self, other):
if not isinstance(other, SiteArray):
raise NotImplementedError()
return self.family == other.family and np.all(self.tags == other.tags)
def positions(self):
"""Real space position of the site.
This relies on ``family`` having a ``pos`` method (see `SiteFamily`).
"""
return self.family.positions(self.tags)
@total_ordering @total_ordering
class SiteFamily(metaclass=abc.ABCMeta): class SiteFamily:
"""Abstract base class for site families. """Abstract base class for site families.
Site families are the 'type' of `Site` objects. Within a family, individual Site families are the 'type' of `Site` objects. Within a family, individual
...@@ -112,12 +166,16 @@ class SiteFamily(metaclass=abc.ABCMeta): ...@@ -112,12 +166,16 @@ class SiteFamily(metaclass=abc.ABCMeta):
the number of orbitals is not specified. the number of orbitals is not specified.
All site families must define the method `normalize_tag` which brings a tag All site families must define either 'normalize_tag' or 'normalize_tags',
to the standard format for this site family. which brings a tag (or, in the latter case, a sequence of tags) to the
standard format for this site family.
Site families that are intended for use with plotting should also provide a Site families may also implement methods ``pos(tag)`` and
method `pos(tag)`, which returns a vector with real-space coordinates of the ``positions(tags)``, which return a vector of realspace coordinates or an
site belonging to this family with a given tag. array of vectors of realspace coordinates of the site(s) belonging to this
family with the given tag(s). These methods are used in plotting routines.
``positions(tags)`` should return an array with shape ``(N, M)`` where
``N`` is the length of ``tags``, and ``M`` is the realspace dimension.
If the ``norbs`` of a site family are provided, and sites of this family If the ``norbs`` of a site family are provided, and sites of this family
are used to populate a `~kwant.builder.Builder`, then the associated are used to populate a `~kwant.builder.Builder`, then the associated
...@@ -144,6 +202,13 @@ class SiteFamily(metaclass=abc.ABCMeta): ...@@ -144,6 +202,13 @@ class SiteFamily(metaclass=abc.ABCMeta):
norbs = int(norbs) norbs = int(norbs)
self.norbs = norbs self.norbs = norbs
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if (cls.normalize_tag is SiteFamily.normalize_tag
and cls.normalize_tags is SiteFamily.normalize_tags):
raise TypeError("Must redefine either 'normalize_tag' or "
"'normalize_tags'")
def __repr__(self): def __repr__(self):
return self.canonical_repr return self.canonical_repr
...@@ -175,13 +240,20 @@ class SiteFamily(metaclass=abc.ABCMeta): ...@@ -175,13 +240,20 @@ class SiteFamily(metaclass=abc.ABCMeta):
# to compare it to something non-comparable anyway. # to compare it to something non-comparable anyway.
return self.canonical_repr < other.canonical_repr return self.canonical_repr < other.canonical_repr
@abc.abstractmethod
def normalize_tag(self, tag): def normalize_tag(self, tag):
"""Return a normalized version of the tag. """Return a normalized version of the tag.
Raises TypeError or ValueError if the tag is not acceptable. Raises TypeError or ValueError if the tag is not acceptable.
""" """
pass tag, = self.normalize_tags([tag])
return tag
def normalize_tags(self, tags):
"""Return a normalized version of the tags.
Raises TypeError or ValueError if the tags are not acceptable.
"""
return np.array([self.normalize_tag(tag) for tag in tags])
def __call__(self, *tag): def __call__(self, *tag):
""" """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment