Skip to content
Snippets Groups Projects
Commit e5d30c69 authored by Joseph Weston's avatar Joseph Weston
Browse files

add SiteArray and update SiteFamily to support them

We remove the ABC metaclass as now 'normalize_tag' and 'normalize_tags'
are cyclically defined, and subclasses must redefine at least 1.
parent 049d1501
No related branches found
No related tags found
No related merge requests found
......@@ -25,6 +25,7 @@ Sites
:toctree: generated/
Site
SiteArray
SiteFamily
Systems
......
......@@ -93,8 +93,62 @@ class Site(tuple):
return self.family.pos(self.tag)
class SiteArray:
"""An array of sites, members of a `SiteFamily`.
Parameters
----------
family : an instance of `SiteFamily`
The 'type' of the sites.
tags : a sequence of python objects
Sequence of unique identifiers of the sites within the
site array family, typically vectors of integers.
Raises
------
ValueError
If `tags` are not proper tags for `family`.
See Also
--------
kwant.system.Site
"""
def __init__(self, family, tags):
self.family = family
try:
tags = family.normalize_tags(tags)
except (TypeError, ValueError) as e:
msg = 'Tags {0} are not allowed for site family {1}: {2}'
raise type(e)(msg.format(repr(tags), repr(family), e.args[0]))
self.tags = tags
def __repr__(self):
return 'SiteArray({0}, {1})'.format(repr(self.family), repr(self.tags))
def __str__(self):
sf = self.family
return ('<SiteArray {0} of {1}>'
.format(self.tags, sf.name if sf.name else sf))
def __len__(self):
return len(self.tags)
def __eq__(self, other):
if not isinstance(other, SiteArray):
raise NotImplementedError()
return self.family == other.family and np.all(self.tags == other.tags)
def positions(self):
"""Real space position of the site.
This relies on ``family`` having a ``pos`` method (see `SiteFamily`).
"""
return self.family.positions(self.tags)
@total_ordering
class SiteFamily(metaclass=abc.ABCMeta):
class SiteFamily:
"""Abstract base class for site families.
Site families are the 'type' of `Site` objects. Within a family, individual
......@@ -112,12 +166,16 @@ class SiteFamily(metaclass=abc.ABCMeta):
the number of orbitals is not specified.
All site families must define the method `normalize_tag` which brings a tag
to the standard format for this site family.
All site families must define either 'normalize_tag' or 'normalize_tags',
which brings a tag (or, in the latter case, a sequence of tags) to the
standard format for this site family.
Site families that are intended for use with plotting should also provide a
method `pos(tag)`, which returns a vector with real-space coordinates of the
site belonging to this family with a given tag.
Site families may also implement methods ``pos(tag)`` and
``positions(tags)``, which return a vector of realspace coordinates or an
array of vectors of realspace coordinates of the site(s) belonging to this
family with the given tag(s). These methods are used in plotting routines.
``positions(tags)`` should return an array with shape ``(N, M)`` where
``N`` is the length of ``tags``, and ``M`` is the realspace dimension.
If the ``norbs`` of a site family are provided, and sites of this family
are used to populate a `~kwant.builder.Builder`, then the associated
......@@ -144,6 +202,13 @@ class SiteFamily(metaclass=abc.ABCMeta):
norbs = int(norbs)
self.norbs = norbs
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if (cls.normalize_tag is SiteFamily.normalize_tag
and cls.normalize_tags is SiteFamily.normalize_tags):
raise TypeError("Must redefine either 'normalize_tag' or "
"'normalize_tags'")
def __repr__(self):
return self.canonical_repr
......@@ -175,13 +240,20 @@ class SiteFamily(metaclass=abc.ABCMeta):
# to compare it to something non-comparable anyway.
return self.canonical_repr < other.canonical_repr
@abc.abstractmethod
def normalize_tag(self, tag):
"""Return a normalized version of the tag.
Raises TypeError or ValueError if the tag is not acceptable.
"""
pass
tag, = self.normalize_tags([tag])
return tag
def normalize_tags(self, tags):
"""Return a normalized version of the tags.
Raises TypeError or ValueError if the tags are not acceptable.
"""
return np.array([self.normalize_tag(tag) for tag in tags])
def __call__(self, *tag):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment