From 54d82c08bdb0ba0ac27814be5252ff5dd9966f26 Mon Sep 17 00:00:00 2001
From: Tomas Rosdahl <torosdahl@gmail.com>
Date: Fri, 2 Dec 2016 16:01:08 +0100
Subject: [PATCH] properly symmetrize modes using the three fundamental
 discrete symmetries

---
 doc/source/pre/whatsnew/1.3.rst   |   9 +
 kwant/physics/leads.py            | 311 +++++++++++++++++++++++++++++-
 kwant/physics/tests/test_leads.py | 269 +++++++++++++++++++++++++-
 3 files changed, 570 insertions(+), 19 deletions(-)

diff --git a/doc/source/pre/whatsnew/1.3.rst b/doc/source/pre/whatsnew/1.3.rst
index 30c13c02..2a343550 100644
--- a/doc/source/pre/whatsnew/1.3.rst
+++ b/doc/source/pre/whatsnew/1.3.rst
@@ -53,3 +53,12 @@ configuration as specified in this file is now more general, allowing to
 modify any build parameter for any of the compiled extensions contained in
 Kwant.  See the :ref:`Installation instructions <build-configuration>` for
 details.
+
+Scattering states with discrete symmetries
+------------------------------------------
+Three discrete symmetries are especially relevant in condensed matter physics,
+namely time-reversal symmetry, particle-hole symmetry and chiral (or
+sublattice) symmetry. When one or more of these symmetries are present,
+and it can be useful to have leads with scattering states that reflect this.
+It is now possible to these discrete symmetries in Kwant, which then uses
+them to construct scattering states that are accordingly symmetric.
\ No newline at end of file
diff --git a/kwant/physics/leads.py b/kwant/physics/leads.py
index 32ebb8e2..aea5f9d3 100644
--- a/kwant/physics/leads.py
+++ b/kwant/physics/leads.py
@@ -402,7 +402,159 @@ def unified_eigenproblem(a, b=None, tol=1e6):
     return ev, select, propselect, vec_gen, ord_schur
 
 
