From d07c170944e8407ac8de5ba9208bf45ce6024551 Mon Sep 17 00:00:00 2001
From: Joseph Weston <joseph@weston.cloud>
Date: Tue, 18 Dec 2018 12:57:12 +0100
Subject: [PATCH] add magnetic gauge fixing for infinite systems and systems
 with leads

Co-authored-by: Pablo Piskunow <pablo.perez.piskunow@gmail.com>
Co-authored-by: Daniel Varjas <dvarjas@gmail.com>
---
 kwant/physics/gauge.py            | 592 +++++++++++++++++++++++++++++-
 kwant/physics/tests/test_gauge.py | 177 ++++++++-
 2 files changed, 754 insertions(+), 15 deletions(-)

diff --git a/kwant/physics/gauge.py b/kwant/physics/gauge.py
index 09711237..7a3860cd 100644
--- a/kwant/physics/gauge.py
+++ b/kwant/physics/gauge.py
@@ -13,11 +13,15 @@ Backwards incompatible changes (up to and including removal of the package)
 may occur if deemed necessary by the core developers.
 """
 
+import bisect
 import functools as ft
+from functools import partial
+from itertools import permutations
 
 import numpy as np
 import scipy
 from scipy.integrate import dblquad
+from scipy.sparse import csgraph
 
 from .. import system, builder
 
@@ -287,6 +291,470 @@ def shortest_distance_forest(graph):
     return tree
 
 
+def loops_in_infinite(syst):
+    """Find the loops in an infinite system.
+
+    Returns
+    -------
+    loops : sequence of sequences of integers
+        The sites in the returned loops belong to two adjacent unit
+        cells. The first 'syst.cell_size' sites are in the first
+        unit cell, and the next 'sys.cell_size' are in the next
+        (in the direction of the translational symmetry).
+    extended_sites : callable : int -> Site
+        Given a site index in the extended system consisting of
+        two unit cells, returns the associated high-level
+        `kwant.builder.Site`.
+    """
+    assert isinstance(syst, system.InfiniteSystem)
+    check_infinite_syst(syst)
+
+    cell_size = syst.cell_size
+
+    unit_cell_links = [(i, j) for i, j in syst.graph
+                       if i < cell_size and j < cell_size]
+    unit_cell_graph = distance_matrix(unit_cell_links,
+                                      pos=syst.pos,
+                                      shape=(cell_size, cell_size))
+
+    # Loops in the interior of the unit cell
+    spanning_tree = shortest_distance_forest(unit_cell_graph)
+    loops = find_loops(unit_cell_graph, spanning_tree)
+
+    # Construct an extended graph consisting of 2 unit cells connected
+    # by the inter-cell links.
+    extended_shape = (2 * cell_size, 2 * cell_size)
+    uc1 = shift_diagonally(unit_cell_graph, 0, shape=extended_shape)
+    uc2 = shift_diagonally(unit_cell_graph, cell_size, shape=extended_shape)
+    hop_links = [(i, j) for i, j in syst.graph if j >= cell_size]
+    hop = distance_matrix(hop_links,
+                          pos=syst.pos,
+                          shape=extended_shape)
+    graph = add_coo_matrices(uc1, uc2, hop, hop.T,
+                             shape=extended_shape)
+
+    # Construct a subgraph where only the shortest link between the
+    # 2 unit cells is added. The other links are added with infinite
+    # values, so that the subgraph has the same sparsity structure
+    # as 'graph'.
+    idx = np.argmin(hop.data)
+    data = np.full_like(hop.data, np.inf)
+    data[idx] = hop.data[idx]
+    smallest_edge = scipy.sparse.coo_matrix(
+        (data, (hop.row, hop.col)),
+        shape=extended_shape)
+    subgraph = add_coo_matrices(uc1, uc2, smallest_edge, smallest_edge.T,
+                                shape=extended_shape)
+
+    # Use these two graphs to find the loops between unit cells.
+    loops.extend(find_loops(graph, subgraph))
+
+    def extended_sites(i):
+        unit_cell = np.array([i // cell_size])
+        site = syst.sites[i % cell_size]
+        return syst.symmetry.act(-unit_cell, site)
+
+    return loops, extended_sites
+
+
+def loops_in_composite(syst):
+    """Find the loops in finite system with leads.
+
+    Parameters
+    ----------
+    syst : kwant.builder.FiniteSystem
+
+    Returns
+    -------
+    loops : sequence of sequences of integers
+        The sites in each loop belong to the extended system (see notes).
+        The first and last site in each loop are guaranteed to be in 'syst'.
+    which_patch : callable : int -> int
+        Given a site index in the extended scattering region (see notes),
+        returns the lead patch (see notes) to which the site belongs. Returns
+        -1 if the site is part of the reduced scattering region (see notes).
+    extended_sites : callable : int -> Site
+        Given a site index in the extended scattering region (see notes),
+        returns the associated high-level `kwant.builder.Site`.
+
+    Notes
+    -----
+    extended system
+        The scattering region with a single lead unit cell attached at
+        each interface. This unit cell is added so that we can "see" any
+        loops formed with sites in the lead (see 'check_infinite_syst'
+        for details). The sites for each lead are added in the same
+        order as the leads, and within a given added unit cell the sites
+        are ordered in the same way as the associated lead.
+    lead patch
+        Sites in the extended system that belong to the added unit cell
+        for a given lead, or the lead padding for a given lead are said
+        to be in the "lead patch" for that lead.
+    reduced scattering region
+        The sites of the extended system that are not in a lead patch.
+    """
+    # Check that we can consistently fix the gauge in the scattering region,
+    # given that we have independently fixed gauges in the leads.
+    check_composite_syst(syst)
+
+    # Get distance matrix for the extended system, a function that maps sites
+    # to their lead patches (-1 for sites in the reduced scattering region),
+    # and a function that maps sites to high-level 'kwant.builder.Site' objects.
+    distance_matrix, which_patch, extended_sites = extended_scattering_region(syst)
+
+    spanning_tree = spanning_tree_composite(distance_matrix, which_patch).tocsr()
+
+    # Fill in all links with at least 1 site in a lead patch;
+    # their gauge is fixed by the lead gauge.
+    for i, j, v in zip(distance_matrix.row, distance_matrix.col,
+                       distance_matrix.data):
+        if which_patch(i) > -1 or which_patch(j) > -1:
+            assign_csr(spanning_tree, v, (i, j))
+            assign_csr(spanning_tree, v, (j, i))
+
+    loops = find_loops(distance_matrix, spanning_tree)
+
+    return loops, which_patch, extended_sites
+
+
+def extended_scattering_region(syst):
+    """Return the distance matrix of a finite system with 1 unit cell
+       added to each lead interface.
+
+    Parameters
+    ----------
+    syst : kwant.builder.FiniteSystem
+
+    Returns
+    -------
+    extended_scattering_region: COO matrix
+        Distance matrix between connected sites in the extended
+        scattering region.
+    which_patch : callable : int -> int
+        Given a site index in the extended scattering region, returns
+        the lead patch to which the site belongs. Returns
+        -1 if the site is part of the reduced scattering region.
+    extended_sites : callable : int -> Site
+        Given a site index in the extended scattering region, returns
+        the associated high-level `kwant.builder.Site`.
+
+    Notes
+    -----
+    Definitions of the terms 'extended scatteringr region',
+    'lead patch' and 'reduced scattering region' are given
+    in the notes for `kwant.physics.gauge.loops_in_composite`.
+    """
+    extended_size = (syst.graph.num_nodes
+                     + sum(l.cell_size for l in syst.leads))
+    extended_shape = (extended_size, extended_size)
+
+    added_unit_cells = []
+    first_lead_site = syst.graph.num_nodes
+    for lead, interface in zip(syst.leads, syst.lead_interfaces):
+        # Here we assume that the distance between sites in the added
+        # unit cell and sites in the interface is the same as between sites
+        # in neighboring unit cells.
+        uc = distance_matrix(list(lead.graph),
+                             pos=lead.pos, shape=extended_shape)
+        # Map unit cell lead sites to their indices in the extended scattering,
+        # region and sites in next unit cell to their interface sites.
+        hop_from_syst = uc.row >= lead.cell_size
+        uc.row[~hop_from_syst] = uc.row[~hop_from_syst] + first_lead_site
+        uc.row[hop_from_syst] = interface[uc.row[hop_from_syst] - lead.cell_size]
+        # Same for columns
+        hop_to_syst = uc.col >= lead.cell_size
+        uc.col[~hop_to_syst] = uc.col[~hop_to_syst] + first_lead_site
+        uc.col[hop_to_syst] = interface[uc.col[hop_to_syst] - lead.cell_size]
+
+        added_unit_cells.append(uc)
+        first_lead_site += lead.cell_size
+
+    scattering_region = distance_matrix(list(syst.graph),
+                                        pos=syst.pos, shape=extended_shape)
+
+    extended_scattering_region = add_coo_matrices(scattering_region,
+                                                  *added_unit_cells,
+                                                  shape=extended_shape)
+
+    lead_starts = np.cumsum([syst.graph.num_nodes,
+                             *[lead.cell_size for lead in syst.leads]])
+    # Frozenset to quickly check 'is this site in the lead padding?'
+    extra_sites = [frozenset(sites) for sites in syst.lead_paddings]
+
+
+    def which_patch(i):
+        if i < len(syst.sites):
+            # In scattering region
+            for patch_num, sites in enumerate(extra_sites):
+                if i in sites:
+                    return patch_num
+            # If not in 'extra_sites' it is in the reduced scattering region.
+            return -1
+        else:
+            # Otherwise it's in an attached lead cell
+            which_lead = bisect.bisect(lead_starts, i) - 1
+            assert which_lead > -1
+            return which_lead
+
+
+    # Here we use the fact that all the sites in a lead interface belong
+    # to the same symmetry domain.
+    interface_domains = [lead.symmetry.which(syst.sites[interface[0]])
+                         for lead, interface in
+                         zip(syst.leads, syst.lead_interfaces)]
+
+    def extended_sites(i):
+        if i < len(syst.sites):
+            # In scattering region
+            return syst.sites[i]
+        else:
+            # Otherwise it's in an attached lead cell
+            which_lead = bisect.bisect(lead_starts, i) - 1
+            assert which_lead > -1
+            lead = syst.leads[which_lead]
+            domain = interface_domains[which_lead] + 1
+            # Map extended scattering region site index to site index in lead.
+            i = i - lead_starts[which_lead]
+            return lead.symmetry.act(domain, lead.sites[i])
+
+    return extended_scattering_region, which_patch, extended_sites
+
+
+def _interior_links(distance_matrix, which_patch):
+    """Return the indices of the links in 'distance_matrix' that
+       connect interface sites of the scattering region to other
+       sites (interface and non-interface) in the scattering region.
+    """
+
+    def _is_in_lead(i):
+        return which_patch(i) > -1
+
+    # Sites that connect to/from sites in a lead patch
+    interface_sites = {
+        (i if not _is_in_lead(i) else j)
+        for i, j in zip(distance_matrix.row, distance_matrix.col)
+        if _is_in_lead(i) ^ _is_in_lead(j)
+    }
+
+    def _we_want(i, j):
+        return i in interface_sites and not _is_in_lead(j)
+
+    # Links that connect interface sites to the rest of the scattering region.
+    return np.array([
+        k
+        for k, (i, j) in enumerate(zip(distance_matrix.row, distance_matrix.col))
+        if _we_want(i, j) or _we_want(j, i)
+    ])
+
+
+def _make_metatree(graph, links_to_delete):
+    """Make a tree of the components of 'graph' that are
+       disconnected by deleting 'links'. The values of
+       the returned tree are indices of edges in 'graph'
+       that connect components.
+    """
+    # Partition the graph into disconnected components
+    dl = partial(np.delete, obj=links_to_delete)
+    partitioned_graph = scipy.sparse.coo_matrix(
+        (dl(graph.data), (dl(graph.row), dl(graph.col)))
+    )
+    # Construct the "metagraph", where each component is reduced to
+    # a single node, and a representative (smallest) edge is chosen
+    # among the edges that connected the components in the original graph.
+    ncc, labels = csgraph.connected_components(partitioned_graph)
+    metagraph = scipy.sparse.dok_matrix((ncc, ncc), int)
+    for k in links_to_delete:
+        i, j = labels[graph.row[k]], labels[graph.col[k]]
+        if i == j:
+            continue  # Discard loop edges
+        # Add a representative (smallest) edge from each graph component.
+        if graph.data[k] < metagraph.get((i, j), np.inf):
+            metagraph[i, j] = k
+            metagraph[j, i] = k
+
+    return csgraph.minimum_spanning_tree(metagraph).astype(int)
+
+
+def spanning_tree_composite(distance_matrix, which_patch):
+    """Find a spanning tree for a composite system.
+
+    We cannot use a simple minimum-distance spanning tree because
+    we have the additional constraint that all links with at least
+    one end in a lead patch have their gauge fixed. See the notes
+    for details.
+
+    Parameters
+    ----------
+    distance_matrix : COO matrix
+        Distance matrix between connected sites in the extended
+        scattering region.
+    which_patch : callable : int -> int
+        Given a site index in the extended scattering region (see notes),
+        returns the lead patch (see notes) to which the site belongs. Returns
+        -1 if the site is part of the reduced scattering region (see notes).
+    Returns
+    -------
+    spanning_tree : CSR matrix
+        A spanning tree with the same sparsity structure as 'distance_matrix',
+        where missing links are denoted with infinite weights.
+
+    Notes
+    -----
+    Definitions of the terms 'extended scattering region', 'lead patch'
+    and 'reduced scattering region' are given in the notes for
+    `kwant.physics.gauge.loops_in_composite`.
+
+    We cannot use a simple minimum-distance spanning tree because
+    we have the additional constraint that all links with at least
+    one end in a lead patch have their gauge fixed.
+    Consider the following case using a minimum-distance tree
+    where 'x' are sites in the lead patch::
+
+        o-o-x      o-o-x
+        | | |  -->   | |
+        o-o-x      o-o x
+
+    The removed link on the lower right comes from the lead, and hence
+    is gauge-fixed, however the vertical link in the center is not in
+    the lead, but *is* in the tree, which means that we will fix its
+    gauge to 0. The loop on the right would thus not have the correct
+    gauge on all links.
+
+    Instead we first cut all links between *interface* sites and
+    sites in the scattering region (including other interface sites).
+    We then construct a minimum distance forest for these disconnected
+    graphs. Finally we add back links from the ones that were cut,
+    ensuring that we do not form any loops; we do this by contructing
+    a tree of representative links from the "metagraph" of components
+    that were disconnected by the link cutting.
+    """
+    # Links that connect interface sites to other sites in the
+    # scattering region (including other interface sites)
+    links_to_delete = _interior_links(distance_matrix, which_patch)
+    # Make a shortest distance tree for each of the components
+    # obtained by cutting the links.
+    cut_syst = distance_matrix.copy()
+    cut_syst.data[links_to_delete] = np.inf
+    forest = shortest_distance_forest(cut_syst)
+    # Connect the forest back up with representative links until
+    # we have a single tree (if the original system was not connected,
+    # we get a forest).
+    metatree = _make_metatree(distance_matrix, links_to_delete)
+    for k in np.unique(metatree.data):
+        value = distance_matrix.data[k]
+        i, j = distance_matrix.row[k], distance_matrix.col[k]
+        assign_csr(forest, value, (i, j))
+        assign_csr(forest, value, (j, i))
+
+    return forest
+
+
+def check_infinite_syst(syst):
+    r"""Check that the unit cell is a connected graph.
+
+    If the unit cell is not connected then we cannot be sure whether
+    there are loops or not just by inspecting the unit cell graph
+    (this may be a solved problem, but we could not find an algorithm
+    to do this).
+
+    To illustrate this, consider the following unit cell consisting
+    of 3 sites and 4 hoppings::
+
+        o-
+         \
+        o
+         \
+        o-
+
+    None of the sites are connected within the unit cell, however if we repeat
+    a few unit cells::
+
+        o-o-o-o
+         \ \ \
+        o o o o
+         \ \ \
+        o-o-o-o
+
+    we see that there is a loop crossing 4 unit cells. A connected unit cell
+    is a sufficient condition that all the loops can be found by inspecting
+    the graph consisting of two unit cells glued together.
+    """
+    assert isinstance(syst, system.InfiniteSystem)
+    n = syst.cell_size
+    rows, cols = np.array([(i, j) for i, j in syst.graph
+                            if i < n and j < n]).transpose()
+    data = np.ones(len(rows))
+    graph = scipy.sparse.coo_matrix((data, (rows, cols)), shape=(n, n))
+    if csgraph.connected_components(graph, return_labels=False) > 1:
+        raise ValueError(
+            'Infinite system unit cell is not connected: we cannot determine '
+            'if there are any loops in the system\n\n'
+            'If there are, then you must define your unit cell so that it is '
+            'connected. If there are not, then you must add zero-magnitude '
+            'hoppings to your system.'
+        )
+
+
+def check_composite_syst(syst):
+    """Check that we can consistently fix the gauge in a system with leads.
+
+    If not, raise an exception with an informative error message.
+    """
+    assert isinstance(syst, system.FiniteSystem) and syst.leads
+    # Frozenset to quickly check 'is this site in the lead padding?'
+    extras = [frozenset(sites) for sites in syst.lead_paddings]
+    interfaces = [set(iface) for iface in syst.lead_interfaces]
+    # Make interfaces between lead patches and the reduced scattering region.
+    for interface, extra in zip(interfaces, extras):
+        extra_interface = set()
+        if extra:
+            extra_interface = set()
+            for i, j in syst.graph:
+                if i in extra and j not in extra:
+                    extra_interface.add(j)
+            interface -= extra
+            interface |= extra_interface
+        assert not extra.intersection(interface)
+
+    pre_msg = (
+        'Attaching leads results in gauge-fixed loops in the extended '
+        'scattering region (scattering region plus one lead unit cell '
+        'from every lead). This does not allow consistent gauge-fixing.\n\n'
+    )
+    solution_msg = (
+        'To avoid this error, attach leads further away from each other.\n\n'
+        'Note: calling `attach_lead()` with `add_cells > 0` will not fix '
+        'this problem, as the added sites inherit the gauge from the lead. '
+        'To extend the scattering region, you must manually add sites '
+        'making sure that they use the scattering region gauge.'
+    )
+
+    # Check that there is at most one overlapping site between
+    # reduced interface of one lead and extra sites of another
+    num_leads = len(syst.leads)
+    metagraph = scipy.sparse.lil_matrix((num_leads, num_leads))
+    for i, j in permutations(range(num_leads), 2):
+        intersection = len(interfaces[i] & (interfaces[j] | extras[j]))
+        if intersection > 1:
+            raise ValueError(
+                pre_msg
+                + ('There is at least one gauge-fixed loop in the overlap '
+                   'of leads {} and {}.\n\n'.format(i, j))
+                + solution_msg
+            )
+        elif intersection == 1:
+            metagraph[i, j] = 1
+    # Check that there is no loop formed by gauge-fixed bonds of multiple leads.
+    num_components = scipy.sparse.csgraph.connected_components(metagraph, return_labels=False)
+    if metagraph.nnz // 2 + num_components != num_leads:
+        raise ValueError(
+            pre_msg
+            + ('There is at least one gauge-fixed loop formed by more than 2 leads. '
+               ' The connectivity matrix of the leads is:\n\n'
+               '{}\n\n'.format(metagraph.A))
+            + solution_msg
+        )
+
 ### Phase calculation
 
 def calculate_phases(loops, pos, previous_phase, flux):
@@ -335,6 +803,42 @@ def _previous_phase_finite(phases, path):
     return previous_phase
 
 
+def _previous_phase_infinite(cell_size, phases, path):
+    previous_phase = 0
+    for i, j in zip(path, path[1:]):
+        # i and j are only in the fundamental unit cell (0 <= i < cell_size)
+        # or the next one (cell_size <= i < 2 * cell_size).
+        if i >= cell_size and j >= cell_size:
+            assert i // cell_size == j // cell_size
+            i = i % cell_size
+            j = j % cell_size
+        previous_phase += phases.get((i, j), 0)
+        previous_phase -= phases.get((j, i), 0)
+    return previous_phase
+
+
+def _previous_phase_composite(which_patch, extended_sites, lead_phases,
+                              phases, path):
+    previous_phase = 0
+    for i, j in zip(path, path[1:]):
+        patch_i = which_patch(i)
+        patch_j = which_patch(j)
+        if patch_i == -1 and patch_j == -1:
+            # Both sites in reduced scattering region.
+            previous_phase += phases.get((i, j), 0)
+            previous_phase -= phases.get((j, i), 0)
+        else:
+            # At least one site in a lead patch; use the phase from the
+            # associated lead. Check that if both are in a patch, they
+            # are in the same patch.
+            assert patch_i * patch_j <= 0 or patch_i == patch_j
+            patch = max(patch_i, patch_j)
+            a, b = extended_sites(i), extended_sites(j)
+            previous_phase += lead_phases[patch](a, b)
+
+    return previous_phase
+
+
 ### High-level interface
 #
 # These functions glue all the above functionality together.
@@ -348,6 +852,19 @@ def _finite_wrapper(syst, phases, a, b):
     return phases.get((i, j), -phases.get((j, i), 0))
 
 
+def _infinite_wrapper(syst, phases, a, b):
+    sym = syst.symmetry
+    # Bring link to fundamental domain consistently with how
+    # we store the phases.
+    t = max(sym.which(a), sym.which(b))
+    a, b = sym.act(-t, a, b)
+    i = syst.id_by_site[a]
+    j = syst.id_by_site[b]
+    # We only store *either* (i, j) *or* (j, i). If not present
+    # then the phase is zero by definition.
+    return phases.get((i, j), -phases.get((j, i), 0))
+
+
 def _gauge_finite(syst):
     loops = loops_in_finite(syst)
 
@@ -365,6 +882,53 @@ def _gauge_finite(syst):
     return _gauge
 
 
+def _gauge_infinite(syst):
+    loops, extended_sites = loops_in_infinite(syst)
+
+    def _gauge(syst_field, tol=1E-8, average=False):
+        integrate = partial(surface_integral, syst_field,
+                            tol=tol, average=average)
+        phases = calculate_phases(
+            loops,
+            lambda i: extended_sites(i).pos,
+            partial(_previous_phase_infinite, syst.cell_size),
+            integrate,
+        )
+        return partial(_infinite_wrapper, syst, phases)
+
+    return _gauge
+
+
+def _gauge_composite(syst):
+    loops, which_patch, extended_sites = loops_in_composite(syst)
+    lead_gauges = [_gauge_infinite(lead) for lead in syst.leads]
+
+    def _gauge(syst_field, *lead_fields, tol=1E-8, average=False):
+        if len(lead_fields) != len(syst.leads):
+            raise ValueError('Magnetic fields must be provided for all leads.')
+
+        lead_phases = [gauge(B, tol=tol, average=average)
+                       for gauge, B in zip(lead_gauges, lead_fields)]
+
+        flux = partial(surface_integral, syst_field, tol=tol, average=average)
+
+        # NOTE: uses the scattering region magnetic field to set the phase
+        # of the inteface hoppings this choice is somewhat arbitrary,
+        # but it is consistent with the position defined in the scattering
+        # region coordinate system. the integrate functions for the leads
+        # may be defined far from the interface.
+        phases = calculate_phases(
+            loops,
+            lambda i: extended_sites(i).pos,
+            partial(_previous_phase_composite,
+                    which_patch, extended_sites, lead_phases),
+            flux,
+        )
+
+        return (partial(_finite_wrapper, syst, phases), *lead_phases)
+
+    return _gauge
+
 def magnetic_gauge(syst):
     """Fix the magnetic gauge for a finalized system.
 
