diff --git a/kwant/builder.py b/kwant/builder.py index f3b08e6a0c2f9135ca3600400956c8621421cfb1..d5fb4e3fff13d305848ccce96885bfc61c16dfaf 100644 --- a/kwant/builder.py +++ b/kwant/builder.py @@ -12,6 +12,7 @@ __all__ = ['Builder', 'Site', 'SiteFamily', 'SimpleSiteFamily', 'Symmetry', import abc import warnings import operator +from bisect import bisect from functools import total_ordering from itertools import islice, chain import tinyarray as ta @@ -1562,3 +1563,136 @@ class InfiniteSystem(system.InfiniteSystem): def pos(self, i): return self.sites[i].pos + + +################ Site and Hopping indexable arrays + +class Indexable(np.ndarray): + """An array that can be indexed by more general keys. + + Parameters + ---------- + map_key: callable + takes a general key and returns a key that can be + used to index the underlying array directly. If this + raises a KeyError then the general key is used + directly to index the array. + data: array_like + An array, any object exposing the array interface, an + object whose __array__ method returns an array, or any + (nested) sequence. + """ + + def __init__(self, map_key, data): + self._map_key = map_key + super().__init__(data) + + def __getitem__(self, key): + try: + key = self._map_key(key) + except KeyError: + pass + return super().__getitem__(key) + + def __setitem__(self, key, value): + try: + key = self._map_index(key) + except KeyError: + pass + return super().__setitem__(key, value) + + +def indexable_by_site(syst, array): + """Return an array that is indexable by `~kwant.builder.Site`s. + + Parameters + ---------- + syst: `kwant.builder.FiniteSystem` or `kwant.builder.InfiniteSystem` + array: array_like + Must have the same length as there are sites in the system. + + Notes + ----- + This differs from `sliceable_by_site`, in that it deals with arrays + defined over all the *sites* of a system, while the latter deals + with arrays defined over all the *orbitals*. + + See Also + -------- + indexable_by_site + """ + if len(array) != len(syst.sites): + raise ValueError( + 'there are {} sites in the system, but the array has length {}' + .format(len(syst.sites), len(array))) + + site_index = syst.id_by_site.__getitem__ + + return Indexable(site_index, array) + + +def indexable_by_hopping(syst, array): + """Return an array that is indexable by pairs of `~kwant.builder.Site`s. + + Parameters + ---------- + syst: `kwant.builder.FiniteSystem` or `kwant.builder.InfiniteSystem` + array: array_like + Must have the same length as there are hoppings in the system. + + Notes + ----- + Both hoppings ``(i, j)`` and ``(j, i)`` appear in the system, so + ``array`` must accomodate both these hoppings. + """ + if len(array) != syst.graph.num_edges: + raise ValueError( + 'there are {} hoppings in the system, but the array has length {}' + .format(syst.graph.num_edges, len(array))) + + def hopping_index(hop): + a, b = map(syst.id_by_sites.__getitem__, hop) + return syst.graph.first_edge_id(a, b) + + return Indexable(hopping_index, array) + + +def sliceable_by_site(syst, array): + """Return an array that can be sliced by `~kwant.builder.Site`s. + + Parameters + ---------- + syst: `kwant.builder.FiniteSystem` or `kwant.builder.InfiniteSystem` + array: array_like + Must have the same length as there are orbitals in the system. + + Notes + ----- + This differs from `indexable_by_site`, in that it deals with arrays + defined over all the *orbitals* of a system, while the latter deals + with arrays defined over all the *sites*. + + See Also + -------- + indexable_by_site + """ + norbs = syst.site_ranges[-1][-1] + if len(array) != norbs: + raise ValueError( + 'there are {} orbitals in the system, but the array has length {}' + .format(norbs, len(array))) + + inf = float('inf') + site_ranges = syst.site_ranges + id_by_site = syst.id_by_site + + # return a slice into the orbitals associated with `site` + def site_slice(site): + site_id = id_by_site[site] + # `inf` is needed to give correct sorting semantics + range_idx = bisect(site_ranges, (site_id, inf)) - 1 + first_site_id, norbs, orb_offset = site_ranges[range_idx] + start_orb = orb_offset + (site_id - first_site_id) * norbs + return slice(start_orb, start_orb + norbs) + + return Indexable(site_slice, array)