-def make_proper_modes(lmbdainv, psi, extract, tol=1e6):
+def phs_symmetrization(wfs, particle_hole):
+    """Makes the wave functions that have the same velocity at a time-reversal
+    invariant momentum (TRIM) particle-hole symmetric.
+
+    If P is the particle-hole operator and P^2 = 1, then a particle-hole
+    symmetric wave function at a TRIM is an eigenstate of P with eigenvalue 1.
+    If P^2 = -1, wave functions with the same velocity at a TRIM come in pairs.
+    Such a pair is particle-hole symmetric if the wave functions are related by
+    P, i. e. the pair can be expressed as [psi_n, P psi_n] where psi_n is a wave
+    function.
+
+    To ensure proper ordering of modes, this function also returns an array
+    of indices which ensures that particle-hole partners are properly ordered
+    in this subspace of modes. These are later used with np.lexsort to ensure
+    proper ordering.
+
+    Parameters
+    ----------
+    wfs : numpy array
+        A matrix of propagating wave functions at a TRIM that all have the same
+        velocity. The wave functions form the columns of this matrix.
+    particle_hole : numpy array
+        The matrix representation of the unitary part of the particle-hole
+        operator, expressed in the tight binding basis.
+
+    Returns
+    -------
+    new_wfs : numpy array
+        The matrix of particle-hole symmetric wave functions.
+    TRIM_sort: numpy integer array
+        Index array that stores the proper sort order of particle-hole symmetric
+        wave functions in this subspace.
+    """
+
+    def Pdot(mat):
+        """Apply the particle-hole operator to an array. """
+        return particle_hole.dot(mat.conj())
+
+    # P always squares to 1 or -1.
+    P_squared = np.sign(particle_hole[0,:].dot(particle_hole[:,0].conj()))
+    # np.sign returns the same data type as its argument. Make sure
+    # that the comparison with integers is okay.
+    assert P_squared in (-1, 1)
+
+    if P_squared == 1:
+        # Make particle hole eigenstates.
+        # Phase factor ensures they are not numerically close.
+        phases = np.diag([np.exp(1j*np.angle(wf.T.conj().dot(
+                            Pdot(wf)))*0.5) for wf in wfs.T])
+        new_wfs = wfs.dot(phases) + Pdot(wfs.dot(phases))
+        # Orthonormalize the modes using QR on the matrix of eigenstates of P.
+        # So long as the matrix of coefficients R is purely real, any linear
+        # combination of these modes remains an eigenstate of P. From the way
+        # we construct eigenstates of P, the coefficients of R are real.
+        new_wfs, r = la.qr(new_wfs, mode='economic', pivoting=True)[:2]
+        if not np.allclose(r.imag, np.zeros(r.shape)):
+            raise RuntimeError("Numerical instability in finding particle-hole \
+                                symmetric modes.")
+        # If P^2 = 1, there is no need to sort the modes further.
+        TRIM_sort = np.zeros((wfs.shape[1],), dtype=int)
+    elif P_squared == -1:
+        # We start by trying to construct explicit particle-hole partners,
+        # and then using QR to make then orthonormal. If this does not yield
+        # correct particle-hole symmetric TRIM modes, we use a different
+        # algorithm.
+        ### Try the first algorithm ###
+        # Iterate over pairs of wave function vectors and make them
+        # particle-hole partners.
+        new_wfs = np.empty(wfs.shape, dtype=complex)
+        new_wfs[:, ::2] = wfs[:, ::2] + Pdot(wfs[:, 1::2])
+        new_wfs[:, 1::2] = Pdot(wfs[:, ::2]) - wfs[:, 1::2]
+        assert np.allclose(Pdot(new_wfs[:, ::2]), new_wfs[:, 1::2])
+        # Orthonormalize the modes. This usually gives pairs of orthonormal
+        # modes that are particle-hole partners up to a sign.
+        wf_mat = la.qr(new_wfs, mode='economic', pivoting=True)[0]
+        # Pick the correct sign to have particle-hole partners.
+        new_wfs[:, 1::2] = np.array([odd_col if np.allclose(
+            odd_col, Pdot(even_col)) else -odd_col for even_col, odd_col
+            in zip(wf_mat[:, ::2].T, wf_mat[:, 1::2].T)]).T
+        new_wfs[:, ::2] = wf_mat[:, ::2]
+        # At this stage, the modes within this subspace are properly ordered by
+        # particle-hole symmetry. Store the ordering in an array of indices
+        # that is also returned.
+        TRIM_sort = np.array(range(new_wfs.shape[1]))
+        ###############################
+        # We now check if the modes obey particle-hole symmetry. If they do, the
+        # QR factorization was sufficient, and we're done.
+        if not np.allclose(new_wfs[:, 1::2], Pdot(new_wfs[:, ::2])):
+            # If the modes are not particle-hole symmetric, we use a more
+            # expensive algorithm that definitely works.
+            ### Use the second algorithm ###
+            new_wfs = []
+            # Iterate over wave functions to construct
+            # particle-hole partners.
+            # The number of modes. This is always an even number >=2.
+            N_modes = wfs.shape[1]
+            # If there are only two modes in this subspace, they are orthogonal
+            # so we replace the second one with the P applied to the first one.
+            if N_modes == 2:
+                wf = wfs[:,0]
+                # Store psi_n and P psi_n.
+                new_wfs.append(wf)
+                new_wfs.append(Pdot(wf))
+            # If there are more than two modes, iterate over wave functions
+            # and construct their particle-hole partners one by one.
+            else:
+                # We construct pairs of modes that are particle-hole partners.
+                # Need to iterate over all pairs except the final one.
+                iterations = range((N_modes-2)//2)
+                for i in iterations:
+                    # Take a mode psi_n from the basis - the first column
+                    # of the matrix of remaining modes.
+                    wf = wfs[:,0]
+                    # Store psi_n and P psi_n.
+                    new_wfs.append(wf)
+                    P_wf = Pdot(wf)
+                    new_wfs.append(P_wf)
+                    # Remove psi_n and P psi_n from the basis matrix of modes.
+                    # First remove psi_n.
+                    wfs = wfs[:,1:]
+                    # Now we project the remaining modes onto the orthogonal
+                    # complement of P psi_n. Projector:
+                    Projector = wfs.dot(wfs.T.conj()) - \
+                                np.outer(P_wf, P_wf.T.conj())
+                    # After the projection, the mode matrix is rank deficient -
+                    # the span of the column space has dimension one less than
+                    # the number of columns.
+                    wfs = Projector.dot(wfs)
+                    wfs = la.qr(wfs, mode='economic', pivoting=True)[0]
+                    # Remove the redundant column.
+                    wfs = wfs[:, :-1]
+                    # If this is the final iteration, we only have two modes
+                    # left and can construct particle-hole partners without
+                    # the projection.
+                    if i == iterations[-1]:
+                        assert wfs.shape[1] == 2
+                        wf = wfs[:,0]
+                        # Store psi_n and P psi_n.
+                        new_wfs.append(wf)
+                        P_wf = Pdot(wf)
+                        new_wfs.append(P_wf)
+                    assert np.allclose(wfs.T.conj().dot(wfs),
+                                       np.eye(wfs.shape[1]))
+            new_wfs = np.hstack([col.reshape(len(col), 1)/npl.norm(col) for
+                                 col in new_wfs])
+            assert np.allclose(new_wfs[:, 1::2], Pdot(new_wfs[:, ::2]))
+            ############################
+    assert np.allclose(new_wfs.T.conj().dot(new_wfs), np.eye(new_wfs.shape[1]))
+    return new_wfs, TRIM_sort
+
+
+def make_proper_modes(lmbdainv, psi, extract, tol=1e6, *, particle_hole=None,
+                      time_reversal=None, chiral=None):
     """
     Find, normalize and sort the propagating eigenmodes.
 
@@ -418,6 +570,9 @@ def make_proper_modes(lmbdainv, psi, extract, tol=1e6):
     # Array for the velocities.
     velocities = np.empty(nmodes, dtype=float)
 
+    # Array of indices to sort modes at a TRIM by PHS.
+    TRIM_PHS_sort = np.zeros(nmodes, dtype=int)
+
     # Calculate the full wave function in real space.
     full_psi = extract(psi, lmbdainv)
 
@@ -496,29 +651,160 @@ def make_proper_modes(lmbdainv, psi, extract, tol=1e6):
         full_psi[:, indx] = dot(full_psi[:, indx], rot)
         velocities[indx] = vel_vals
 
+        # With particle-hole symmetry, treat TRIMs individually.
+        # Particle-hole conserves velocity.
+        # If P^2 = 1, we can pick modes at a TRIM as particle-hole eigenstates.
+        # If P^2 = -1, a mode at a TRIM and its particle-hole partner are
+        # orthogonal, and we pick modes such that they are related by
+        # particle-hole symmetry.
+
+        # At a TRIM, propagating translation eigenvalues are +1 or -1.
+        if (particle_hole is not None and
+            (np.abs(np.abs(lmbdainv[indx].real) - 1) < eps).all()):
+            assert not len(indx) % 2
+            # Set the eigenvalues to the exact TRIM values.
+            if (np.abs(lmbdainv[indx].real - 1) < eps).all():
+                lmbdainv[indx] = 1
+            else:
+            # Momenta are the negative arguments of the translation eigenvalues,
+            # as computed below using np.angle. np.angle of -1 is pi, so this
+            # assigns k = -pi to modes with translation eigenvalue -1.
+                lmbdainv[indx] = -1
+
+            # Original wave functions
+            orig_wf = full_psi[:, indx]
+
+            # Modes are already sorted by velocity in ascending order, as
+            # returned by la.eigh above. The first half is thus incident,
+            # and the second half outgoing.
+            # Here we work within a subspace of modes with a fixed velocity.
+            # Mostly, this is done to ensure that modes of different velocities
+            # are not mixed when particle-hole partners are constructed for
+            # P^2 = -1. First, we identify which modes have the same velocity.
+            # In each such subspace of modes, we construct wave functions that
+            # are particle-hole partners.
+            vels = velocities[indx]
+            # Velocities are sorted in ascending order. Find the indices of the
+            # last instance of each unique velocity.
+            inds = [ind+1 for ind, vel in enumerate(vels[:-1])
+                    if np.abs(vel-vels[ind+1])>vel_eps]
+            inds = [0] + inds + [len(vels)]
+            inds = zip(inds[:-1], inds[1:])
+            # Now possess an iterator over tuples, where each tuple (i,j)
+            # contains the starting and final indices i and j of a submatrix
+            # of the modes matrix, such that all modes in the submatrix
+            # have the same velocity.
+
+            # Iterate over all submatrices of modes with the same velocity.
+            new_wf = []
+            TRIM_sorts = []
+            for ind_tuple in inds:
+                # Pick out wave functions that have a given velocity
+                wfs = orig_wf[:, slice(*ind_tuple)]
+                # Make particle-hole symmetric modes
+                new_modes, TRIM_sort = phs_symmetrization(wfs, particle_hole)
+                new_wf.append(new_modes)
+                # Store sorting indices of the TRIM modes with the given
+                # velocity.
+                TRIM_sorts.append(TRIM_sort)
+            # Gather into a matrix of modes
+            new_wf = np.hstack(new_wf)
+            # Store the sort order of all modes at the TRIM.
+            # Used later with np.lexsort when the ordering
+            # of modes is done.
+            TRIM_PHS_sort[indx] = np.hstack(TRIM_sorts)
+            # Replace the old modes.
+            full_psi[:, indx] = new_wf
+            # For both cases P^2 = +-1, must rotate wave functions in the
+            # singular value basis. Find the rotation from new basis to old.
+            rot = new_wf.T.conj().dot(orig_wf)
+            # Rotate the wave functions in the singular value basis
+            psi[:, indx] = psi[:, indx].dot(rot.T.conj())
+
+        # Ensure proper usage of chiral symmetry.
+        if chiral is not None and time_reversal is None:
+            out_orig = full_psi[:, indx[len(indx)//2:]]
+            out = chiral.dot(full_psi[:, indx[:len(indx)//2]])
+            rot = out_orig.T.conj().dot(out)
+            full_psi[:, indx[len(indx)//2:]] = out
+            psi[:, indx[len(indx)//2:]] = psi[:, indx[len(indx)//2:]].dot(rot)
+
     if np.any(abs(velocities) < vel_eps):
         raise RuntimeError("Found a mode with zero or close to zero velocity.")
     if 2 * np.sum(velocities < 0) != len(velocities):
         raise RuntimeError("Numbers of left- and right-propagating "
                            "modes differ, possibly due to a numerical "
                            "instability.")
+
     momenta = -np.angle(lmbdainv)
-    order = np.lexsort([velocities, -np.sign(velocities) * momenta,
-                        np.sign(velocities)])
+    # Sort the modes. The modes are sorted first by velocity and momentum,
+    # and finally TRIM modes are properly ordered.
+    order = np.lexsort([TRIM_PHS_sort, velocities,
+                        -np.sign(velocities) * momenta, np.sign(velocities)])
 
-    # TODO: Remove the check once we depende on numpy>=1.8.
+    # TODO: Remove the check once we depend on numpy>=1.8.
     if not len(order):
         order = slice(None)
     velocities = velocities[order]
-    norm = np.sqrt(abs(velocities))
-    full_psi = full_psi[:, order] / norm
-    psi = psi[:, order] / norm
     momenta = momenta[order]
+    full_psi = full_psi[:, order]
+    psi = psi[:, order]
+
+    # Use particle-hole symmetry to relate modes that propagate in the
+    # same direction but at opposite momenta.
+    # Modes are sorted by velocity (first incident, then outgoing).
+    # Incident modes are then sorted by momentum in ascending order,
+    # and outgoing modes in descending order.
+    # Adopted convention is to get modes with negative k (both in and out)
+    # by applying particle-hole operator to modes with positive k.
+    if particle_hole is not None:
+        N = nmodes//2  # Number of incident or outgoing modes.
+        # With particle-hole symmetry, N must be an even number.
+        # Incident modes
+        positive_k = (np.pi - eps > momenta[:N]) * (momenta[:N] > eps)
+        # Original wave functions with negative values of k
+        orig_neg_k = full_psi[:, :N][:, positive_k[::-1]]
+        # For incident modes, ordering of wfs by momentum as returned by kwant
+        # is [-k2, -k1, k1, k2], if k2, k1 > 0 and k2 > k1.
+        # To maintain this ordering with ki and -ki as particle-hole partners,
+        # reverse the order of the product at the end.
+        wf_neg_k = particle_hole.dot((full_psi[:, :N][:, positive_k]).conj())[:, ::-1]
+        rot = orig_neg_k.T.conj().dot(wf_neg_k)
+        full_psi[:, :N][:, positive_k[::-1]] = wf_neg_k
+        psi[:, :N][:, positive_k[::-1]] = psi[:, :N][:, positive_k[::-1]].dot(rot)
+
+        # Outgoing modes
+        positive_k = (np.pi - eps > momenta[N:]) * (momenta[N:] > eps)
+        # Original wave functions with negative values of k
+        orig_neg_k = full_psi[:, N:][:, positive_k[::-1]]
+        # For outgoing modes, ordering of wfs by momentum as returned by kwant
+        # is like [k2, k1, -k1, -k2], if k2, k1 > 0 and k2 > k1.
+        # Reverse order with [::-1] at the end to match momenta of opposite sign.
+        wf_neg_k = particle_hole.dot(full_psi[:, N:][:, positive_k].conj())[:, ::-1]
+        rot = orig_neg_k.T.conj().dot(wf_neg_k)
+        full_psi[:, N:][:, positive_k[::-1]] = wf_neg_k
+        psi[:, N:][:, positive_k[::-1]] = psi[:, N:][:, positive_k[::-1]].dot(rot)
+
+    # Modes are ordered by velocity.
+    # Use time-reversal symmetry to relate modes of opposite velocity.
+    if time_reversal is not None:
+        # Note: within this function, nmodes refers to the total number
+        # of propagating modes, not either left or right movers.
+        out_orig = full_psi[:, nmodes//2:]
+        out = time_reversal.dot(full_psi[:, :nmodes//2].conj())
+        rot = out_orig.T.conj().dot(out)
+        full_psi[:, nmodes//2:] = out
+        psi[:, nmodes//2:] = psi[:, nmodes//2:].dot(rot)
+
+    norm = np.sqrt(abs(velocities))
+    full_psi = full_psi / norm
+    psi = psi / norm
 
     return psi, PropagatingModes(full_psi, velocities, momenta)
 
 
-def modes(h_cell, h_hop, tol=1e6, stabilization=None):
+def modes(h_cell, h_hop, tol=1e6, stabilization=None, *, particle_hole=None,
+          time_reversal=None, chiral=None):
     """Compute the eigendecomposition of a translation operator of a lead.
 
     Parameters
@@ -544,6 +830,12 @@ def modes(h_cell, h_hop, tol=1e6, stabilization=None):
         to the regular one.  If it is `False`, reduction to a regular problem
         is performed if possible.  Selecting the stabilization manually is
         mostly necessary for testing purposes.
+    particle_hole : sparse or dense square matrix
+        The unitary part of the particle-hole symmetry operator.
+    time_reversal : sparse or dense square matrix
+        The unitary part of the time-reversal symmetry operator.
+    chiral : sparse or dense square matrix
+        The chiral symmetry operator.
 
     Returns
     -------
@@ -599,7 +891,8 @@ def modes(h_cell, h_hop, tol=1e6, stabilization=None):
     prop_vecs = vec_gen(propselect)
     # Compute their velocity, and, if necessary, rotate them
     prop_vecs, real_space_data = make_proper_modes(
-        ev[propselect], prop_vecs, extract, tol)
+        ev[propselect], prop_vecs, extract, tol, particle_hole=particle_hole,
+        time_reversal=time_reversal, chiral=chiral)
 
     vecs = np.c_[prop_vecs[n:], evan_vecs[n:]]
     vecslmbdainv = np.c_[prop_vecs[:n], evan_vecs[:n]]
