From 474fe4b7e6b05db1ad87af5cdaa3cfe97ca8d1ab Mon Sep 17 00:00:00 2001
From: Christoph Groth <christoph.groth@cea.fr>
Date: Mon, 26 Nov 2012 15:21:54 +0100
Subject: [PATCH] implement calculation of wave functions inside the scattering
 region

---
 doc/source/whatsnew/0.2.rst         |  6 +++++
 kwant/solvers/common.py             | 42 +++++++++++++++++++++++++++++
 kwant/solvers/mumps.py              |  3 ++-
 kwant/solvers/sparse.py             |  3 ++-
 kwant/solvers/tests/_test_sparse.py | 36 +++++++++++++++++++++++++
 kwant/solvers/tests/test_mumps.py   | 10 ++++++-
 kwant/solvers/tests/test_sparse.py  |  7 ++++-
 7 files changed, 103 insertions(+), 4 deletions(-)

diff --git a/doc/source/whatsnew/0.2.rst b/doc/source/whatsnew/0.2.rst
index a7a90b0a..ae7465fa 100644
--- a/doc/source/whatsnew/0.2.rst
+++ b/doc/source/whatsnew/0.2.rst
@@ -77,6 +77,12 @@ Calculation of the local density of states
 The new function of sparse solvers `~kwant.solvers.common.SparseSolver.ldos`
 allows the calculation of the local density of states.
 
+Calculation of wave functions in the scattering region
+------------------------------------------------------
+The new function of sparse solvers
+`~kwant.solvers.common.SparseSolver.wave_func` allows the calculation of the
+wave function in the scattering region due to any mode of any lead.
+
 Return value of sparse solver
 -----------------------------
 The function `~kwant.solvers.common.SparseSolver.solve` of sparse solvers now
diff --git a/kwant/solvers/common.py b/kwant/solvers/common.py
index 2b4662bf..b829edf7 100644
--- a/kwant/solvers/common.py
+++ b/kwant/solvers/common.py
@@ -377,6 +377,48 @@ class SparseSolver(object):
 
         return ldos * (0.5 / np.pi)
 
+    def wave_func(self, sys, energy=0):
+        """
+        Return a callable object for the computation of the wave function
+        inside the scattering region.
+
+        Parameters
+        ----------
+        sys : `kwant.system.FiniteSystem`
+            The low level system for which the wave functions are to be
+            calculated.
+
+        Notes
+        -----
+        The returned object can be itself called like a function.  Given a lead
+        number, it returns a 2d NumPy array containing the wave function within
+        the scattering region due to each mode of the given lead.  Index 0 is
+        the mode number, index 1 is the orbital number.
+        """
+        return WaveFunc(self, sys, energy)
+
+
+class WaveFunc(object):
+    def __init__(self, solver, sys, energy=0):
+        for lead in sys.leads:
+            if not isinstance(lead, system.InfiniteSystem):
+                # TODO: fix this
+                msg = 'All leads must be tight binding systems.'
+                raise ValueError(msg)
+        (h, self.rhs, kept_vars), lead_info = \
+            solver._make_linear_sys(sys, [], xrange(len(sys.leads)), energy)
+        Modes = physics.Modes
+        num_extra_vars = sum(li.vecs.shape[1] - li.nmodes
+                             for li in lead_info if isinstance(li, Modes))
+        self.solver = solver
+        self.num_orb = h.shape[0] - num_extra_vars
+        self.factorized_h = solver._factorized(h)
+
+    def __call__(self, lead):
+        result = self.solver._solve_linear_sys(
+            self.factorized_h, [self.rhs[lead]], slice(self.num_orb))
+        return result.transpose()
+
 
 class BlockResult(namedtuple('BlockResultTuple', ['data', 'lead_info'])):
     """
diff --git a/kwant/solvers/mumps.py b/kwant/solvers/mumps.py
index f2fabe46..e290d755 100644
--- a/kwant/solvers/mumps.py
+++ b/kwant/solvers/mumps.py
@@ -17,7 +17,7 @@ control options that may affect performance:
 For more details see `~Solver.options`.
 """
 
-__all__ = ['solve', 'ldos', 'options', 'Solver']
+__all__ = ['solve', 'ldos', 'wave_func', 'options', 'Solver']
 
 import numpy as np
 import scipy.sparse as sp
@@ -179,5 +179,6 @@ default_solver = Solver()
 
 solve = default_solver.solve
 ldos = default_solver.ldos
+wave_func = default_solver.wave_func
 options = default_solver.options
 reset_options = default_solver.reset_options
