From e662097ac22ded7b58dcf80b8267a02232ebe5aa Mon Sep 17 00:00:00 2001
From: Christoph Groth <christoph.groth@cea.fr>
Date: Wed, 22 Feb 2017 13:03:26 +0100
Subject: [PATCH] migrate 'wraparound' module to Kwant

---
 kwant/wraparound.py | 242 ++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 242 insertions(+)
 create mode 100644 kwant/wraparound.py

diff --git a/kwant/wraparound.py b/kwant/wraparound.py
new file mode 100644
index 00000000..21a1e981
--- /dev/null
+++ b/kwant/wraparound.py
@@ -0,0 +1,242 @@
+# Copyright 2016 Christoph Groth (INAC / CEA Grenoble).
+#
+# This file is subject to the 2-clause BSD license as found at
+# http://kwant-project.org/license.
+
+"""Replace symmetries of Kwant builders with momentum parameters to the
+system."""
+
+import sys
+import collections
+import cmath
+import numpy as np
+import tinyarray as ta
+
+import kwant
+from kwant.builder import herm_conj
+
+
+if sys.version_info >= (3, 0):
+    def _hashable(obj):
+        return isinstance(obj, collections.Hashable)
+else:
+    def _hashable(obj):
+        return (isinstance(obj, collections.Hashable)
+                and not isinstance(obj, np.ndarray))
+
+
+def _memoize(f):
+    """Decorator to memoize a function that works even with unhashable args.
+
+    This decorator will even work with functions whose args are not hashable.
+    The cache key is made up by the hashable arguments and the ids of the
+    non-hashable args.  It is up to the user to make sure that non-hashable
+    args do not change during the lifetime of the decorator.
+
+    This decorator will keep reevaluating functions that return None.
+    """
+    def lookup(*args):
+        key = tuple(arg if _hashable(arg) else id(arg) for arg in args)
+        result = cache.get(key)
+        if result is None:
+            cache[key] = result = f(*args)
+        return result
+    cache = {}
+    return lookup
+
+
+def wraparound(builder, keep=None):
+    """Replace translational symmetries by momentum parameters.
+
+    A new Builder instance is returned.  By default, each symmetry is replaced
+    by one scalar momentum parameter that is appended to the already existing
+    arguments of the system.  Optionally, one symmetry may be kept by using the
+    `keep` argument.
+    """
+
+    @_memoize
+    def bind_site(val):
+        assert callable(val)
+        return lambda a, *args: val(a, *args[:mnp])
+
+    @_memoize
+    def bind_hopping_as_site(elem, val):
+        def f(a, *args):
+            phase = cmath.exp(1j * ta.dot(elem, args[mnp:]))
+            v = val(a, sym.act(elem, a), *args[:mnp]) if callable(val) else val
+            pv = phase * v
+            return pv + herm_conj(pv)
+        return f
+
+    @_memoize
+    def bind_hopping(elem, val):
+        def f(a, b, *args):
+            phase = cmath.exp(1j * ta.dot(elem, args[mnp:]))
+            v = val(a, sym.act(elem, b), *args[:mnp]) if callable(val) else val
+            return phase * v
+        return f
+
+    @_memoize
+    def bind_sum(*vals):
+        return lambda *args: sum((val(*args) if callable(val) else val)
+                                 for val in vals)
+
+    if keep is None:
+        ret = kwant.Builder()
+        sym = builder.symmetry
+    else:
+        periods = list(builder.symmetry.periods)
+        ret = kwant.Builder(kwant.TranslationalSymmetry(periods.pop(keep)))
+        sym = kwant.TranslationalSymmetry(*periods)
+
+    sites = {}
+    hops = collections.defaultdict(list)
+
+    mnp = -len(sym.periods)      # Used by the bound functions above.
+
+    # Store lists of values, so that multiple values can be assigned to the
+    # same site or hopping.
+    for site, val in builder.site_value_pairs():
+        sites[site] = [bind_site(val) if callable(val) else val]
+
+    for hop, val in builder.hopping_value_pairs():
+        a, b = hop
+        b_dom = sym.which(b)
+        b_wa = sym.act(-b_dom, b)
+
+        if a == b_wa:
+            # The hopping gets wrapped-around into an onsite Hamiltonian.
+            # Since site `a` already exists in the system, we can simply append.
+            sites[a].append(bind_hopping_as_site(b_dom, val))
+        else:
+            # The hopping remains a hopping.
+            if b != b_wa or callable(val):
+                # The hopping got wrapped-around or is a function.
+                val = bind_hopping(b_dom, val)
+
+            # Make sure that there is only one entry for each hopping
+            # (pointing in one direction).
+            if (b_wa, a) in hops:
+                assert (a, b_wa) not in hops
+                if callable(val):
+                    assert not isinstance(val, kwant.builder.HermConjOfFunc)
+                    val = kwant.builder.HermConjOfFunc(val)
+                else:
+                    val = kwant.builder.herm_conj(val)
+
+                hops[b_wa, a].append(val)
+            else:
+                hops[a, b_wa].append(val)
+
+    # Copy stuff into result builder, converting lists of more than one element
+    # into summing functions.
+    for site, vals in sites.items():
+        ret[site] = vals[0] if len(vals) == 1 else bind_sum(*vals)
+
+    for hop, vals in hops.items():
+        ret[hop] = vals[0] if len(vals) == 1 else bind_sum(*vals)
+
+    return ret
+
+
+def plot_bands_2d(syst, args=(), momenta=(31, 31)):
+
+    """Plot the bands of a system with two wrapped-around symmetries."""
+    from mpl_toolkits.mplot3d import Axes3D
+    from matplotlib import pyplot
+
+    if not isinstance(syst, kwant.system.FiniteSystem):
+        raise TypeError("Need a system without symmetries.")
+
+    fig = pyplot.figure()
+    ax = fig.gca(projection='3d')
+    kxs = np.linspace(-np.pi, np.pi, momenta[0])
+    kys = np.linspace(-np.pi, np.pi, momenta[1])
+
+    energies = [[np.sort(np.linalg.eigvalsh(syst.hamiltonian_submatrix(
+        args + (kx, ky), sparse=False)).real)
+                 for ky in kys] for kx in kxs]
+    energies = np.array(energies)
+
+    mesh_x, mesh_y = np.meshgrid(kxs, kys)
+    for i in range(energies.shape[-1]):
+        ax.plot_wireframe(mesh_x, mesh_y, energies[:, :, i],
+                          rstride=1, cstride=1)
+
+    pyplot.show()
+
+
+def _simple_syst(lat, E=0, t=1+1j):
+    """Create a builder for a simple infinite system."""
+    sym = kwant.TranslationalSymmetry(lat.vec((1, 0)), lat.vec((0, 1)))
+    # Build system with 2d periodic BCs. This system cannot be finalized in
+    # Kwant <= 1.2.
+    syst = kwant.Builder(sym)
+    syst[lat.shape(lambda p: True, (0, 0))] = E
+    syst[lat.neighbors(1)] = t
+    return syst
+
+
+def test_consistence_with_bands(kx=1.9, nkys=31):
+    kys = np.linspace(-np.pi, np.pi, nkys)
+    for lat in [kwant.lattice.honeycomb(), kwant.lattice.square()]:
+        syst = _simple_syst(lat)
+        wa_keep_1 = wraparound(syst, keep=1).finalized()
+        wa_keep_none = wraparound(syst).finalized()
+
+        bands = kwant.physics.Bands(wa_keep_1, (kx,))
+        energies_a = [bands(ky) for ky in
+                      (kys if kwant.__version__ > "1.0" else reversed(kys))]
+
+        energies_b = []
+        for ky in kys:
+            H = wa_keep_none.hamiltonian_submatrix((kx, ky), sparse=False)
+            evs = np.sort(np.linalg.eigvalsh(H).real)
+            energies_b.append(evs)
+
+        np.testing.assert_almost_equal(energies_a, energies_b)
+
+
+def test_opposite_hoppings():
+    lat = kwant.lattice.square()
+
+    for val in [1j, lambda a, b: 1j]:
+        syst = kwant.Builder(kwant.TranslationalSymmetry((1, 1)))
+        syst[ (lat(x, 0) for x in [-1, 0]) ] = 0
+        syst[lat(0, 0), lat(-1, 0)] = val
+        syst[lat(-1, 0), lat(-1, -1)] = val
+
+        fsyst = wraparound(syst).finalized()
+        np.testing.assert_almost_equal(fsyst.hamiltonian_submatrix([0]), 0)
+
+
+def test_value_types(k=(-1.1, 0.5), E=0, t=1):
+    for lat in [kwant.lattice.honeycomb(), kwant.lattice.square()]:
+        syst = wraparound(_simple_syst(lat, E, t)).finalized()
+        H = syst.hamiltonian_submatrix(k, sparse=False)
+        for E1, t1 in [(float(E), float(t)),
+                       (np.array([[E]], float), np.array([[1]], float)),
+                       (ta.array([[E]], float), ta.array([[1]], float))]:
+            for E2 in [E1, lambda a: E1]:
+                for t2 in [t1, lambda a, b: t1]:
+                    syst = wraparound(_simple_syst(lat, E2, t2)).finalized()
+                    H_alt = syst.hamiltonian_submatrix(k, sparse=False)
+                    np.testing.assert_equal(H_alt, H)
+
+
+def test():
+    test_consistence_with_bands()
+    test_opposite_hoppings()
+    test_value_types()
+
+
+def demo():
+    """Calculate and plot the band structure of graphene."""
+    lat = kwant.lattice.honeycomb()
+    syst = wraparound(_simple_syst(lat)).finalized()
+    plot_bands_2d(syst)
+
+
+if __name__ == '__main__':
+    test()
+    demo()
-- 
GitLab