diff --git a/kwant/physics/tests/test_leads.py b/kwant/physics/tests/test_leads.py
index 2d1fd6de..d8527959 100644
--- a/kwant/physics/tests/test_leads.py
+++ b/kwant/physics/tests/test_leads.py
@@ -9,6 +9,8 @@
 
 import numpy as np
 from numpy.testing import assert_almost_equal
+import scipy.linalg as la
+import numpy.linalg as npl
 from kwant.physics import leads
 import kwant
 
@@ -244,19 +246,16 @@ def test_modes_bearded_ribbon():
     assert leads.modes(h, t)[1].nmodes == 8
 
 
-def test_algorithm_equivalence():
-    np.random.seed(400)
-    n = 12
-    h = np.random.randn(n, n) + 1j * np.random.randn(n, n)
-    h += h.T.conj()
-    t = np.random.randn(n, n) + 1j * np.random.randn(n, n)
+def check_equivalence(h, t, n, sym='', particle_hole=None, chiral=None, time_reversal=None):
+    """Compare modes stabilization algorithms for a given Hamiltonian."""
     u, s, vh = np.linalg.svd(t)
     u, v = u * np.sqrt(s), vh.T.conj() * np.sqrt(s)
     prop_vecs = []
     evan_vecs = []
     algos = [None, (True, True), (True, False), (False, True), (False, False)]
     for algo in algos:
-        result = leads.modes(h, t, stabilization=algo)[1]
+        result = leads.modes(h, t, stabilization=algo, chiral=chiral,
+                             particle_hole=particle_hole, time_reversal=time_reversal)[1]
 
         vecs, vecslmbdainv = result.vecs, result.vecslmbdainv
 
@@ -279,12 +278,51 @@ def test_algorithm_equivalence():
         # By a phase
         np.testing.assert_allclose(np.abs(np.sum(vecs/prop_vecs[0],
                                                  axis=0)), vecs.shape[0],
-                                   err_msg=msg.format(algo))
+                                   err_msg=msg.format(algo)+' in symmetry class '+sym)
 
     for vecs, algo in zip(evan_vecs, algos):
         # Evanescent modes must span the same linear space.
-        assert (np.linalg.matrix_rank(np.c_[vecs, evan_vecs[0]], tol=1e-12) ==
-                vecs.shape[1]), msg.format(algo)
+        mat = np.c_[vecs, evan_vecs[0]]
+        # Scale largest singular value to 1 if the array is not empty
+        mat = mat/np.linalg.norm(mat, ord=2)
+        # As a tolerance, take the square root of machine precision times the largest
+        # matrix dimension.
+        tol = np.abs(np.sqrt(max(mat.shape)*np.finfo(mat.dtype).eps))
+        assert (np.linalg.matrix_rank(mat, tol=tol) ==
+                vecs.shape[1]), msg.format(algo)+' in symmetry class '+sym
+
+
+def test_symm_algorithm_equivalence():
+    """Test different stabilization methods in the computation of modes,
+    in the presence and/or absence of the discrete symmetries."""
+    np.random.seed(400)
+    for n in (12, 20, 40, 60):
+        for sym in kwant.rmt.sym_list:
+            # Random onsite and hopping matrices in symmetry class
+            h_cell = kwant.rmt.gaussian(n, sym)
+            # Hopping is an offdiagonal block of a Hamiltonian. We rescale it
+            # to ensure that there are modes at the Fermi level.
+            h_hop = 10 * kwant.rmt.gaussian(2*n, sym)[:n, n:]
+
+            if kwant.rmt.p(sym):
+                p_mat = np.array(kwant.rmt.h_p_matrix[sym])
+                p_mat = np.kron(np.identity(n // len(p_mat)), p_mat)
+            else:
+                p_mat = None
+
+            if kwant.rmt.t(sym):
+                t_mat = np.array(kwant.rmt.h_t_matrix[sym])
+                t_mat = np.kron(np.identity(n // len(t_mat)), t_mat)
+            else:
+                t_mat = None
+
+            if kwant.rmt.c(sym):
+                c_mat = np.kron(np.identity(n // 2), np.diag([1, -1]))
+            else:
+                c_mat = None
+
+            check_equivalence(h_cell, h_hop, n, sym=sym, particle_hole=p_mat,
+                              chiral=c_mat, time_reversal=t_mat)
 
 
 def test_for_all_evs_equal():
@@ -356,3 +394,214 @@ def test_momenta():
     these should not change when the Hamiltonian is scaled."""
     momenta = [make_clean_lead(10, s, s).modes()[0].momenta for s in [1, 1e20]]
     assert_almost_equal(*momenta)
+
+
+def check_PHS(TRIM, moms, velocities, wfs, pmat):
+    """Check PHS of incident or outgoing modes at a TRIM momentum. Input are momenta,
+    velocities and wave functions of incident or outgoing modes. """
+    # Pick out TRIM momenta - in this test, all momenta are either 0 or -pi,
+    # so the boundaries of the interval here are not important.
+    TRIM_moms = (TRIM-0.1 < moms) * (moms < TRIM+0.1)
+    assert np.allclose(moms[TRIM_moms], TRIM)
+    # At a given momentum, incident modes are sorted in ascending order by velocity.
+    # Pick out modes with the same velocity.
+    vels = velocities[TRIM_moms]
+    inds = [ind+1 for ind, vel in enumerate(vels[:-1])
+            if np.abs(vel-vels[ind+1])>1e-8]
+    inds = [0] + inds + [len(vels)]
+    inds = zip(inds[:-1], inds[1:])
+    for ind_tuple in inds:
+        vel_wfs = wfs[:, slice(*ind_tuple)]
+        assert np.allclose(vels[slice(*ind_tuple)], vels[slice(*ind_tuple)][0])
+        assert_almost_equal(vel_wfs[:, 1::2], pmat.dot(vel_wfs[:, ::2].conj()),
+                            err_msg='Particle-hole symmetry broken at a TRIM')
+
+def test_PHS_TRIM_degenerate_ordering():
+    """ Test PHS at a TRIM, both when it squares to 1 and -1.
+
+    Take a Hamiltonian with 3 degenerate bands, the degeneracy of each is given
+    in the tuple dims. The bands have different velocities. All bands intersect
+    zero energy only at k = 0 and at the edge of the BZ, so all momenta are 0
+    or -pi. We thus have multiple TRIM modes, both with the same and different
+    velocities.
+
+    If P^2 = 1, all TRIM modes are eigenmodes of P.
+    If P^2 = -1, TRIM modes come in pairs of particle-hole partners, ordered by
+    a predefined convention."""
+    sy = np.array([[0,-1j],[1j,0]])
+    sz = np.array([[1,0],[0,-1]])
+    ### P squares to 1 ###
+    np.random.seed(42)
+    dims = (4, 10, 20)
+    ts = (1.0, 1.7, 13.8)
+    rand_hop = 1j*(0.1+np.random.rand())
+    hop = la.block_diag(*[t*rand_hop*np.eye(dim) for t, dim in zip(ts, dims)])
+
+    # Particle-hole operator
+    pmat = np.eye(sum(dims))
+    onsite = np.zeros(hop.shape, dtype=complex)
+    prop, stab = leads.modes(onsite, hop, particle_hole=pmat)
+    # All momenta are either 0 or -pi.
+    assert np.all([np.any(ele - np.array([0, -np.pi])) for ele in prop.momenta])
+    # All modes are eigenmodes of P.
+    assert np.all([np.allclose(wf, pmat.dot(wf.conj())) for wf in prop.wave_functions.T])
+    ###########
+
+    ### P squares to -1 ###
+    np.random.seed(1337)
+    dims = (1, 4, 40)
+    ts = (1.0, 1.7, 13.4)
+
+    hop_mat = np.kron(sz, 1j*(0.1+np.random.rand())*np.eye(2))
+    blocks = []
+    for t, dim in zip(ts, dims):
+        blocks += dim*[t*hop_mat]
+    hop = la.block_diag(*blocks)
+    # Particle-hole operator
+    pmat = np.kron(np.eye(sum(dims)), 1j*np.kron(sz, sy))
+
+    # P squares to -1
+    assert np.allclose(pmat.dot(pmat.conj()), -np.eye(pmat.shape[0]))
+    # The Hamiltonian anticommutes with P
+    assert np.allclose(pmat.dot(hop.conj()).dot(npl.inv(pmat)), -hop)
+
+    onsite = np.zeros(hop.shape, dtype=complex)
+    prop, stab = leads.modes(onsite, hop, particle_hole=pmat)
+    # By design, all momenta are either 0 or -pi.
+    assert np.all([np.any(ele - np.array([0, -np.pi])) for ele in prop.momenta])
+
+    wfs = prop.wave_functions
+    momenta = prop.momenta
+    velocities = prop.velocities
+    nmodes = stab.nmodes
+
+    # By design, all modes are at a TRIM here. Each must thus have a particle-hole
+    # partner at the same TRIM and with the same velocity.
+    # Incident modes
+    check_PHS(0, momenta[:nmodes], velocities[:nmodes], wfs[:, :nmodes], pmat)
+    check_PHS(-np.pi, momenta[:nmodes], velocities[:nmodes], wfs[:, :nmodes], pmat)
+    # Outgoing modes
+    check_PHS(0, momenta[nmodes:], velocities[nmodes:], wfs[:, nmodes:], pmat)
+    check_PHS(-np.pi, momenta[nmodes:], velocities[nmodes:], wfs[:, nmodes:], pmat)
+    ###########
+
+
+def test_modes_symmetries():
+    np.random.seed(10)
+    for n in (4, 8, 40, 100):
+        for sym in kwant.rmt.sym_list:
+            # Random onsite and hopping matrices in symmetry class
+            h_cell = kwant.rmt.gaussian(n, sym)
+            # Hopping is an offdiagonal block of a Hamiltonian. We rescale it
+            # to ensure that there are modes at the Fermi level.
+            h_hop = 10 * kwant.rmt.gaussian(2*n, sym)[:n, n:]
+
+            if kwant.rmt.p(sym):
+                p_mat = np.array(kwant.rmt.h_p_matrix[sym])
+                p_mat = np.kron(np.identity(n // len(p_mat)), p_mat)
+            else:
+                p_mat = None
+
+            if kwant.rmt.t(sym):
+                t_mat = np.array(kwant.rmt.h_t_matrix[sym])
+                t_mat = np.kron(np.identity(n // len(t_mat)), t_mat)
+            else:
+                t_mat = None
+
+            if kwant.rmt.c(sym):
+                c_mat = np.kron(np.identity(n // 2), np.diag([1, -1]))
+            else:
+                c_mat = None
+
+            prop_modes, stab_modes = leads.modes(h_cell, h_hop, particle_hole=p_mat,
+                                                 time_reversal=t_mat, chiral=c_mat)
+            wave_functions = prop_modes.wave_functions
+            momenta = prop_modes.momenta
+            nmodes = stab_modes.nmodes
+
+            if t_mat is not None:
+                assert_almost_equal(wave_functions[:, nmodes:],
+                        t_mat.dot(wave_functions[:, :nmodes].conj()),
+                        err_msg='TRS broken in ' + sym)
+
+            if c_mat is not None:
+                assert_almost_equal(wave_functions[:, nmodes:],
+                        c_mat.dot(wave_functions[:, :nmodes:-1]),
+                        err_msg='SLS broken in ' + sym)
+
+            if p_mat is not None:
+                # If P^2 = -1, then P psi(-k) = -psi(k) for k>0, so one must look at
+                # positive and negative momenta separately.
+                # Test positive momenta.
+                in_positive_k = (np.pi > momenta[:nmodes]) * (momenta[:nmodes] > 0)
+                out_positive_k = (np.pi > momenta[nmodes:]) * (momenta[nmodes:] > 0)
+
+                assert_almost_equal(wave_functions[:, :nmodes][:, in_positive_k[::-1]],
+                        p_mat.dot((wave_functions[:, :nmodes][:, in_positive_k][:, ::-1]).conj()),
+                        err_msg='PHS broken in ' + sym)
+                assert_almost_equal(wave_functions[:, nmodes:][:, out_positive_k[::-1]],
+                        p_mat.dot((wave_functions[:, nmodes:][:, out_positive_k][:, ::-1]).conj()),
+                        err_msg='PHS broken in ' + sym)
+
+                # Test negative momenta. Need the sign of P^2 here.
+                p_squared_sign = np.sign(p_mat.dot(p_mat.conj())[0, 0].real)
+                in_neg_k = (-np.pi < momenta[:nmodes]) * (momenta[:nmodes] < 0)
+                out_neg_k = (-np.pi < momenta[nmodes:]) * (momenta[nmodes:] < 0)
+
+                assert_almost_equal(p_squared_sign*wave_functions[:, :nmodes][:, in_neg_k[::-1]],
+                        p_mat.dot((wave_functions[:, :nmodes][:, in_neg_k][:, ::-1]).conj()),
+                        err_msg='PHS broken in ' + sym)
+                assert_almost_equal(p_squared_sign*wave_functions[:, nmodes:][:, out_neg_k[::-1]],
+                        p_mat.dot((wave_functions[:, nmodes:][:, out_neg_k][:, ::-1]).conj()),
+                        err_msg='PHS broken in ' + sym)
+
+
+def test_PHS_TRIM():
+    """Test the function that makes particle-hole symmetric modes at a TRIM. """
+    np.random.seed(10)
+    for n in (4, 8, 16, 40, 100):
+        for sym in kwant.rmt.sym_list:
+            if kwant.rmt.p(sym):
+                p_mat = np.array(kwant.rmt.h_p_matrix[sym])
+                p_mat = np.kron(np.identity(n // len(p_mat)), p_mat)
+                P_squared = 1 if np.all(np.abs(p_mat.conj().dot(p_mat) -
+                                               np.eye(*p_mat.shape)) < 1e-10) else -1
+                if P_squared == 1:
+                    for nmodes in (1, 3, n//4, n//2, n):
+                        # Random matrix of 'modes.' Take part of a unitary matrix to
+                        # ensure that the modes form a basis.
+                        modes = np.random.rand(n, n) + 1j*np.random.rand(n, n)
+                        modes = la.expm(1j*(modes + modes.T.conj()))[:n, :nmodes]
+                        # Ensure modes are particle-hole symmetric and normalized
+                        modes = modes + p_mat.dot(modes.conj())
+                        modes = np.array([col/np.linalg.norm(col) for col in modes.T]).T
+                        # Mix the modes with a random unitary transformation
+                        U = np.random.rand(nmodes, nmodes) + 1j*np.random.rand(nmodes, nmodes)
+                        U = la.expm(1j*(U + U.T.conj()))
+                        modes = modes.dot(U)
+                        # Make the modes PHS symmetric using the method for a TRIM.
+                        phs_modes = leads.phs_symmetrization(modes, p_mat)[0]
+                        assert_almost_equal(phs_modes, p_mat.dot(phs_modes.conj()),
+                                            err_msg='PHS broken at a TRIM in ' + sym)
+                        assert_almost_equal(phs_modes.T.conj().dot(phs_modes), np.eye(phs_modes.shape[1]),
+                                           err_msg='Modes are not orthonormal, TRIM PHS in ' + sym)
+                elif P_squared == -1:
+                    # Need even number of modes =< n
+                    for nmodes in (2, 4, n//2, n):
+                        # Random matrix of 'modes.' Take part of a unitary matrix to
+                        # ensure that the modes form a basis.
+                        modes = np.random.rand(n, n) + 1j*np.random.rand(n, n)
+                        modes = la.expm(1j*(modes + modes.T.conj()))[:n, :nmodes]
+                        # Ensure modes are particle-hole symmetric and orthonormal.
+                        modes[:, nmodes//2:] = p_mat.dot(modes[:, :nmodes//2].conj())
+                        modes = la.qr(modes, mode='economic')[0]
+                        # Mix the modes with a random unitary transformation
+                        U = np.random.rand(nmodes, nmodes) + 1j*np.random.rand(nmodes, nmodes)
+                        U = la.expm(1j*(U + U.T.conj()))
+                        modes = modes.dot(U)
+                        # Make the modes PHS symmetric using the method for a TRIM.
+                        phs_modes = leads.phs_symmetrization(modes, p_mat)[0]
+                        assert_almost_equal(phs_modes[:, 1::2], p_mat.dot(phs_modes[:, ::2].conj()),
+                                            err_msg='PHS broken at a TRIM in ' + sym)
+                        assert_almost_equal(phs_modes.T.conj().dot(phs_modes), np.eye(phs_modes.shape[1]),
+                                           err_msg='Modes are not orthonormal, TRIM PHS in ' + sym)
\ No newline at end of file
-- 
GitLab