@@ -377,20 +941,24 @@ def magnetic_gauge(syst):
 
     Parameters
     ----------
-    syst : kwant.builder.FiniteSystem
-        May not have leads attached (this restriction will
-        be lifted in the future).
+    syst : kwant.builder.FiniteSystem or kwant.builder.InfiniteSystem
 
     Returns
     -------
-    gauge : callable
-        When called with a magnetic field as argument, returns
-        another callable 'phase' that returns the Peierls phase to
-        apply to a given hopping.
+    gauge : a callable or sequence thereof
+        If 'syst' is an infinite system, or a finite system without
+        leads, returns a single callable. If 'syst' has leads, then
+        returns a sequence of callables, where the first callable
+        is the gauge of the scattering region, and the rest are
+        the gauges of the leads.
+        When the returned callable(s) is called with a magnetic field
+        as argument, returns another callable 'phase' that takes pairs
+        of sites and returns the Peierls phase to apply to the
+        corresponding hopping.
 
     Examples
     --------
-    The following illustrates basic usage:
+    The following illustrates basic usage for a finite system:
 
     >>> import numpy as np
     >>> import kwant
@@ -408,10 +976,10 @@ def magnetic_gauge(syst):
     """
     if isinstance(syst, builder.FiniteSystem):
         if syst.leads:
-            raise ValueError('Can only fix magnetic gauge for finite systems '
-                             'without leads')
+            return _gauge_composite(syst)
         else:
             return _gauge_finite(syst)
+    elif isinstance(syst, builder.InfiniteSystem):
+        return _gauge_infinite(syst)
     else:
-        raise TypeError('Can only fix magnetic gauge for finite systems '
-                        'without leads')
+        raise TypeError('Expected a finalized Builder')
diff --git a/kwant/physics/tests/test_gauge.py b/kwant/physics/tests/test_gauge.py
index 762c4a0c..c3917eb7 100644
--- a/kwant/physics/tests/test_gauge.py
+++ b/kwant/physics/tests/test_gauge.py
@@ -1,4 +1,5 @@
 from collections import namedtuple, Counter
+import warnings
 from math import sqrt
 import numpy as np
 import pytest
@@ -151,7 +152,7 @@ def translational_symmetry(lat, neighbors):
 ## Tests
 
 # Tests that phase around a loop is equal to the flux through the loop.
-# First we define the loops that we want to test, for various latticeutils.
+# First we define the loops that we want to test, for various lattices.
 # If a system does not support a particular kind of loop, they will simply
 # not be generated.
 
@@ -214,8 +215,8 @@ def _test_phase_loops(syst, phases, loops):
 
 
 @pytest.mark.parametrize("neighbors", [1, 2, 3])
-@pytest.mark.parametrize("symmetry", [no_symmetry],
-                         ids=['finite'])
+@pytest.mark.parametrize("symmetry", [no_symmetry, translational_symmetry],
+                         ids=['finite', 'infinite'])
 @pytest.mark.parametrize("lattice, loops", [square, honeycomb, cubic],
                          ids=['square', 'honeycomb', 'cubic'])
 def test_phases(lattice, neighbors, symmetry, loops):
@@ -235,6 +236,176 @@ def test_phases(lattice, neighbors, symmetry, loops):
     _test_phase_loops(syst, phases, loops)
 
 
+@pytest.mark.parametrize("neighbors", [1, 2, 3])
+@pytest.mark.parametrize("lat, loops", [square, honeycomb],
+                         ids=['square', 'honeycomb'])
+def test_phases_composite(neighbors, lat, loops):
+    """Check that the phases around common loops are equal to the flux, for
+    composite systems with uniform magnetic field.
+    """
+    W = 4
+    dim = len(lat.prim_vecs)
+    field = np.array([0, 0, 1]) if dim == 3 else 1
+
+    lead = Builder(lattice.TranslationalSymmetry(-lat.prim_vecs[0]))
+    lead.fill(model(lat, neighbors), *hypercube(dim, W))
+
+    # Case where extra sites are added and fields are same in
+    # scattering region and lead.
+    syst = Builder()
+    syst.fill(model(lat, neighbors), *ball(dim, W + 1))
+    extra_sites = syst.attach_lead(lead)
+    assert extra_sites  # make sure we're testing the edge case with added sites
+
+    this_gauge = gauge.magnetic_gauge(syst.finalized())
+    # same field in scattering region and lead
+    phases, lead_phases = this_gauge(field, field)
+
+    # When extra sites are added to the central region we need to select
+    # the correct phase function.
+    def combined_phases(a, b):
+        if a in extra_sites or b in extra_sites:
+            return lead_phases(a, b)
+        else:
+            return phases(a, b)
+
+    _test_phase_loops(syst, combined_phases, loops)
+    _test_phase_loops(lead, lead_phases, loops)
+
+
+@pytest.mark.parametrize("neighbors", [1, 2])
+def test_overlapping_interfaces(neighbors):
+    """Test composite systems with overlapping lead interfaces."""
+
+    lat = square_lattice
+
+    def _make_syst(edge, W=5):
+
+        syst = Builder()
+        syst.fill(model(lat, neighbors), *rectangle(W, W))
+
+        leadx = Builder(lattice.TranslationalSymmetry((-1, 0)))
+        leadx[(lat(0, j) for j in range(edge, W - edge))] = None
+        for n in range(1, neighbors + 1):
+            leadx[lat.neighbors(n)] = None
+
+        leady = Builder(lattice.TranslationalSymmetry((0, -1)))
+        leady[(lat(i, 0) for i in range(edge, W - edge))] = None
+        for n in range(1, neighbors + 1):
+            leady[lat.neighbors(n)] = None
+
+        assert not syst.attach_lead(leadx)  # sanity check; no sites added
+        assert not syst.attach_lead(leady)  # sanity check; no sites added
+
+        return syst, leadx, leady
+
+    # edge == 0: lead interfaces overlap
+    # edge == 1: lead interfaces partition scattering region
+    #            into 2 disconnected components
+    for edge in (0, 1):
+        syst, leadx, leady = _make_syst(edge)
+        this_gauge = gauge.magnetic_gauge(syst.finalized())
+        phases, leadx_phases, leady_phases = this_gauge(1, 1, 1)
+        _test_phase_loops(syst, phases, square_loops)
+        _test_phase_loops(leadx, leadx_phases, square_loops)
+        _test_phase_loops(leady, leady_phases, square_loops)
+
+
+def _make_square_syst(sym, neighbors=1):
+    lat = square_lattice
+    syst = Builder(sym)
+    syst[(lat(i, j) for i in (0, 1) for j in (0, 1))] = None
+    for n in range(1, neighbors + 1):
+        syst[lat.neighbors(n)] = None
+    return syst
+
+
+def test_unfixable_gauge():
+    """Check error is raised when we cannot fix the gauge."""
+
+    leadx = _make_square_syst(lattice.TranslationalSymmetry((-1, 0)))
+    leady = _make_square_syst(lattice.TranslationalSymmetry((0, -1)))
+
+    # 1x2 line with 2 leads
+    syst = _make_square_syst(NoSymmetry())
+    del syst[[square_lattice(1, 0), square_lattice(1, 1)]]
+    syst.attach_lead(leadx)
+    syst.attach_lead(leadx.reversed())
+    with pytest.raises(ValueError):
+        gauge.magnetic_gauge(syst.finalized())
+
+    # 2x2 square with leads attached from all 4 sides,
+    # and nearest neighbor hoppings
+    syst = _make_square_syst(NoSymmetry())
+    # Until we add the last lead we have enough gauge freedom
+    # to specify independent fields in the scattering region
+    # and each of the leads. We check that no extra sites are
+    # added as a sanity check.
+    assert not syst.attach_lead(leadx)
+    gauge.magnetic_gauge(syst.finalized())
+    assert not syst.attach_lead(leady)
+    gauge.magnetic_gauge(syst.finalized())
+    assert not syst.attach_lead(leadx.reversed())
+    gauge.magnetic_gauge(syst.finalized())
+    # Adding the last lead removes our gauge freedom.
+    assert not syst.attach_lead(leady.reversed())
+    with pytest.raises(ValueError):
+        gauge.magnetic_gauge(syst.finalized())
+
+    # 2x2 square with 2 leads, but 4rd nearest neighbor hoppings
+    syst = _make_square_syst(NoSymmetry())
+    del syst[(square_lattice(1, 0), square_lattice(1, 1))]
+    leadx = _make_square_syst(lattice.TranslationalSymmetry((-1, 0)))
+    leadx[square_lattice.neighbors(4)] = None
+    for lead in (leadx, leadx.reversed()):
+        syst.attach_lead(lead)
+
+    with pytest.raises(ValueError):
+        gauge.magnetic_gauge(syst.finalized())
+
+
+def _test_disconnected(syst):
+    with pytest.raises(ValueError) as excinfo:
+        gauge.magnetic_gauge(syst.finalized())
+        assert 'unit cell not connected' in str(excinfo.value)
+
+def test_invalid_lead():
+    """Check error is raised when a lead unit cell is not connected
+       within the unit cell itself.
+
+       In order for the algorithm to work we need to be able to close
+       loops within the lead. However we only add a single lead unit
+       cell, so not all paths can be closed, even if the lead is
+       connected.
+    """
+    lat = square_lattice
+
+    lead = _make_square_syst(lattice.TranslationalSymmetry((-1, 0)),
+                             neighbors=0)
+    # Truly disconnected system
+    # Ignore warnings to suppress Kwant's complaint about disconnected lead
+    with warnings.catch_warnings():
+        warnings.simplefilter('ignore')
+        _test_disconnected(lead)
+
+    # 2 disconnected chains (diagonal)
+    lead[(lat(0, 0), lat(1, 1))] = None
+    lead[(lat(0, 1), lat(1, 0))] = None
+    _test_disconnected(lead)
+
+    lead = _make_square_syst(lattice.TranslationalSymmetry((-1, 0)),
+                             neighbors=0)
+    # 2 disconnected chains (horizontal)
+    lead[(lat(0, 0), lat(1, 0))] = None
+    lead[(lat(0, 1), lat(1, 1))] = None
+    _test_disconnected(lead)
+
+    # System has loops, but need 3 unit cells
+    # to express them.
+    lead[(lat(0, 0), lat(1, 1))] = None
+    lead[(lat(0, 1), lat(1, 0))] = None
+    _test_disconnected(lead)
+
 # Test internal parts of magnetic_gauge
 
 @pytest.mark.parametrize("shape",
-- 
GitLab