Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • kwant/kwant
  • jbweston/kwant
  • anton-akhmerov/kwant
  • cwg/kwant
  • Mathieu/kwant
  • slavoutich/kwant
  • pacome/kwant
  • behrmann/kwant
  • michaelwimmer/kwant
  • albeercik/kwant
  • eunjongkim/kwant
  • basnijholt/kwant
  • r-j-skolasinski/kwant
  • sahmed95/kwant
  • pablopiskunow/kwant
  • mare/kwant
  • dvarjas/kwant
  • Paul/kwant
  • bbuijtendorp/kwant
  • tkloss/kwant
  • torosdahl/kwant
  • kel85uk/kwant
  • kpoyhonen/kwant
  • Fromeworld/kwant
  • quaeritis/kwant
  • marwahaha/kwant
  • fernandodfufrpe/kwant
  • oly/kwant
  • jiamingh/kwant
  • mehdi2369/kwant
  • ValFadeev/kwant
  • Kostas/kwant
  • chelseabaptiste03/kwant
33 results
Show changes
# Copyright 2011-2013 Kwant authors.
# Copyright 2011-2019 Kwant authors.
#
# This file is part of Kwant. It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
......@@ -8,13 +8,269 @@
"""Low-level interface of systems"""
__all__ = ['System', 'FiniteSystem', 'InfiniteSystem']
__all__ = [
'Site', 'SiteArray', 'SiteFamily',
'System', 'VectorizedSystem', 'FiniteSystem', 'FiniteVectorizedSystem',
'InfiniteSystem', 'InfiniteVectorizedSystem',
]
import abc
import warnings
import operator
from copy import copy
from collections import namedtuple
from functools import total_ordering, lru_cache
import numpy as np
from . import _system
from ._common import deprecate_args
from ._common import deprecate_args, KwantDeprecationWarning
################ Sites and Site families
class Site(tuple):
"""A site, member of a `SiteFamily`.
Sites are the vertices of the graph which describes the tight binding
system in a `Builder`.
A site is uniquely identified by its family and its tag.
Parameters
----------
family : an instance of `SiteFamily`
The 'type' of the site.
tag : a hashable python object
The unique identifier of the site within the site family, typically a
vector of integers.
Raises
------
ValueError
If `tag` is not a proper tag for `family`.
Notes
-----
For convenience, ``family(*tag)`` can be used instead of ``Site(family,
tag)`` to create a site.
The parameters of the constructor (see above) are stored as instance
variables under the same names. Given a site ``site``, common things to
query are thus ``site.family``, ``site.tag``, and ``site.pos``.
"""
__slots__ = ()
family = property(operator.itemgetter(0),
doc="The site family to which the site belongs.")
tag = property(operator.itemgetter(1), doc="The tag of the site.")
def __new__(cls, family, tag, _i_know_what_i_do=False):
if _i_know_what_i_do:
return tuple.__new__(cls, (family, tag))
try:
tag = family.normalize_tag(tag)
except (TypeError, ValueError) as e:
msg = 'Tag {0} is not allowed for site family {1}: {2}'
raise type(e)(msg.format(repr(tag), repr(family), e.args[0]))
return tuple.__new__(cls, (family, tag))
def __repr__(self):
return 'Site({0}, {1})'.format(repr(self.family), repr(self.tag))
def __str__(self):
sf = self.family
return '<Site {0} of {1}>'.format(self.tag, sf.name if sf.name else sf)
def __getnewargs__(self):
return (self.family, self.tag, True)
@property
def pos(self):
"""Real space position of the site.
This relies on ``family`` having a ``pos`` method (see `SiteFamily`).
"""
return self.family.pos(self.tag)
class SiteArray:
"""An array of sites, members of a `SiteFamily`.
Parameters
----------
family : an instance of `SiteFamily`
The 'type' of the sites.
tags : a sequence of python objects
Sequence of unique identifiers of the sites within the
site array family, typically vectors of integers.
Raises
------
ValueError
If `tags` are not proper tags for `family`.
See Also
--------
kwant.system.Site
"""
def __init__(self, family, tags):
self.family = family
try:
tags = family.normalize_tags(tags)
except (TypeError, ValueError) as e:
msg = 'Tags {0} are not allowed for site family {1}: {2}'
raise type(e)(msg.format(repr(tags), repr(family), e.args[0]))
self.tags = tags
def __repr__(self):
return 'SiteArray({0}, {1})'.format(repr(self.family), repr(self.tags))
def __str__(self):
sf = self.family
return ('<SiteArray {0} of {1}>'
.format(self.tags, sf.name if sf.name else sf))
def __len__(self):
return len(self.tags)
def __eq__(self, other):
if not isinstance(other, SiteArray):
raise NotImplementedError()
return self.family == other.family and np.all(self.tags == other.tags)
def positions(self):
"""Real space position of the site.
This relies on ``family`` having a ``pos`` method (see `SiteFamily`).
"""
return self.family.positions(self.tags)
@total_ordering
class SiteFamily:
"""Abstract base class for site families.
Site families are the 'type' of `Site` objects. Within a family, individual
sites are uniquely identified by tags. Valid tags must be hashable Python
objects, further details are up to the family.
Site families must be immutable and fully defined by their initial
arguments. They must inherit from this abstract base class and call its
__init__ function providing it with two arguments: a canonical
representation and a name. The canonical representation will be returned as
the objects representation and must uniquely identify the site family
instance. The name is a string used to distinguish otherwise identical site
families. It may be empty. ``norbs`` defines the number of orbitals
on sites associated with this site family; it may be `None`, in which case
the number of orbitals is not specified.
All site families must define either 'normalize_tag' or 'normalize_tags',
which brings a tag (or, in the latter case, a sequence of tags) to the
standard format for this site family.
Site families may also implement methods ``pos(tag)`` and
``positions(tags)``, which return a vector of realspace coordinates or an
array of vectors of realspace coordinates of the site(s) belonging to this
family with the given tag(s). These methods are used in plotting routines.
``positions(tags)`` should return an array with shape ``(N, M)`` where
``N`` is the length of ``tags``, and ``M`` is the realspace dimension.
If the ``norbs`` of a site family are provided, and sites of this family
are used to populate a `~kwant.builder.Builder`, then the associated
Hamiltonian values must have the correct shape. That is, if a site family
has ``norbs = 2``, then any on-site terms for sites belonging to this
family should be 2x2 matrices. Similarly, any hoppings to/from sites
belonging to this family must have a matrix structure where there are two
rows/columns. This condition applies equally to Hamiltonian values that
are given by functions. If this condition is not satisfied, an error will
be raised.
"""
def __init__(self, canonical_repr, name, norbs):
self.canonical_repr = canonical_repr
self.hash = hash(canonical_repr)
self.name = name
if norbs is None:
warnings.warn("Not specfying norbs is deprecated. Always specify "
"norbs when creating site families.",
KwantDeprecationWarning, stacklevel=3)
if norbs is not None:
if int(norbs) != norbs or norbs <= 0:
raise ValueError('The norbs parameter must be an integer > 0.')
norbs = int(norbs)
self.norbs = norbs
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if (cls.normalize_tag is SiteFamily.normalize_tag
and cls.normalize_tags is SiteFamily.normalize_tags):
raise TypeError("Must redefine either 'normalize_tag' or "
"'normalize_tags'")
def __repr__(self):
return self.canonical_repr
def __str__(self):
if self.name:
msg = '<{0} site family {1}{2}>'
else:
msg = '<unnamed {0} site family{2}>'
orbs = ' with {0} orbitals'.format(self.norbs) if self.norbs else ''
return msg.format(self.__class__.__name__, self.name, orbs)
def __hash__(self):
return self.hash
def __eq__(self, other):
try:
return self.canonical_repr == other.canonical_repr
except AttributeError:
return False
def __ne__(self, other):
try:
return self.canonical_repr != other.canonical_repr
except AttributeError:
return True
def __lt__(self, other):
# If this raises an AttributeError, we were trying
# to compare it to something non-comparable anyway.
return self.canonical_repr < other.canonical_repr
def normalize_tag(self, tag):
"""Return a normalized version of the tag.
Raises TypeError or ValueError if the tag is not acceptable.
"""
tag, = self.normalize_tags([tag])
return tag
def normalize_tags(self, tags):
"""Return a normalized version of the tags.
Raises TypeError or ValueError if the tags are not acceptable.
"""
return np.array([self.normalize_tag(tag) for tag in tags])
def __call__(self, *tag):
"""
A convenience function.
This function allows to write fam(1, 2) instead of Site(fam, (1, 2)).
"""
# Catch a likely and difficult to find mistake.
if tag and isinstance(tag[0], tuple):
raise ValueError('Use site_family(1, 2) instead of '
'site_family((1, 2))!')
return Site(self, tag)
################ Systems
class System(metaclass=abc.ABCMeta):
......@@ -92,12 +348,100 @@ class System(metaclass=abc.ABCMeta):
details = ', and '.join((', '.join(details[:-1]), details[-1]))
return '<{} with {}>'.format(self.__class__.__name__, details)
hamiltonian_submatrix = _system.hamiltonian_submatrix
Term = namedtuple(
"Term",
["subgraph", "hermitian", "parameters"],
)
class VectorizedSystem(System, metaclass=abc.ABCMeta):
"""Abstract general low-level system with support for vectorization.
Attributes
----------
graph : kwant.graph.CGraph
The system graph.
subgraphs : sequence of tuples
Each subgraph has the form '((idx1, idx2), (offsets1, offsets2))'
where 'offsets1' and 'offsets2' index sites within the site arrays
indexed by 'idx1' and 'idx2'.
terms : sequence of tuples
Each tuple has the following structure:
(subgraph: int, hermitian: bool, parameters: List(str))
'subgraph' indexes 'subgraphs' and supplies the to/from sites of this
term. 'hermitian' is 'True' if the term needs its Hermitian
conjugate to be added when evaluating the Hamiltonian, and 'parameters'
contains a list of parameter names used when evaluating this term.
site_arrays : sequence of SiteArray
The sites of the system.
site_ranges : None or Nx3 integer array
Has 1 row per site array, plus one extra row. Each row consists
of ``(first_site, norbs, orb_offset)``: the index of the first
site in the site array, the number of orbitals on each site in
the site array, and the offset of the first orbital of the first
site in the site array. In addition, the final row has the form
``(len(graph.num_nodes), 0, tot_norbs)`` where ``tot_norbs`` is the
total number of orbitals in the system. ``None`` if any site array
in 'site_arrays' does not have 'norbs' specified. Note 'site_ranges'
is directly computable from 'site_arrays'.
parameters : frozenset of strings
The names of the parameters on which the system depends. This attribute
is provisional and may be changed in a future version of Kwant
Notes
-----
The sites of the system are indexed by integers ranging from 0 to
``self.graph.num_nodes - 1``.
Optionally, a class derived from ``System`` can provide a method ``pos`` which
is assumed to return the real-space position of a site given its index.
"""
@abc.abstractmethod
def hamiltonian_term(self, term_number, selector=slice(None),
args=(), params=None):
"""Return the Hamiltonians for hamiltonian term number k.
# Add a C-implemented function as an unbound method to class System.
System.hamiltonian_submatrix = _system.hamiltonian_submatrix
Parameters
----------
term_number : int
The number of the term to evaluate.
selector : slice or sequence of int, default: slice(None)
The elements of the term to evaluate.
args : tuple
Positional arguments to the term. (Deprecated)
params : dict
Keyword parameters to the term
Returns
-------
hamiltonian : 3d complex array
Has shape ``(N, P, Q)`` where ``N`` is the number of matrix
elements in this term (or the number selected by 'selector'
if provided), ``P`` and ``Q`` are the number of orbitals in the
'to' and 'from' site arrays associated with this term.
class FiniteSystem(System, metaclass=abc.ABCMeta):
Providing positional arguments via 'args' is deprecated,
instead, provide named parameters as a dictionary via 'params'.
"""
@property
@lru_cache(1)
def site_ranges(self):
site_offsets = np.cumsum([0] + [len(arr) for arr in self.site_arrays])
norbs = [arr.family.norbs for arr in self.site_arrays] + [0]
if any(norb is None for norb in norbs):
return None
orb_offsets = np.cumsum(
[0] + [len(arr) * arr.family.norbs for arr in self.site_arrays]
)
return np.array([site_offsets, norbs, orb_offsets]).transpose()
hamiltonian_submatrix = _system.vectorized_hamiltonian_submatrix
class FiniteSystemMixin(metaclass=abc.ABCMeta):
"""Abstract finite low-level system, possibly with leads.
Attributes
......@@ -220,7 +564,19 @@ class FiniteSystem(System, metaclass=abc.ABCMeta):
return symmetries.validate(ham)
class InfiniteSystem(System, metaclass=abc.ABCMeta):
class FiniteSystem(System, FiniteSystemMixin, metaclass=abc.ABCMeta):
pass
class FiniteVectorizedSystem(VectorizedSystem, FiniteSystemMixin, metaclass=abc.ABCMeta):
pass
def is_finite(syst):
return isinstance(syst, (FiniteSystem, FiniteVectorizedSystem))
class InfiniteSystemMixin(metaclass=abc.ABCMeta):
"""Abstract infinite low-level system.
An infinite system consists of an infinite series of identical cells.
......@@ -261,30 +617,10 @@ class InfiniteSystem(System, metaclass=abc.ABCMeta):
infinite system. The other scheme has the numbers of site 0 and 1
exchanged, as well as of site 3 and 4.
Sites in the fundamental domain cell must belong to a different site array
than the sites in the previous cell. In the above example this means that
sites '(0, 1, 2)' and '(3, 4)' must belong to different site arrays.
"""
@deprecate_args
def cell_hamiltonian(self, args=(), sparse=False, *, params=None):
"""Hamiltonian of a single cell of the infinite system.
Providing positional arguments via 'args' is deprecated,
instead, provide named parameters as a dictionary via 'params'.
"""
cell_sites = range(self.cell_size)
return self.hamiltonian_submatrix(args, cell_sites, cell_sites,
sparse=sparse, params=params)
@deprecate_args
def inter_cell_hopping(self, args=(), sparse=False, *, params=None):
"""Hopping Hamiltonian between two cells of the infinite system.
Providing positional arguments via 'args' is deprecated,
instead, provide named parameters as a dictionary via 'params'.
"""
cell_sites = range(self.cell_size)
interface_sites = range(self.cell_size, self.graph.num_nodes)
return self.hamiltonian_submatrix(args, cell_sites, interface_sites,
sparse=sparse, params=params)
@deprecate_args
def modes(self, energy=0, args=(), *, params=None):
"""Return mode decomposition of the lead
......@@ -368,6 +704,41 @@ class InfiniteSystem(System, metaclass=abc.ABCMeta):
return list(broken)
class InfiniteSystem(System, InfiniteSystemMixin, metaclass=abc.ABCMeta):
@deprecate_args
def cell_hamiltonian(self, args=(), sparse=False, *, params=None):
"""Hamiltonian of a single cell of the infinite system.
Providing positional arguments via 'args' is deprecated,
instead, provide named parameters as a dictionary via 'params'.
"""
cell_sites = range(self.cell_size)
return self.hamiltonian_submatrix(args, cell_sites, cell_sites,
sparse=sparse, params=params)
@deprecate_args
def inter_cell_hopping(self, args=(), sparse=False, *, params=None):
"""Hopping Hamiltonian between two cells of the infinite system.
Providing positional arguments via 'args' is deprecated,
instead, provide named parameters as a dictionary via 'params'.
"""
cell_sites = range(self.cell_size)
interface_sites = range(self.cell_size, self.graph.num_nodes)
return self.hamiltonian_submatrix(args, cell_sites, interface_sites,
sparse=sparse, params=params)
class InfiniteVectorizedSystem(VectorizedSystem, InfiniteSystemMixin, metaclass=abc.ABCMeta):
cell_hamiltonian = _system.vectorized_cell_hamiltonian
inter_cell_hopping = _system.vectorized_inter_cell_hopping
def is_infinite(syst):
return isinstance(syst, (InfiniteSystem, InfiniteVectorizedSystem))
class PrecalculatedLead:
def __init__(self, modes=None, selfenergy=None):
"""A general lead defined by its self energy.
......
# Copyright 2011-2018 Kwant authors.
# Copyright 2011-2019 Kwant authors.
#
# This file is part of Kwant. It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
......@@ -19,8 +19,8 @@ from pytest import raises, warns
from numpy.testing import assert_almost_equal
import kwant
from kwant import builder
from kwant._common import ensure_rng
from kwant import builder, system
from kwant._common import ensure_rng, KwantDeprecationWarning
def test_bad_keys():
......@@ -28,7 +28,7 @@ def test_bad_keys():
def setitem(key):
syst[key] = None
fam = builder.SimpleSiteFamily()
fam = builder.SimpleSiteFamily(norbs=1)
syst = builder.Builder()
failures = [
......@@ -97,9 +97,9 @@ def test_bad_keys():
def test_site_families():
syst = builder.Builder()
fam = builder.SimpleSiteFamily()
ofam = builder.SimpleSiteFamily()
yafam = builder.SimpleSiteFamily('another_name')
fam = builder.SimpleSiteFamily(norbs=1)
ofam = builder.SimpleSiteFamily(norbs=1)
yafam = builder.SimpleSiteFamily('another_name', norbs=1)
syst[fam(0)] = 7
assert syst[fam(0)] == 7
......@@ -162,7 +162,7 @@ class VerySimpleSymmetry(builder.Symmetry):
# made.
def check_construction_and_indexing(sites, sites_fd, hoppings, hoppings_fd,
unknown_hoppings, sym=None):
fam = builder.SimpleSiteFamily()
fam = builder.SimpleSiteFamily(norbs=1)
syst = builder.Builder(sym)
t, V = 1.0j, 0.0
syst[sites] = V
......@@ -212,7 +212,7 @@ def check_construction_and_indexing(sites, sites_fd, hoppings, hoppings_fd,
def test_construction_and_indexing():
# Without symmetry
fam = builder.SimpleSiteFamily()
fam = builder.SimpleSiteFamily(norbs=1)
sites = [fam(0, 0), fam(0, 1), fam(1, 0)]
hoppings = [(fam(0, 0), fam(0, 1)),
(fam(0, 1), fam(1, 0)),
......@@ -250,7 +250,7 @@ def test_hermitian_conjugation():
raise ValueError
syst = builder.Builder()
fam = builder.SimpleSiteFamily()
fam = builder.SimpleSiteFamily(norbs=1)
syst[fam(0)] = syst[fam(1)] = ta.identity(2)
syst[fam(0), fam(1)] = f
......@@ -266,7 +266,7 @@ def test_hermitian_conjugation():
def test_value_equality_and_identity():
m = ta.array([[1, 2], [3j, 4j]])
syst = builder.Builder()
fam = builder.SimpleSiteFamily()
fam = builder.SimpleSiteFamily(norbs=1)
syst[fam(0)] = m
syst[fam(1)] = m
......@@ -290,12 +290,28 @@ def random_hopping_integral(rng):
def check_onsite(fsyst, sites, subset=False, check_values=True):
vectorized = isinstance(fsyst, (system.FiniteVectorizedSystem, system.InfiniteVectorizedSystem))
if vectorized:
site_offsets = np.cumsum([0] + [len(s) for s in fsyst.site_arrays])
freq = {}
for node in range(fsyst.graph.num_nodes):
site = fsyst.sites[node].tag
freq[site] = freq.get(site, 0) + 1
if check_values and site in sites:
assert fsyst.onsites[node][0] is sites[site]
if vectorized:
term = fsyst._onsite_term_by_site_id[node]
value = fsyst._term_values[term]
if callable(value):
assert value is sites[site]
else:
(w, _), (off, _) = fsyst.subgraphs[fsyst.terms[term].subgraph]
node_off = node - site_offsets[w]
selector = np.searchsorted(off, node_off)
assert np.allclose(value[selector], sites[site])
else:
assert fsyst.onsites[node][0] is sites[site]
if not subset:
# Check that all sites of `fsyst` are in `sites`.
for site in freq.keys():
......@@ -306,24 +322,50 @@ def check_onsite(fsyst, sites, subset=False, check_values=True):
def check_hoppings(fsyst, hops):
vectorized = isinstance(fsyst, (system.FiniteVectorizedSystem, system.InfiniteVectorizedSystem))
if vectorized:
site_offsets = np.cumsum([0] + [len(s) for s in fsyst.site_arrays])
assert fsyst.graph.num_edges == 2 * len(hops)
for edge_id, edge in enumerate(fsyst.graph):
tail, head = edge
tail = fsyst.sites[tail].tag
head = fsyst.sites[head].tag
value = fsyst.hoppings[edge_id][0]
if value is builder.Other:
assert (head, tail) in hops
i, j = edge
tail = fsyst.sites[i].tag
head = fsyst.sites[j].tag
if vectorized:
term = fsyst._hopping_term_by_edge_id[edge_id]
if term < 0: # Hermitian conjugate
assert (head, tail) in hops
else:
value = fsyst._term_values[term]
assert (tail, head) in hops
if callable(value):
assert value is hops[tail, head]
else:
dtype = np.dtype([('f0', int), ('f1', int)])
subgraph = fsyst.terms[term].subgraph
(to_w, from_w), hoppings = fsyst.subgraphs[subgraph]
hop = (i - site_offsets[to_w], j - site_offsets[from_w])
hop = np.array(hop, dtype=dtype)
hoppings = hoppings.transpose().view(dtype).reshape(-1)
selector = np.recarray.searchsorted(hoppings, hop)
assert np.allclose(value[selector], hops[tail, head])
else:
assert (tail, head) in hops
assert value is hops[tail, head]
value = fsyst.hoppings[edge_id][0]
if value is builder.Other:
assert (head, tail) in hops
else:
assert (tail, head) in hops
assert value is hops[tail, head]
def check_id_by_site(fsyst):
for i, site in enumerate(fsyst.sites):
assert fsyst.id_by_site[site] == i
def test_finalization():
@pytest.mark.parametrize("vectorize", [False, True])
def test_finalization(vectorize):
"""Test the finalization of finite and infinite systems.
In order to exactly verify the finalization, low-level features of the
......@@ -377,8 +419,8 @@ def test_finalization():
neighbors = sorted(neighbors)
# Build scattering region from blueprint and test it.
syst = builder.Builder()
fam = kwant.lattice.general(ta.identity(2))
syst = builder.Builder(vectorize=vectorize)
fam = kwant.lattice.general(ta.identity(2), norbs=1)
for site, value in sr_sites.items():
syst[fam(*site)] = value
for hop, value in sr_hops.items():
......@@ -388,7 +430,7 @@ def test_finalization():
check_onsite(fsyst, sr_sites)
check_hoppings(fsyst, sr_hops)
# check that sites are sorted
assert fsyst.sites == tuple(sorted(fam(*site) for site in sr_sites))
assert tuple(fsyst.sites) == tuple(sorted(fam(*site) for site in sr_sites))
# Build lead from blueprint and test it.
lead = builder.Builder(kwant.TranslationalSymmetry((size, 0)))
......@@ -421,12 +463,12 @@ def test_finalization():
# Attach lead with improper interface.
syst.leads[-1] = builder.BuilderLead(
lead, 2 * tuple(builder.Site(fam, n) for n in neighbors))
lead, 2 * tuple(system.Site(fam, n) for n in neighbors))
raises(ValueError, syst.finalized)
# Attach lead properly.
syst.leads[-1] = builder.BuilderLead(
lead, (builder.Site(fam, n) for n in neighbors))
lead, (system.Site(fam, n) for n in neighbors))
fsyst = syst.finalized()
assert len(fsyst.lead_interfaces) == 1
assert ([fsyst.sites[i].tag for i in fsyst.lead_interfaces[0]] ==
......@@ -434,15 +476,15 @@ def test_finalization():
# test that we cannot finalize a system with a badly sorted interface order
raises(ValueError, builder.InfiniteSystem, lead,
[builder.Site(fam, n) for n in reversed(neighbors)])
[system.Site(fam, n) for n in reversed(neighbors)])
# site ordering independent of whether interface was specified
flead_order = builder.InfiniteSystem(lead, [builder.Site(fam, n)
flead_order = builder.InfiniteSystem(lead, [system.Site(fam, n)
for n in neighbors])
assert flead.sites == flead_order.sites
assert tuple(flead.sites) == tuple(flead_order.sites)
syst.leads[-1] = builder.BuilderLead(
lead, (builder.Site(fam, n) for n in neighbors))
lead, (system.Site(fam, n) for n in neighbors))
fsyst = syst.finalized()
assert len(fsyst.lead_interfaces) == 1
assert ([fsyst.sites[i].tag for i in fsyst.lead_interfaces[0]] ==
......@@ -457,44 +499,62 @@ def test_finalization():
raises(ValueError, lead.finalized)
def test_site_ranges():
def _make_system_from_sites(sites, vectorize):
syst = builder.Builder(vectorize=vectorize)
for s in sites:
norbs = s.family.norbs
if norbs:
syst[s] = np.eye(s.family.norbs)
else:
syst[s] = None
return syst.finalized()
@pytest.mark.parametrize("vectorize", [False, True])
def test_site_ranges(vectorize):
lat1a = kwant.lattice.chain(norbs=1, name='a')
lat1b = kwant.lattice.chain(norbs=1, name='b')
lat2 = kwant.lattice.chain(norbs=2)
site_ranges = builder._site_ranges
# simple case -- single site family
for lat in (lat1a, lat2):
sites = list(map(lat, range(10)))
ranges = site_ranges(sites)
syst = _make_system_from_sites(sites, vectorize)
ranges = syst.site_ranges
expected = [(0, lat.norbs, 0), (10, 0, 10 * lat.norbs)]
assert ranges == expected
assert np.array_equal(ranges, expected)
# pair of site families
sites = it.chain(map(lat1a, range(4)), map(lat1b, range(6)),
map(lat1a, range(4)))
expected = [(0, 1, 0), (4, 1, 4), (10, 1, 10), (14, 0, 14)]
assert expected == site_ranges(tuple(sites))
sites = it.chain(map(lat1a, range(4)), map(lat1b, range(6)))
syst = _make_system_from_sites(sites, vectorize)
expected = [(0, 1, 0), (4, 1, 4), (10, 0, 10)]
assert np.array_equal(expected, syst.site_ranges)
sites = it.chain(map(lat2, range(4)), map(lat1a, range(6)),
map(lat1b, range(4)))
syst = _make_system_from_sites(sites, vectorize)
expected = [(0, 2, 0), (4, 1, 4*2), (10, 1, 4*2+6), (14, 0, 4*2+10)]
assert expected == site_ranges(tuple(sites))
assert np.array_equal(expected, syst.site_ranges)
# test with an actual builder
for lat in (lat1a, lat2):
sites = list(map(lat, range(10)))
syst = kwant.Builder()
syst = kwant.Builder(vectorize=vectorize)
syst[sites] = np.eye(lat.norbs)
ranges = syst.finalized().site_ranges
expected = [(0, lat.norbs, 0), (10, 0, 10 * lat.norbs)]
assert ranges == expected
# poison system with a single site with no norbs defined
syst[kwant.lattice.chain()(0)] = 1
ranges = syst.finalized().site_ranges
assert ranges == None
def test_hamiltonian_evaluation():
assert np.array_equal(ranges, expected)
if not vectorize: # vectorized systems *must* have all norbs
# poison system with a single site with no norbs defined.
# Also catch the deprecation warning.
with warnings.catch_warnings():
warnings.simplefilter("ignore")
syst[kwant.lattice.chain(norbs=None)(0)] = 1
ranges = syst.finalized().site_ranges
assert ranges is None
@pytest.mark.parametrize("vectorize", [False, True])
def test_hamiltonian_evaluation(vectorize):
def f_onsite(site):
return site.tag[0]
......@@ -502,14 +562,26 @@ def test_hamiltonian_evaluation():
a, b = a.tag, b.tag
return complex(a[0] + b[0], a[1] - b[1])
def f_onsite_vectorized(sites):
return sites.tags[:, 0]
def f_hopping_vectorized(a, b):
a, b = a.tags, b.tags
return a[:, 0] + b[:, 0] + 1j * (a[:, 1] - b[:, 1])
tags = [(0, 0), (1, 1), (2, 2), (3, 3)]
edges = [(0, 1), (0, 2), (0, 3), (1, 2)]
syst = builder.Builder()
fam = builder.SimpleSiteFamily()
syst = builder.Builder(vectorize=vectorize)
fam = builder.SimpleSiteFamily(norbs=1)
sites = [fam(*tag) for tag in tags]
syst[(fam(*tag) for tag in tags)] = f_onsite
syst[((fam(*tags[i]), fam(*tags[j])) for (i, j) in edges)] = f_hopping
hoppings = [(sites[i], sites[j]) for i, j in edges]
if vectorize:
syst[sites] = f_onsite_vectorized
syst[hoppings] = f_hopping_vectorized
else:
syst[sites] = f_onsite
syst[hoppings] = f_hopping
fsyst = syst.finalized()
assert fsyst.graph.num_nodes == len(tags)
......@@ -518,12 +590,16 @@ def test_hamiltonian_evaluation():
for i in range(len(tags)):
site = fsyst.sites[i]
assert site in sites
assert fsyst.hamiltonian(i, i) == syst[site](site)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=KwantDeprecationWarning)
assert fsyst.hamiltonian(i, i) == f_onsite(site)
for t, h in fsyst.graph:
tsite = fsyst.sites[t]
hsite = fsyst.sites[h]
assert fsyst.hamiltonian(t, h) == syst[tsite, hsite](tsite, hsite)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=KwantDeprecationWarning)
assert fsyst.hamiltonian(t, h) == f_hopping(tsite, hsite)
# test when user-function raises errors
def onsite_raises(site):
......@@ -536,13 +612,18 @@ def test_hamiltonian_evaluation():
a, b = hop
# exceptions are converted to kwant.UserCodeError and we add our message
with raises(kwant.UserCodeError) as exc_info:
fsyst.hamiltonian(a, a)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=KwantDeprecationWarning)
fsyst.hamiltonian(a, a)
msg = 'Error occurred in user-supplied value function "onsite_raises"'
assert msg in str(exc_info.value)
for hop in [(a, b), (b, a)]:
with raises(kwant.UserCodeError) as exc_info:
fsyst.hamiltonian(*hop)
with warnings.catch_warnings():
# Ignore deprecation warnings for 'hamiltonian'
warnings.filterwarnings("ignore", category=KwantDeprecationWarning)
fsyst.hamiltonian(*hop)
msg = ('Error occurred in user-supplied '
'value function "hopping_raises"')
assert msg in str(exc_info.value)
......@@ -552,7 +633,7 @@ def test_hamiltonian_evaluation():
syst[new_hop[0]] = onsite_raises
syst[new_hop] = hopping_raises
fsyst = syst.finalized()
hop = tuple(map(fsyst.sites.index, new_hop))
hop = tuple(map(fsyst.id_by_site.__getitem__, new_hop))
test_raising(fsyst, hop)
# test with infinite system
......@@ -560,17 +641,125 @@ def test_hamiltonian_evaluation():
for k, v in it.chain(syst.site_value_pairs(), syst.hopping_value_pairs()):
inf_syst[k] = v
inf_fsyst = inf_syst.finalized()
hop = tuple(map(inf_fsyst.sites.index, new_hop))
hop = tuple(map(inf_fsyst.id_by_site.__getitem__, new_hop))
test_raising(inf_fsyst, hop)
def test_vectorized_hamiltonian_evaluation():
def onsite(site):
return site.tag[0]
def vectorized_onsite(sites):
return sites.tags[:, 0]
def hopping(to_site, from_site):
a, b = to_site.tag, from_site.tag
return a[0] + b[0] + 1j * (a[1] - b[1])
def vectorized_hopping(to_sites, from_sites):
a, b = to_sites.tags, from_sites.tags
return a[:, 0] + b[:, 0] + 1j * (a[:, 1] - b[:, 1])
tags = [(0, 0), (1, 1), (2, 2), (3, 3)]
edges = [(0, 1), (0, 2), (0, 3), (1, 2)]
fam = builder.SimpleSiteFamily(norbs=1)
sites = [fam(*tag) for tag in tags]
hops = [(fam(*tags[i]), fam(*tags[j])) for (i, j) in edges]
syst_simple = builder.Builder(vectorize=False)
syst_simple[sites] = onsite
syst_simple[hops] = hopping
fsyst_simple = syst_simple.finalized()
syst_vectorized = builder.Builder(vectorize=True)
syst_vectorized[sites] = vectorized_onsite
syst_vectorized[hops] = vectorized_hopping
fsyst_vectorized = syst_vectorized.finalized()
assert fsyst_vectorized.graph.num_nodes == len(tags)
assert fsyst_vectorized.graph.num_edges == 2 * len(edges)
assert len(fsyst_vectorized.site_arrays) == 1
assert fsyst_vectorized.site_arrays[0] == system.SiteArray(fam, tags)
assert np.allclose(
fsyst_simple.hamiltonian_submatrix(),
fsyst_vectorized.hamiltonian_submatrix(),
)
for i in range(len(tags)):
site = fsyst_vectorized.sites[i]
assert site in sites
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=KwantDeprecationWarning)
assert (
fsyst_vectorized.hamiltonian(i, i)
== fsyst_simple.hamiltonian(i, i))
for t, h in fsyst_vectorized.graph:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=KwantDeprecationWarning)
assert (
fsyst_vectorized.hamiltonian(t, h)
== fsyst_simple.hamiltonian(t, h))
# Test infinite system, including hoppings that go both ways into
# the next cell
lat = kwant.lattice.square(norbs=1)
syst_vectorized = builder.Builder(kwant.TranslationalSymmetry((-1, 0)),
vectorize=True)
syst_vectorized[lat(0, 0)] = 4
syst_vectorized[lat(0, 1)] = 5
syst_vectorized[lat(0, 2)] = vectorized_onsite
syst_vectorized[(lat(1, 0), lat(0, 0))] = 1j
syst_vectorized[(lat(2, 1), lat(1, 1))] = vectorized_hopping
fsyst_vectorized = syst_vectorized.finalized()
syst_simple = builder.Builder(kwant.TranslationalSymmetry((-1, 0)),
vectorize=False)
syst_simple[lat(0, 0)] = 4
syst_simple[lat(0, 1)] = 5
syst_simple[lat(0, 2)] = onsite
syst_simple[(lat(1, 0), lat(0, 0))] = 1j
syst_simple[(lat(2, 1), lat(1, 1))] = hopping
fsyst_simple = syst_simple.finalized()
assert np.allclose(
fsyst_vectorized.hamiltonian_submatrix(),
fsyst_simple.hamiltonian_submatrix(),
)
assert np.allclose(
fsyst_vectorized.cell_hamiltonian(),
fsyst_simple.cell_hamiltonian(),
)
assert np.allclose(
fsyst_vectorized.inter_cell_hopping(),
fsyst_simple.inter_cell_hopping(),
)
def test_vectorized_requires_norbs():
# Catch deprecation warning for lack of norbs
with warnings.catch_warnings():
warnings.simplefilter("ignore")
fam = builder.SimpleSiteFamily()
syst = builder.Builder(vectorize=True)
syst[fam(0, 0)] = 1
raises(ValueError, syst.finalized)
def test_dangling():
def make_system():
# 1
# / \
# 3-0---2-4-5 6-7 8
syst = builder.Builder()
fam = builder.SimpleSiteFamily()
fam = builder.SimpleSiteFamily(norbs=1)
syst[(fam(i) for i in range(9))] = None
syst[[(fam(0), fam(1)), (fam(1), fam(2)), (fam(2), fam(0))]] = None
syst[[(fam(0), fam(3)), (fam(2), fam(4)), (fam(4), fam(5))]] = None
......@@ -596,7 +785,7 @@ def test_dangling():
def test_builder_with_symmetry():
g = kwant.lattice.general(ta.identity(3))
g = kwant.lattice.general(ta.identity(3), norbs=1)
sym = kwant.TranslationalSymmetry((0, 0, 3), (0, 2, 0))
syst = builder.Builder(sym)
......@@ -642,7 +831,7 @@ def test_builder_with_symmetry():
def test_fill():
g = kwant.lattice.square()
g = kwant.lattice.square(norbs=1)
sym_x = kwant.TranslationalSymmetry((-1, 0))
sym_xy = kwant.TranslationalSymmetry((-1, 0), (0, 1))
......@@ -658,7 +847,7 @@ def test_fill():
lambda pos: True),
(builder.NoSymmetry(),
lambda pos: ta.dot(pos, pos) < 17)]:
cubic = kwant.lattice.general(ta.identity(3))
cubic = kwant.lattice.general(ta.identity(3), norbs=1)
# Make a weird system.
orig = kwant.Builder(sym)
......@@ -769,7 +958,7 @@ def test_fill():
assert sorted(target.hoppings()) == sorted(should_be_syst.hoppings())
## test that 'fill' respects the symmetry of the target builder
lat = kwant.lattice.chain(a=1)
lat = kwant.lattice.chain(a=1, norbs=1)
template = builder.Builder(kwant.TranslationalSymmetry((-1,)))
template[lat(0)] = 2
template[lat.neighbors()] = -1
......@@ -796,7 +985,7 @@ def test_fill():
target.fill(template, lambda x: True, lat(0))
# Test for warning when one of the starting sites is outside the template
lat = kwant.lattice.square()
lat = kwant.lattice.square(norbs=1)
template = builder.Builder(kwant.TranslationalSymmetry((-1, 0)))
template[lat(0, 0)] = None
template[lat.neighbors()] = None
......@@ -811,7 +1000,7 @@ def test_fill_sticky():
separately.
"""
# Generate model template.
lat = kwant.lattice.kagome()
lat = kwant.lattice.kagome(norbs=1)
template = kwant.Builder(kwant.TranslationalSymmetry(
lat.vec((1, 0)), lat.vec((0, 1))))
for i, sl in enumerate(lat.sublattices):
......@@ -844,7 +1033,7 @@ def test_fill_sticky():
def test_attach_lead():
fam = builder.SimpleSiteFamily(norbs=1)
fam_noncommensurate = builder.SimpleSiteFamily(name='other')
fam_noncommensurate = builder.SimpleSiteFamily(name='other', norbs=1)
syst = builder.Builder()
syst[fam(1)] = 0
......@@ -901,7 +1090,7 @@ def test_attach_lead():
def test_attach_lead_incomplete_unit_cell():
lat = kwant.lattice.chain()
lat = kwant.lattice.chain(norbs=1)
syst = kwant.Builder()
lead = kwant.Builder(kwant.TranslationalSymmetry((2,)))
syst[lat(1)] = lead[lat(0)] = lead[lat(1)] = 0
......@@ -912,7 +1101,7 @@ def test_attach_lead_incomplete_unit_cell():
def test_neighbors_not_in_single_domain():
sr = builder.Builder()
lead = builder.Builder(VerySimpleSymmetry(-1))
fam = builder.SimpleSiteFamily()
fam = builder.SimpleSiteFamily(norbs=1)
sr[(fam(x, y) for x in range(3) for y in range(3) if x >= y)] = 0
sr[builder.HoppingKind((1, 0), fam)] = 1
sr[builder.HoppingKind((0, 1), fam)] = 1
......@@ -935,7 +1124,7 @@ def test_closest():
rng = ensure_rng(10)
for sym_dim in range(1, 4):
for space_dim in range(sym_dim, 4):
lat = kwant.lattice.general(ta.identity(space_dim))
lat = kwant.lattice.general(ta.identity(space_dim), norbs=1)
# Choose random periods.
while True:
......@@ -967,7 +1156,7 @@ def test_closest():
def test_update():
lat = builder.SimpleSiteFamily()
lat = builder.SimpleSiteFamily(norbs=1)
syst = builder.Builder()
syst[[lat(0,), lat(1,)]] = 1
......@@ -1006,8 +1195,8 @@ def test_update():
# ghgh hhgh
#
def test_HoppingKind():
g = kwant.lattice.general(ta.identity(3), name='some_lattice')
h = kwant.lattice.general(ta.identity(3), name='another_lattice')
g = kwant.lattice.general(ta.identity(3), name='some_lattice', norbs=1)
h = kwant.lattice.general(ta.identity(3), name='another_lattice', norbs=1)
sym = kwant.TranslationalSymmetry((0, 2, 0))
syst = builder.Builder(sym)
syst[((h if max(x, y, z) % 2 else g)(x, y, z)
......@@ -1047,8 +1236,8 @@ def test_HoppingKind():
def test_invalid_HoppingKind():
g = kwant.lattice.general(ta.identity(3))
h = kwant.lattice.general(np.identity(3)[:-1]) # 2D lattice in 3D
g = kwant.lattice.general(ta.identity(3), norbs=1)
h = kwant.lattice.general(np.identity(3)[:-1], norbs=1) # 2D lattice in 3D
delta = (1, 0, 0)
......@@ -1062,7 +1251,7 @@ def test_invalid_HoppingKind():
def test_ModesLead_and_SelfEnergyLead():
lat = builder.SimpleSiteFamily()
lat = builder.SimpleSiteFamily(norbs=1)
hoppings = [builder.HoppingKind((1, 0), lat),
builder.HoppingKind((0, 1), lat)]
rng = Random(123)
......@@ -1139,7 +1328,7 @@ def test_ModesLead_and_SelfEnergyLead():
def test_site_pickle():
site = kwant.lattice.square()(0, 0)
site = kwant.lattice.square(norbs=1)(0, 0)
assert pickle.loads(pickle.dumps(site)) == site
......@@ -1201,15 +1390,22 @@ def test_discrete_symmetries():
# We need to keep testing 'args', but we don't want to see
# all the deprecation warnings in the test logs
@pytest.mark.filterwarnings("ignore:.*'args' parameter")
def test_argument_passing():
chain = kwant.lattice.chain()
@pytest.mark.parametrize("vectorize", [False, True])
def test_argument_passing(vectorize):
chain = kwant.lattice.chain(norbs=1)
# Test for passing parameters to hamiltonian matrix elements
def onsite(site, p1, p2):
return p1 + p2
if vectorize:
return (p1 + p2) * np.ones(len(site))
else:
return p1 + p2
def hopping(site1, site2, p1, p2):
return p1 - p2
if vectorize:
return (p1 - p2) * np.ones(len(site1))
else:
return p1 - p2
def gen_fill_syst(onsite, hopping, syst):
syst[(chain(i) for i in range(3))] = onsite
......@@ -1218,8 +1414,9 @@ def test_argument_passing():
fill_syst = ft.partial(gen_fill_syst, onsite, hopping)
syst = fill_syst(kwant.Builder())
inf_syst = fill_syst(kwant.Builder(kwant.TranslationalSymmetry((-3,))))
syst = fill_syst(kwant.Builder(vectorize=vectorize))
inf_syst = fill_syst(kwant.Builder(kwant.TranslationalSymmetry((-3,)),
vectorize=vectorize))
tests = (
syst.hamiltonian_submatrix,
......@@ -1235,28 +1432,34 @@ def test_argument_passing():
# test that mixing 'args' and 'params' raises TypeError
with raises(TypeError):
syst.hamiltonian(0, 0, *(2, 1), params=dict(p1=2, p2=1))
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=KwantDeprecationWarning)
syst.hamiltonian(0, 0, *(2, 1), params=dict(p1=2, p2=1))
with raises(TypeError):
inf_syst.hamiltonian(0, 0, *(2, 1), params=dict(p1=2, p2=1))
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=KwantDeprecationWarning)
inf_syst.hamiltonian(0, 0, *(2, 1), params=dict(p1=2, p2=1))
# Missing parameters raises TypeError
with raises(TypeError):
syst.hamiltonian(0, 0, params=dict(p1=2))
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=KwantDeprecationWarning)
syst.hamiltonian(0, 0, params=dict(p1=2))
with raises(TypeError):
syst.hamiltonian_submatrix(params=dict(p1=2))
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=KwantDeprecationWarning)
syst.hamiltonian_submatrix(params=dict(p1=2))
# test that passing parameters without default values works, and that
# passing parameters with default values fails
def onsite(site, p1, p2):
return p1 + p2
def hopping(site, site2, p1, p2):
return p1 - p2
fill_syst = ft.partial(gen_fill_syst, onsite, hopping)
syst = fill_syst(kwant.Builder())
inf_syst = fill_syst(kwant.Builder(kwant.TranslationalSymmetry((-3,))))
syst = fill_syst(kwant.Builder(vectorize=vectorize))
inf_syst = fill_syst(kwant.Builder(kwant.TranslationalSymmetry((-3,)),
vectorize=vectorize))
tests = (
syst.hamiltonian_submatrix,
......@@ -1272,12 +1475,18 @@ def test_argument_passing():
# Some common, some different args for value functions
def onsite2(site, a, b):
return site.pos + a + b
if vectorize:
return site.positions()[:, 0] + a + b
else:
return site.pos[0] + a + b
def hopping2(site1, site2, a, c, b):
return a + b + c
if vectorize:
return (a + b + c) * np.ones(len(site1))
else:
return a + b + c
syst = kwant.Builder()
syst = kwant.Builder(vectorize=vectorize)
syst[(chain(i) for i in range(3))] = onsite2
syst[((chain(i), chain(i + 1)) for i in range(2))] = hopping2
fsyst = syst.finalized()
......@@ -1336,7 +1545,7 @@ def test_subs():
salt = str(b) + str(c)
return kwant.digest.uniform(ta.array((sitea.tag, siteb.tag)), salt=salt)
lat = kwant.lattice.chain()
lat = kwant.lattice.chain(norbs=1)
def make_system(sym=kwant.builder.NoSymmetry(), n=3):
syst = kwant.Builder(sym)
......@@ -1388,8 +1597,9 @@ def test_subs():
lead = lead.finalized()
assert lead.parameters == {'lead_a', 'lead_b', 'lead_c'}
def test_attach_stores_padding():
lat = kwant.lattice.chain()
lat = kwant.lattice.chain(norbs=1)
syst = kwant.Builder()
syst[lat(0)] = 0
lead = kwant.Builder(kwant.TranslationalSymmetry(lat.prim_vecs[0]))
......@@ -1400,7 +1610,7 @@ def test_attach_stores_padding():
def test_finalization_preserves_padding():
lat = kwant.lattice.chain()
lat = kwant.lattice.chain(norbs=1)
syst = kwant.Builder()
for i in range(10):
syst[lat(i)] = 0
......@@ -1416,3 +1626,28 @@ def test_finalization_preserves_padding():
syst = syst.finalized()
# The order is guaranteed because the paddings are sorted.
assert [syst.sites[i] for i in syst.lead_paddings[0]] == padding[:-1]
def test_add_peierls_phase():
lat = kwant.lattice.square(norbs=1)
syst = kwant.Builder()
syst[(lat(i, j) for i in range(5) for j in range(5))] = 4
syst[lat.neighbors()] = lambda a, b, t: -t
lead = kwant.Builder(kwant.TranslationalSymmetry((-1, 0)))
lead[(lat(0, j) for j in range(5))] = 4
lead[lat.neighbors()] = -1
syst.attach_lead(lead)
syst.attach_lead(lead.reversed())
syst, phase = builder.add_peierls_phase(syst)
assert isinstance(syst, builder.FiniteSystem)
params = phase(1, 0, 0)
assert all(p in params for p in ('phi', 'phi_lead0', 'phi_lead1'))
kwant.smatrix(syst, energy=0.1, params=dict(t=-1, **params))
......@@ -18,7 +18,7 @@ def test_qhe(W=16, L=8):
x, y = pos
return -L < x < L and abs(y) < W - 5.5 * math.exp(-x**2 / 5**2)
lat = kwant.lattice.square()
lat = kwant.lattice.square(norbs=1)
syst = kwant.Builder()
syst[lat.shape(central_region, (0, 0))] = onsite
......
......@@ -6,6 +6,7 @@
# the file AUTHORS.rst at the top-level directory of this distribution and at
# http://kwant-project.org/authors.
import warnings
from math import sqrt
import numpy as np
import tinyarray as ta
......@@ -17,40 +18,42 @@ import pytest
def test_closest():
rng = ensure_rng(4)
lat = lattice.general(((1, 0), (0.5, sqrt(3)/2)))
lat = lattice.general(((1, 0), (0.5, sqrt(3)/2)), norbs=1)
for i in range(50):
point = 20 * rng.random_sample(2)
closest = lat(*lat.closest(point)).pos
assert np.linalg.norm(point - closest) <= 1 / sqrt(3)
lat = lattice.general(rng.randn(3, 3))
lat = lattice.general(rng.randn(3, 3), norbs=1)
for i in range(50):
tag = rng.randint(10, size=(3,))
assert lat.closest(lat(*tag).pos) == tag
def test_general():
for lat in (lattice.general(((1, 0), (0.5, 0.5))),
for lat in (lattice.general(((1, 0), (0.5, 0.5)), norbs=1),
lattice.general(((1, 0), (0.5, sqrt(3)/2)),
((0, 0), (0, 1/sqrt(3))))):
((0, 0), (0, 1/sqrt(3))),
norbs=1,
)):
for sl in lat.sublattices:
tag = (-5, 33)
site = sl(*tag)
assert tag == sl.closest(site.pos)
# Test 2D lattice with 1 vector.
lat = lattice.general([[1, 0]])
lat = lattice.general([[1, 0]], norbs=1)
site = lat(0)
raises(ValueError, lat, 0, 1)
def test_neighbors():
lat = lattice.honeycomb(1e-10)
lat = lattice.honeycomb(1e-10, norbs=1)
num_nth_nearest = [len(lat.neighbors(n)) for n in range(5)]
assert num_nth_nearest == [2, 3, 6, 3, 6]
lat = lattice.general([(0, 1e8, 0, 0), (0, 0, 1e8, 0)])
lat = lattice.general([(0, 1e8, 0, 0), (0, 0, 1e8, 0)], norbs=1)
num_nth_nearest = [len(lat.neighbors(n)) for n in range(5)]
assert num_nth_nearest == [1, 2, 2, 2, 4]
lat = lattice.chain(1e-10)
lat = lattice.chain(1e-10, norbs=1)
num_nth_nearest = [len(lat.neighbors(n)) for n in range(5)]
assert num_nth_nearest == 5 * [1]
......@@ -59,7 +62,7 @@ def test_shape():
def in_circle(pos):
return pos[0] ** 2 + pos[1] ** 2 < 3
lat = lattice.honeycomb()
lat = lattice.honeycomb(norbs=1)
sites = list(lat.shape(in_circle, (0, 0))())
sites_alt = list()
sl0, sl1 = lat.sublattices
......@@ -88,7 +91,7 @@ def test_wire():
vecs = rng.randn(3, 3)
vecs[0] = [1, 0, 0]
center = rng.randn(3)
lat = lattice.general(vecs, rng.randn(4, 3))
lat = lattice.general(vecs, rng.randn(4, 3), norbs=1)
syst = builder.Builder(lattice.TranslationalSymmetry((2, 0, 0)))
def wire_shape(pos):
pos = np.array(pos)
......@@ -103,8 +106,8 @@ def test_wire():
def test_translational_symmetry():
ts = lattice.TranslationalSymmetry
f2 = lattice.general(np.identity(2))
f3 = lattice.general(np.identity(3))
f2 = lattice.general(np.identity(2), norbs=1)
f3 = lattice.general(np.identity(3), norbs=1)
shifted = lambda site, delta: site.family(*ta.add(site.tag, delta))
raises(ValueError, ts, (0, 0, 4), (0, 5, 0), (0, 0, 2))
......@@ -112,7 +115,7 @@ def test_translational_symmetry():
raises(ValueError, sym.add_site_family, f2)
# Test lattices with dimension smaller than dimension of space.
f2in3 = lattice.general([[4, 4, 0], [4, -4, 0]])
f2in3 = lattice.general([[4, 4, 0], [4, -4, 0]], norbs=1)
sym = ts((8, 0, 0))
sym.add_site_family(f2in3)
sym = ts((8, 0, 1))
......@@ -150,7 +153,7 @@ def test_translational_symmetry():
(site, shifted(site, hop)))
# Test act for hoppings belonging to different lattices.
f2p = lattice.general(2 * np.identity(2))
f2p = lattice.general(2 * np.identity(2), norbs=1)
sym = ts(*(2 * np.identity(2)))
assert sym.act((1, 1), f2(0, 0), f2p(0, 0)) == (f2(2, 2), f2p(1, 1))
assert sym.act((1, 1), f2p(0, 0), f2(0, 0)) == (f2p(1, 1), f2(2, 2))
......@@ -160,7 +163,7 @@ def test_translational_symmetry():
# generated symmetry with proper vectors.
rng = ensure_rng(30)
vec = rng.randn(3, 5)
lat = lattice.general(vec)
lat = lattice.general(vec, norbs=1)
total = 0
for k in range(1, 4):
for i in range(10):
......@@ -176,7 +179,7 @@ def test_translational_symmetry():
def test_translational_symmetry_reversed():
rng = ensure_rng(30)
lat = lattice.general(np.identity(3))
lat = lattice.general(np.identity(3), norbs=1)
sites = [lat(i, j, k) for i in range(-2, 6) for j in range(-2, 6)
for k in range(-2, 6)]
for i in range(4):
......@@ -194,9 +197,9 @@ def test_translational_symmetry_reversed():
def test_monatomic_lattice():
lat = lattice.square()
lat2 = lattice.general(np.identity(2))
lat3 = lattice.square(name='no')
lat = lattice.square(norbs=1)
lat2 = lattice.general(np.identity(2), norbs=1)
lat3 = lattice.square(name='no', norbs=1)
assert len(set([lat, lat2, lat3, lat(0, 0), lat2(0, 0), lat3(0, 0)])) == 4
@pytest.mark.parametrize('prim_vecs, basis', [
......@@ -210,16 +213,22 @@ def test_monatomic_lattice():
])
def test_lattice_constraints(prim_vecs, basis):
with pytest.raises(ValueError):
lattice.general(prim_vecs, basis)
lattice.general(prim_vecs, basis, norbs=1)
def test_norbs():
id_mat = np.identity(2)
# Monatomic lattices
assert lattice.general(id_mat).norbs == None
# Catch deprecation warning
with warnings.catch_warnings():
warnings.simplefilter("ignore")
assert lattice.general(id_mat).norbs == None
assert lattice.general(id_mat, norbs=2).norbs == 2
# Polyatomic lattices
lat = lattice.general(id_mat, basis=id_mat, norbs=None)
# Catch deprecation warning
with warnings.catch_warnings():
warnings.simplefilter("ignore")
lat = lattice.general(id_mat, basis=id_mat, norbs=None)
for l in lat.sublattices:
assert l.norbs == None
lat = lattice.general(id_mat, basis=id_mat, norbs=2)
......@@ -237,8 +246,14 @@ def test_norbs():
raises(ValueError, lattice.general, id_mat, norbs=1.5)
raises(ValueError, lattice.general, id_mat, id_mat, norbs=1.5)
raises(ValueError, lattice.general, id_mat, id_mat, norbs=[1.5, 1.5])
# should raise ValueError if norbs is <= 0
raises(ValueError, lattice.general, id_mat, norbs=0)
raises(ValueError, lattice.general, id_mat, norbs=-1)
# test that lattices with different norbs are compared `not equal`
lat = lattice.general(id_mat, basis=id_mat, norbs=None)
# Catch deprecation warning
with warnings.catch_warnings():
warnings.simplefilter("ignore")
lat = lattice.general(id_mat, basis=id_mat, norbs=None)
lat1 = lattice.general(id_mat, basis=id_mat, norbs=1)
lat2 = lattice.general(id_mat, basis=id_mat, norbs=2)
assert lat != lat1
......@@ -247,7 +262,7 @@ def test_norbs():
def test_symmetry_act():
lat = lattice.square()
lat = lattice.square(norbs=1)
sym = lattice.TranslationalSymmetry((1, 0), (0, 1))
site = lat(0, 0)
hopping = (lat(0, 0), lat(1, 0))
......
......@@ -6,6 +6,7 @@
# the file AUTHORS.rst at the top-level directory of this distribution and at
# http://kwant-project.org/authors.
import warnings
import functools as ft
from collections import deque
import pickle
......@@ -59,7 +60,10 @@ def test_operator_construction():
N = len(fsyst.sites)
# test construction failure if norbs not given
latnone = kwant.lattice.chain()
# Catch deprecation warning
with warnings.catch_warnings():
warnings.simplefilter("ignore")
latnone = kwant.lattice.chain()
syst[latnone(0)] = 1
for A in opservables:
raises(ValueError, A, syst.finalized())
......@@ -128,12 +132,35 @@ def test_operator_construction():
fwhere = tuple(fsyst.id_by_site[s] for s in where)
A = ops.Density(fsyst, where=where)
assert np.all(np.asarray(A.where).reshape(-1) == fwhere)
# Test for passing integers as 'where'
A = ops.Density(fsyst, where=fwhere)
assert np.all(np.asarray(A.where).reshape(-1) == fwhere)
# Test passing invalid sites
with raises(ValueError):
ops.Density(fsyst, where=[lat(100)])
with raises(ValueError):
ops.Density(fsyst, where=[-1])
with raises(ValueError):
ops.Density(fsyst, where=[10000])
where = [(lat(2, 2), lat(1, 2)), (lat(0, 0), lat(0, 1))]
fwhere = np.asarray([(fsyst.id_by_site[a], fsyst.id_by_site[b])
for a, b in where])
A = ops.Current(fsyst, where=where)
assert np.all(np.asarray(A.where) == fwhere)
# Test for passing integers as 'where'
A = ops.Current(fsyst, where=fwhere)
assert np.all(np.asarray(A.where) == fwhere)
# Test passing invalid hoppings
with raises(ValueError):
ops.Current(fsyst, where=[(lat(2, 2), lat(0, 0))])
with raises(ValueError):
ops.Current(fsyst, where=[(-1, 1)])
with raises(ValueError):
ops.Current(fsyst, where=[(len(fsyst.sites), 1)])
with raises(ValueError):
ops.Current(fsyst, where=[(fsyst.id_by_site[lat(2, 2)],
fsyst.id_by_site[lat(0, 0)])])
# test construction with `where` given by a function
tag_list = [(1, 0), (1, 1), (1, 2)]
......@@ -304,7 +331,7 @@ def test_opservables_spin():
down, up = kwant.wave_function(fsyst, energy=1., params=params)(0)
x_hoppings = kwant.builder.HoppingKind((1,), lat)
spin_current_z = ops.Current(fsyst, sigmaz, where=x_hoppings(syst))
spin_current_z = ops.Current(fsyst, sigmaz, where=list(x_hoppings(syst)))
_test(spin_current_z, up, params=params, per_el_val=1)
_test(spin_current_z, down, params=params, per_el_val=-1)
......@@ -366,12 +393,14 @@ def test_opservables_gauged():
(Us[i], sigmaz, Us[i].conjugate().transpose()))
x_hoppings = kwant.builder.HoppingKind((1,), lat)
spin_current_gauge = ops.Current(fsyst, M_a, where=x_hoppings(syst))
spin_current_gauge = ops.Current(fsyst, M_a,
where=list(x_hoppings(syst)))
_test(spin_current_gauge, up, per_el_val=1)
_test(spin_current_gauge, down, per_el_val=-1)
# check the reverse is also true
minus_x_hoppings = kwant.builder.HoppingKind((-1,), lat)
spin_current_gauge = ops.Current(fsyst, M_a, where=minus_x_hoppings(syst))
spin_current_gauge = ops.Current(fsyst, M_a,
where=list(minus_x_hoppings(syst)))
_test(spin_current_gauge, up, per_el_val=-1)
_test(spin_current_gauge, down, per_el_val=1)
......@@ -416,7 +445,7 @@ def test_arg_passing(A):
lat1 = kwant.lattice.chain(norbs=1)
syst = kwant.Builder()
syst[lat1(0)] = syst[lat1(1)] = lambda s0, a, b: s0.pos + a + b
syst[lat1(0)] = syst[lat1(1)] = lambda s0, a, b: s0.pos[0] + a + b
syst[lat1.neighbors()] = lambda s0, s1, a, b: a - b
fsyst = syst.finalized()
......
......@@ -91,7 +91,7 @@ def syst_2d(W=3, r1=3, r2=8):
def syst_3d(W=3, r1=2, r2=4, a=1, t=1.0):
lat = kwant.lattice.general(((a, 0, 0), (0, a, 0), (0, 0, a)))
lat = kwant.lattice.general(((a, 0, 0), (0, a, 0), (0, 0, a)), norbs=1)
syst = kwant.Builder()
def ring(pos):
......@@ -162,13 +162,29 @@ def test_plot_more_site_families_than_colors():
# https://gitlab.kwant-project.org/kwant/kwant/issues/257
ncolors = len(pyplot.rcParams['axes.prop_cycle'])
syst = kwant.Builder()
lattices = [kwant.lattice.square(name=i) for i in range(ncolors + 1)]
lattices = [kwant.lattice.square(name=i, norbs=1)
for i in range(ncolors + 1)]
for i, lat in enumerate(lattices):
syst[lat(i, 0)] = None
with tempfile.TemporaryFile('w+b') as out:
plotter.plot(syst, file=out)
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
def test_plot_raises_on_bad_site_spec():
syst = kwant.Builder()
lat = kwant.lattice.square(norbs=1)
syst[(lat(i, j) for i in range(5) for j in range(5))] = None
# Cannot provide site_size as an array when syst is a Builder
with pytest.raises(TypeError):
plotter.plot(syst, site_size=[1] * 25)
# Cannot provide site_size as an array when syst is a Builder
with pytest.raises(TypeError):
plotter.plot(syst, site_symbol=['o'] * 25)
def good_transform(pos):
x, y = pos
return y, x
......@@ -237,7 +253,7 @@ def test_spectrum():
def ham_2d(a, b, c):
return np.eye(2) * (a**2 + b**2 + c**2)
lat = kwant.lattice.chain()
lat = kwant.lattice.chain(norbs=1)
syst = kwant.Builder()
syst[(lat(i) for i in range(3))] = lambda site, a, b: a + b
syst[lat.neighbors()] = lambda site1, site2, c: c
......@@ -361,7 +377,7 @@ def _border_is_0(field):
def _test_border_0(interpolator):
## Test that current is always identically zero at box boundaries
syst = kwant.Builder()
lat = kwant.lattice.square()
lat = kwant.lattice.square(norbs=1)
syst[[lat(0, 0), lat(1, 0)]] = None
syst[(lat(0, 0), lat(1, 0))] = None
syst = syst.finalized()
......@@ -488,7 +504,7 @@ def test_current_interpolation():
### Tests on a divergence-free current (closed system)
lat = kwant.lattice.general([(1, 0), (0.5, np.sqrt(3) / 2)])
lat = kwant.lattice.general([(1, 0), (0.5, np.sqrt(3) / 2)], norbs=1)
syst = kwant.Builder()
sites = [lat(0, 0), lat(1, 0), lat(0, 1), lat(2, 2)]
syst[sites] = None
......
# Copyright 2011-2016 Kwant authors.
# Copyright 2011-2019 Kwant authors.
#
# This file is part of Kwant. It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution and at
......@@ -8,6 +8,7 @@
import pickle
import copy
import pytest
from pytest import raises
import numpy as np
from scipy import sparse
......@@ -15,9 +16,11 @@ import kwant
from kwant._common import ensure_rng
def test_hamiltonian_submatrix():
syst = kwant.Builder()
chain = kwant.lattice.chain()
@pytest.mark.parametrize("vectorize", [False, True])
def test_hamiltonian_submatrix(vectorize):
syst = kwant.Builder(vectorize=vectorize)
chain = kwant.lattice.chain(norbs=1)
chain2 = kwant.lattice.chain(norbs=2)
for i in range(3):
syst[chain(i)] = 0.5 * i
for i in range(2):
......@@ -27,7 +30,11 @@ def test_hamiltonian_submatrix():
mat = syst2.hamiltonian_submatrix()
assert mat.shape == (3, 3)
# Sorting is required due to unknown compression order of builder.
perm = np.argsort([os[0] for os in syst2.onsites])
if vectorize:
_, (site_offsets, _) = syst2.subgraphs[0]
else:
site_offsets = [os[0] for os in syst2.onsites]
perm = np.argsort(site_offsets)
mat_should_be = np.array([[0, 1j, 0], [-1j, 0.5, 2j], [0, -2j, 1]])
mat = mat[perm, :]
......@@ -41,19 +48,12 @@ def test_hamiltonian_submatrix():
mat = mat[:, perm]
np.testing.assert_array_equal(mat, mat_should_be)
mat = syst2.hamiltonian_submatrix((), perm[[0, 1]], perm[[2]])
np.testing.assert_array_equal(mat, mat_should_be[:2, 2:3])
mat = syst2.hamiltonian_submatrix((), perm[[0, 1]], perm[[2]], sparse=True)
mat = mat.toarray()
np.testing.assert_array_equal(mat, mat_should_be[:2, 2:3])
# Test for correct treatment of matrix input.
syst = kwant.Builder()
syst[chain(0)] = np.array([[0, 1j], [-1j, 0]])
syst = kwant.Builder(vectorize=vectorize)
syst[chain2(0)] = np.array([[0, 1j], [-1j, 0]])
syst[chain(1)] = np.array([[1]])
syst[chain(2)] = np.array([[2]])
syst[chain(1), chain(0)] = np.array([[1, 2j]])
syst[chain(1), chain2(0)] = np.array([[1, 2j]])
syst[chain(2), chain(1)] = np.array([[3j]])
syst2 = syst.finalized()
mat_dense = syst2.hamiltonian_submatrix()
......@@ -62,9 +62,10 @@ def test_hamiltonian_submatrix():
# Test precalculation of modes.
rng = ensure_rng(5)
lead = kwant.Builder(kwant.TranslationalSymmetry((-1,)))
lead[chain(0)] = np.zeros((2, 2))
lead[chain(0), chain(1)] = rng.randn(2, 2)
lead = kwant.Builder(kwant.TranslationalSymmetry((-1,)),
vectorize=vectorize)
lead[chain2(0)] = np.zeros((2, 2))
lead[chain2(0), chain2(1)] = rng.randn(2, 2)
syst.attach_lead(lead)
syst2 = syst.finalized()
smatrix = kwant.smatrix(syst2, .1).data
......@@ -74,20 +75,27 @@ def test_hamiltonian_submatrix():
raises(ValueError, kwant.solvers.default.greens_function, syst3, 0.2)
# Test for shape errors.
syst[chain(0), chain(2)] = np.array([[1, 2]])
syst[chain2(0), chain(2)] = np.array([[1, 2]])
syst2 = syst.finalized()
raises(ValueError, syst2.hamiltonian_submatrix)
raises(ValueError, syst2.hamiltonian_submatrix, sparse=True)
syst[chain(0), chain(2)] = 1
syst[chain2(0), chain(2)] = 1
syst2 = syst.finalized()
raises(ValueError, syst2.hamiltonian_submatrix)
raises(ValueError, syst2.hamiltonian_submatrix, sparse=True)
if vectorize: # non-vectorized systems don't check this at finalization
# Add another hopping of the same type but with a different
# (and still incompatible) shape.
syst[chain2(0), chain(1)] = np.array([[1, 2]])
raises(ValueError, syst.finalized)
def test_pickling():
syst = kwant.Builder()
lead = kwant.Builder(symmetry=kwant.TranslationalSymmetry([1.]))
lat = kwant.lattice.chain()
@pytest.mark.parametrize("vectorize", [False, True])
def test_pickling(vectorize):
syst = kwant.Builder(vectorize=vectorize)
lead = kwant.Builder(symmetry=kwant.TranslationalSymmetry([1.]),
vectorize=vectorize)
lat = kwant.lattice.chain(norbs=1)
syst[lat(0)] = syst[lat(1)] = 0
syst[lat(0), lat(1)] = 1
lead[lat(0)] = syst[lat(1)] = 0
......
......@@ -37,7 +37,8 @@ def _simple_syst(lat, E=0, t=1+1j, sym=None):
def test_consistence_with_bands(kx=1.9, nkys=31):
kys = np.linspace(-np.pi, np.pi, nkys)
for lat in [kwant.lattice.honeycomb(), kwant.lattice.square()]:
for lat in [kwant.lattice.honeycomb(norbs=1),
kwant.lattice.square(norbs=1)]:
syst = _simple_syst(lat)
wa_keep_1 = wraparound(syst, keep=1).finalized()
wa_keep_none = wraparound(syst).finalized()
......@@ -56,7 +57,7 @@ def test_consistence_with_bands(kx=1.9, nkys=31):
def test_opposite_hoppings():
lat = kwant.lattice.square()
lat = kwant.lattice.square(norbs=1)
for val in [1j, lambda a, b: 1j]:
syst = kwant.Builder(kwant.TranslationalSymmetry((1, 1)))
......@@ -74,7 +75,8 @@ def test_opposite_hoppings():
def test_value_types(k=(-1.1, 0.5), E=2, t=1):
k = dict(zip(('k_x', 'k_y', 'k_z'), k))
sym_extents = [1, 2, 3]
lattices = [kwant.lattice.honeycomb(), kwant.lattice.square()]
lattices = [kwant.lattice.honeycomb(norbs=1),
kwant.lattice.square(norbs=1)]
lat_syms = [
(lat, kwant.TranslationalSymmetry(lat.vec((n, 0)), lat.vec((0, n))))
for n, lat in itertools.product(sym_extents, lattices)
......@@ -112,7 +114,7 @@ def test_value_types(k=(-1.1, 0.5), E=2, t=1):
def test_signatures():
lat = kwant.lattice.square()
lat = kwant.lattice.square(norbs=1)
syst = kwant.Builder(kwant.TranslationalSymmetry((-3, 0), (0, 1)))
# onsites and hoppings that will be bound as sites
syst[lat(-2, 0)] = 4
......@@ -162,7 +164,7 @@ def test_signatures():
def test_symmetry():
syst = _simple_syst(kwant.lattice.square())
syst = _simple_syst(kwant.lattice.square(norbs=1))
matrices = [np.random.rand(2, 2) for i in range(4)]
laws = (matrices, [(lambda a: m) for m in matrices])
......@@ -190,10 +192,10 @@ def test_symmetry():
@pytest.mark.skipif(not _plotter.mpl_available, reason="Matplotlib unavailable.")
def test_plot_2d_bands():
chain = kwant.lattice.chain()
square = kwant.lattice.square()
cube = kwant.lattice.general([(1, 0, 0), (0, 1, 0), (0, 0, 1)])
hc = kwant.lattice.honeycomb()
chain = kwant.lattice.chain(norbs=1)
square = kwant.lattice.square(norbs=1)
cube = kwant.lattice.general([(1, 0, 0), (0, 1, 0), (0, 0, 1)], norbs=1)
hc = kwant.lattice.honeycomb(norbs=1)
syst_1d = kwant.Builder(kwant.TranslationalSymmetry(*chain._prim_vecs))
syst_1d[chain(0)] = 2
......@@ -245,7 +247,7 @@ def test_fd_mismatch():
# around in all directions, but could not be wrapped around when 'keep' is
# provided.
sqrt3 = np.sqrt(3)
lat = kwant.lattice.general([(sqrt3, 0), (-sqrt3/2, 1.5)])
lat = kwant.lattice.general([(sqrt3, 0), (-sqrt3/2, 1.5)], norbs=1)
T = kwant.TranslationalSymmetry((sqrt3, 0), (0, 3))
syst1 = kwant.Builder(T)
......@@ -269,7 +271,8 @@ def test_fd_mismatch():
## Test that spectrum of non-trivial system (including above cases)
## is the same, regardless of the way in which it is wrapped around
lat = kwant.lattice.general([(sqrt3, 0), (-sqrt3/2, 1.5)],
[(sqrt3 / 2, 0.5), (0, 1)])
[(sqrt3 / 2, 0.5), (0, 1)],
norbs=1)
a, b = lat.sublattices
T = kwant.TranslationalSymmetry((3 * sqrt3, 0), (0, 3))
syst = kwant.Builder(T)
......@@ -302,7 +305,7 @@ def test_fd_mismatch():
assert all(np.allclose(E, E[0]) for E in E_k)
# Test square lattice with oblique unit cell
lat = kwant.lattice.general(np.eye(2))
lat = kwant.lattice.general(np.eye(2), norbs=1)
translations = kwant.lattice.TranslationalSymmetry([2, 2], [0, 2])
syst = kwant.Builder(symmetry=translations)
syst[lat.shape(lambda site: True, [0, 0])] = 1
......@@ -314,7 +317,7 @@ def test_fd_mismatch():
# Test Rocksalt structure
# cubic lattice that contains both sublattices
lat = kwant.lattice.general(np.eye(3))
lat = kwant.lattice.general(np.eye(3), norbs=1)
# Builder with FCC translational symmetries.
translations = kwant.lattice.TranslationalSymmetry([1, 1, 0], [1, 0, 1], [0, 1, 1])
syst = kwant.Builder(symmetry=translations)
......@@ -341,7 +344,7 @@ def test_fd_mismatch():
def shape(site):
return abs(site.tag[2]) < 4
lat = kwant.lattice.general(np.eye(3))
lat = kwant.lattice.general(np.eye(3), norbs=1)
# First choice: primitive UC
translations = kwant.lattice.TranslationalSymmetry([1, 1, 0], [1, -1, 0], [1, 0, 1])
syst = kwant.Builder(symmetry=translations)
......
......@@ -20,36 +20,12 @@ from . import builder, system, plotter
from .linalg import lll
from .builder import herm_conj, HermConjOfFunc
from .lattice import TranslationalSymmetry
from ._common import get_parameters
from ._common import get_parameters, memoize
__all__ = ['wraparound', 'plot_2d_bands']
def _hashable(obj):
return isinstance(obj, collections.abc.Hashable)
def _memoize(f):
"""Decorator to memoize a function that works even with unhashable args.
This decorator will even work with functions whose args are not hashable.
The cache key is made up by the hashable arguments and the ids of the
non-hashable args. It is up to the user to make sure that non-hashable
args do not change during the lifetime of the decorator.
This decorator will keep reevaluating functions that return None.
"""
def lookup(*args):
key = tuple(arg if _hashable(arg) else id(arg) for arg in args)
result = cache.get(key)
if result is None:
cache[key] = result = f(*args)
return result
cache = {}
return lookup
def _set_signature(func, params):
"""Set the signature of 'func'.
......@@ -103,7 +79,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
format. It will be deprecated in the 2.0 release of Kwant.
"""
@_memoize
@memoize
def bind_site(val):
def f(*args):
a, *args = args
......@@ -113,7 +89,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
_set_signature(f, get_parameters(val) + momenta)
return f
@_memoize
@memoize
def bind_hopping_as_site(elem, val):
def f(*args):
a, *args = args
......@@ -128,7 +104,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
_set_signature(f, params + momenta)
return f
@_memoize
@memoize
def bind_hopping(elem, val):
def f(*args):
a, b, *args = args
......@@ -142,7 +118,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
_set_signature(f, params + momenta)
return f
@_memoize
@memoize
def bind_sum(num_sites, *vals):
"""Construct joint signature for all 'vals'."""
......@@ -207,6 +183,9 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
f.__signature__ = inspect.Signature(params.values())
return f
if builder.vectorize:
raise TypeError("'wraparound' does not work with vectorized Builders.")
try:
momenta = ['k_{}'.format(coordinate_names[i])
for i in range(len(builder.symmetry.periods))]
......@@ -386,10 +365,14 @@ def plot_2d_bands(syst, k_x=31, k_y=31, params=None,
if not hasattr(syst, '_wrapped_symmetry'):
raise TypeError("Expecting a system that was produced by "
"'kwant.wraparound.wraparound'.")
if not isinstance(syst, system.FiniteSystem):
if isinstance(syst, system.InfiniteSystem):
msg = ("All symmetry directions must be wrapped around: specify "
"'keep=None' when calling 'kwant.wraparound.wraparound'.")
raise TypeError(msg)
if isinstance(syst, builder.Builder):
msg = ("Expecting a finalized system: remember to finalize your "
"system with 'syst.finalized()'.")
raise TypeError(msg)
params = params or {}
lat_ndim, space_ndim = syst._wrapped_symmetry.periods.shape
......
[pytest]
testpaths = kwant
flakes-ignore =
__init__.py UnusedImport
__init__.py UnusedImport ImportStarUsed ImportStarUsage
kwant/_plotter.py UnusedImport
graph/tests/test_scotch.py UndefinedName
graph/tests/test_dissection.py UndefinedName
......@@ -517,7 +517,7 @@ def maybe_add_numpy_include(exts):
def main():
check_python_version((3, 5))
check_python_version((3, 6))
check_versions()
exts = collections.OrderedDict([
......@@ -581,12 +581,12 @@ def main():
'build_ext': build_ext,
'test': test},
ext_modules=exts,
install_requires=['numpy >= 1.11.0', 'scipy >= 0.17.0',
install_requires=['numpy >= 1.13.3', 'scipy >= 0.19.1',
'tinyarray >= 1.2'],
extras_require={
# The oldest versions between: Debian stable, Ubuntu LTS
'plotting': 'matplotlib >= 1.5.1',
'continuum': 'sympy >= 0.7.6',
'plotting': 'matplotlib >= 2.1.1',
'continuum': 'sympy >= 1.1.1',
# qsymm is only packaged on PyPI
'qsymm': 'qsymm >= 1.2.6',
},
......