diff --git a/kwant/solvers/sparse.py b/kwant/solvers/sparse.py
index 1abff83c..9d6dfa4f 100644
--- a/kwant/solvers/sparse.py
+++ b/kwant/solvers/sparse.py
@@ -11,7 +11,7 @@ variable `uses_umfpack` can be checked to determine if UMFPACK is being used.
 sparse solver framework.
 """
 
-__all__ = ['solve', 'ldos', 'Solver']
+__all__ = ['solve', 'ldos', 'wave_func', 'Solver']
 
 import warnings
 import numpy as np
@@ -163,3 +163,4 @@ default_solver = Solver()
 
 solve = default_solver.solve
 ldos = default_solver.ldos
+wave_func = default_solver.wave_func
diff --git a/kwant/solvers/tests/_test_sparse.py b/kwant/solvers/tests/_test_sparse.py
index 2059edf3..94527fdb 100644
--- a/kwant/solvers/tests/_test_sparse.py
+++ b/kwant/solvers/tests/_test_sparse.py
@@ -365,3 +365,39 @@ def test_ldos(ldos):
     fsys = sys.finalized()
     assert_almost_equal(ldos(fsys, 0),
                         np.array([1, 1]) / (2 * np.pi))
+
+
+def test_wavefunc_ldos_consistency(wave_func, ldos):
+    L = 2
+    W = 3
+    energy = 0
+
+    np.random.seed(31)
+    sys = kwant.Builder()
+    left_lead = kwant.Builder(kwant.TranslationalSymmetry((-1, 0)))
+    top_lead = kwant.Builder(kwant.TranslationalSymmetry((1, 0)))
+    for b, sites in [(sys, [square(x, y)
+                               for x in range(L) for y in range(W)]),
+                     (left_lead, [square(0, y) for y in range(W)]),
+                     (top_lead, [square(x, 0) for x in range(L)])]:
+        for site in sites:
+            h = np.random.rand(n, n) + 1j * np.random.rand(n, n)
+            h += h.conjugate().transpose()
+            b[site] = h
+        for kind in square.nearest:
+            for hop in b.possible_hoppings(*kind):
+                b[hop] = 10 * np.random.rand(n, n) + 1j * np.random.rand(n, n)
+    sys.attach_lead(left_lead)
+    sys.attach_lead(top_lead)
+    sys = sys.finalized()
+
+    wf = wave_func(sys, energy)
+    ldos2 = np.zeros(wf.num_orb, float)
+    for lead in xrange(len(sys.leads)):
+        temp = abs(wf(lead))
+        temp **= 2
+        print type(temp), temp.shape
+        ldos2 += temp.sum(axis=0)
+    ldos2 *= (0.5 / np.pi)
+
+    assert_almost_equal(ldos2, ldos(sys, energy))
diff --git a/kwant/solvers/tests/test_mumps.py b/kwant/solvers/tests/test_mumps.py
index 78cd8f18..81529592 100644
--- a/kwant/solvers/tests/test_mumps.py
+++ b/kwant/solvers/tests/test_mumps.py
@@ -1,7 +1,8 @@
 from nose.plugins.skip import Skip, SkipTest
 from numpy.testing.decorators import skipif
 try:
-    from  kwant.solvers.mumps import solve, ldos, options, reset_options
+    from kwant.solvers.mumps import \
+        solve, ldos, wave_func, options, reset_options
     import _test_sparse
     _no_mumps = False
 except ImportError:
@@ -101,3 +102,10 @@ def test_ldos():
         reset_options()
         options(**opts)
         _test_sparse.test_ldos(ldos)
+
+
+@skipif(_no_mumps)
+def test_wavefunc_ldos_consistency():
+    for opts in opt_list:
+        options(**opts)
+        _test_sparse.test_wavefunc_ldos_consistency(wave_func, ldos)
diff --git a/kwant/solvers/tests/test_sparse.py b/kwant/solvers/tests/test_sparse.py
index 9e8080d8..9a5c9083 100644
--- a/kwant/solvers/tests/test_sparse.py
+++ b/kwant/solvers/tests/test_sparse.py
@@ -1,5 +1,5 @@
 from nose.plugins.skip import Skip, SkipTest
-from  kwant.solvers.sparse import solve, ldos
+from  kwant.solvers.sparse import solve, ldos, wave_func
 import kwant.solvers.sparse
 import _test_sparse
 
@@ -50,5 +50,10 @@ def test_umfpack_del():
     else:
         raise SkipTest
 
+
 def test_ldos():
     _test_sparse.test_ldos(ldos)
+
+
+def test_wavefunc_ldos_consistency():
+    _test_sparse.test_wavefunc_ldos_consistency(wave_func, ldos)
-- 
GitLab