From 474fe4b7e6b05db1ad87af5cdaa3cfe97ca8d1ab Mon Sep 17 00:00:00 2001 From: Christoph Groth <christoph.groth@cea.fr> Date: Mon, 26 Nov 2012 15:21:54 +0100 Subject: [PATCH] implement calculation of wave functions inside the scattering region --- doc/source/whatsnew/0.2.rst | 6 +++++ kwant/solvers/common.py | 42 +++++++++++++++++++++++++++++ kwant/solvers/mumps.py | 3 ++- kwant/solvers/sparse.py | 3 ++- kwant/solvers/tests/_test_sparse.py | 36 +++++++++++++++++++++++++ kwant/solvers/tests/test_mumps.py | 10 ++++++- kwant/solvers/tests/test_sparse.py | 7 ++++- 7 files changed, 103 insertions(+), 4 deletions(-) diff --git a/doc/source/whatsnew/0.2.rst b/doc/source/whatsnew/0.2.rst index a7a90b0a..ae7465fa 100644 --- a/doc/source/whatsnew/0.2.rst +++ b/doc/source/whatsnew/0.2.rst @@ -77,6 +77,12 @@ Calculation of the local density of states The new function of sparse solvers `~kwant.solvers.common.SparseSolver.ldos` allows the calculation of the local density of states. +Calculation of wave functions in the scattering region +------------------------------------------------------ +The new function of sparse solvers +`~kwant.solvers.common.SparseSolver.wave_func` allows the calculation of the +wave function in the scattering region due to any mode of any lead. + Return value of sparse solver ----------------------------- The function `~kwant.solvers.common.SparseSolver.solve` of sparse solvers now diff --git a/kwant/solvers/common.py b/kwant/solvers/common.py index 2b4662bf..b829edf7 100644 --- a/kwant/solvers/common.py +++ b/kwant/solvers/common.py @@ -377,6 +377,48 @@ class SparseSolver(object): return ldos * (0.5 / np.pi) + def wave_func(self, sys, energy=0): + """ + Return a callable object for the computation of the wave function + inside the scattering region. + + Parameters + ---------- + sys : `kwant.system.FiniteSystem` + The low level system for which the wave functions are to be + calculated. + + Notes + ----- + The returned object can be itself called like a function. Given a lead + number, it returns a 2d NumPy array containing the wave function within + the scattering region due to each mode of the given lead. Index 0 is + the mode number, index 1 is the orbital number. + """ + return WaveFunc(self, sys, energy) + + +class WaveFunc(object): + def __init__(self, solver, sys, energy=0): + for lead in sys.leads: + if not isinstance(lead, system.InfiniteSystem): + # TODO: fix this + msg = 'All leads must be tight binding systems.' + raise ValueError(msg) + (h, self.rhs, kept_vars), lead_info = \ + solver._make_linear_sys(sys, [], xrange(len(sys.leads)), energy) + Modes = physics.Modes + num_extra_vars = sum(li.vecs.shape[1] - li.nmodes + for li in lead_info if isinstance(li, Modes)) + self.solver = solver + self.num_orb = h.shape[0] - num_extra_vars + self.factorized_h = solver._factorized(h) + + def __call__(self, lead): + result = self.solver._solve_linear_sys( + self.factorized_h, [self.rhs[lead]], slice(self.num_orb)) + return result.transpose() + class BlockResult(namedtuple('BlockResultTuple', ['data', 'lead_info'])): """ diff --git a/kwant/solvers/mumps.py b/kwant/solvers/mumps.py index f2fabe46..e290d755 100644 --- a/kwant/solvers/mumps.py +++ b/kwant/solvers/mumps.py @@ -17,7 +17,7 @@ control options that may affect performance: For more details see `~Solver.options`. """ -__all__ = ['solve', 'ldos', 'options', 'Solver'] +__all__ = ['solve', 'ldos', 'wave_func', 'options', 'Solver'] import numpy as np import scipy.sparse as sp @@ -179,5 +179,6 @@ default_solver = Solver() solve = default_solver.solve ldos = default_solver.ldos +wave_func = default_solver.wave_func options = default_solver.options reset_options = default_solver.reset_options diff --git a/kwant/solvers/sparse.py b/kwant/solvers/sparse.py index 1abff83c..9d6dfa4f 100644 --- a/kwant/solvers/sparse.py +++ b/kwant/solvers/sparse.py @@ -11,7 +11,7 @@ variable `uses_umfpack` can be checked to determine if UMFPACK is being used. sparse solver framework. """ -__all__ = ['solve', 'ldos', 'Solver'] +__all__ = ['solve', 'ldos', 'wave_func', 'Solver'] import warnings import numpy as np @@ -163,3 +163,4 @@ default_solver = Solver() solve = default_solver.solve ldos = default_solver.ldos +wave_func = default_solver.wave_func diff --git a/kwant/solvers/tests/_test_sparse.py b/kwant/solvers/tests/_test_sparse.py index 2059edf3..94527fdb 100644 --- a/kwant/solvers/tests/_test_sparse.py +++ b/kwant/solvers/tests/_test_sparse.py @@ -365,3 +365,39 @@ def test_ldos(ldos): fsys = sys.finalized() assert_almost_equal(ldos(fsys, 0), np.array([1, 1]) / (2 * np.pi)) + + +def test_wavefunc_ldos_consistency(wave_func, ldos): + L = 2 + W = 3 + energy = 0 + + np.random.seed(31) + sys = kwant.Builder() + left_lead = kwant.Builder(kwant.TranslationalSymmetry((-1, 0))) + top_lead = kwant.Builder(kwant.TranslationalSymmetry((1, 0))) + for b, sites in [(sys, [square(x, y) + for x in range(L) for y in range(W)]), + (left_lead, [square(0, y) for y in range(W)]), + (top_lead, [square(x, 0) for x in range(L)])]: + for site in sites: + h = np.random.rand(n, n) + 1j * np.random.rand(n, n) + h += h.conjugate().transpose() + b[site] = h + for kind in square.nearest: + for hop in b.possible_hoppings(*kind): + b[hop] = 10 * np.random.rand(n, n) + 1j * np.random.rand(n, n) + sys.attach_lead(left_lead) + sys.attach_lead(top_lead) + sys = sys.finalized() + + wf = wave_func(sys, energy) + ldos2 = np.zeros(wf.num_orb, float) + for lead in xrange(len(sys.leads)): + temp = abs(wf(lead)) + temp **= 2 + print type(temp), temp.shape + ldos2 += temp.sum(axis=0) + ldos2 *= (0.5 / np.pi) + + assert_almost_equal(ldos2, ldos(sys, energy)) diff --git a/kwant/solvers/tests/test_mumps.py b/kwant/solvers/tests/test_mumps.py index 78cd8f18..81529592 100644 --- a/kwant/solvers/tests/test_mumps.py +++ b/kwant/solvers/tests/test_mumps.py @@ -1,7 +1,8 @@ from nose.plugins.skip import Skip, SkipTest from numpy.testing.decorators import skipif try: - from kwant.solvers.mumps import solve, ldos, options, reset_options + from kwant.solvers.mumps import \ + solve, ldos, wave_func, options, reset_options import _test_sparse _no_mumps = False except ImportError: @@ -101,3 +102,10 @@ def test_ldos(): reset_options() options(**opts) _test_sparse.test_ldos(ldos) + + +@skipif(_no_mumps) +def test_wavefunc_ldos_consistency(): + for opts in opt_list: + options(**opts) + _test_sparse.test_wavefunc_ldos_consistency(wave_func, ldos) diff --git a/kwant/solvers/tests/test_sparse.py b/kwant/solvers/tests/test_sparse.py index 9e8080d8..9a5c9083 100644 --- a/kwant/solvers/tests/test_sparse.py +++ b/kwant/solvers/tests/test_sparse.py @@ -1,5 +1,5 @@ from nose.plugins.skip import Skip, SkipTest -from kwant.solvers.sparse import solve, ldos +from kwant.solvers.sparse import solve, ldos, wave_func import kwant.solvers.sparse import _test_sparse @@ -50,5 +50,10 @@ def test_umfpack_del(): else: raise SkipTest + def test_ldos(): _test_sparse.test_ldos(ldos) + + +def test_wavefunc_ldos_consistency(): + _test_sparse.test_wavefunc_ldos_consistency(wave_func, ldos) -- GitLab