Skip to content
Snippets Groups Projects
Commit 474fe4b7 authored by Christoph Groth's avatar Christoph Groth
Browse files

implement calculation of wave functions inside the scattering region

parent c7f53956
Branches
Tags
No related merge requests found
...@@ -77,6 +77,12 @@ Calculation of the local density of states ...@@ -77,6 +77,12 @@ Calculation of the local density of states
The new function of sparse solvers `~kwant.solvers.common.SparseSolver.ldos` The new function of sparse solvers `~kwant.solvers.common.SparseSolver.ldos`
allows the calculation of the local density of states. allows the calculation of the local density of states.
Calculation of wave functions in the scattering region
------------------------------------------------------
The new function of sparse solvers
`~kwant.solvers.common.SparseSolver.wave_func` allows the calculation of the
wave function in the scattering region due to any mode of any lead.
Return value of sparse solver Return value of sparse solver
----------------------------- -----------------------------
The function `~kwant.solvers.common.SparseSolver.solve` of sparse solvers now The function `~kwant.solvers.common.SparseSolver.solve` of sparse solvers now
......
...@@ -377,6 +377,48 @@ class SparseSolver(object): ...@@ -377,6 +377,48 @@ class SparseSolver(object):
return ldos * (0.5 / np.pi) return ldos * (0.5 / np.pi)
def wave_func(self, sys, energy=0):
"""
Return a callable object for the computation of the wave function
inside the scattering region.
Parameters
----------
sys : `kwant.system.FiniteSystem`
The low level system for which the wave functions are to be
calculated.
Notes
-----
The returned object can be itself called like a function. Given a lead
number, it returns a 2d NumPy array containing the wave function within
the scattering region due to each mode of the given lead. Index 0 is
the mode number, index 1 is the orbital number.
"""
return WaveFunc(self, sys, energy)
class WaveFunc(object):
def __init__(self, solver, sys, energy=0):
for lead in sys.leads:
if not isinstance(lead, system.InfiniteSystem):
# TODO: fix this
msg = 'All leads must be tight binding systems.'
raise ValueError(msg)
(h, self.rhs, kept_vars), lead_info = \
solver._make_linear_sys(sys, [], xrange(len(sys.leads)), energy)
Modes = physics.Modes
num_extra_vars = sum(li.vecs.shape[1] - li.nmodes
for li in lead_info if isinstance(li, Modes))
self.solver = solver
self.num_orb = h.shape[0] - num_extra_vars
self.factorized_h = solver._factorized(h)
def __call__(self, lead):
result = self.solver._solve_linear_sys(
self.factorized_h, [self.rhs[lead]], slice(self.num_orb))
return result.transpose()
class BlockResult(namedtuple('BlockResultTuple', ['data', 'lead_info'])): class BlockResult(namedtuple('BlockResultTuple', ['data', 'lead_info'])):
""" """
......
...@@ -17,7 +17,7 @@ control options that may affect performance: ...@@ -17,7 +17,7 @@ control options that may affect performance:
For more details see `~Solver.options`. For more details see `~Solver.options`.
""" """
__all__ = ['solve', 'ldos', 'options', 'Solver'] __all__ = ['solve', 'ldos', 'wave_func', 'options', 'Solver']
import numpy as np import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
...@@ -179,5 +179,6 @@ default_solver = Solver() ...@@ -179,5 +179,6 @@ default_solver = Solver()
solve = default_solver.solve solve = default_solver.solve
ldos = default_solver.ldos ldos = default_solver.ldos
wave_func = default_solver.wave_func
options = default_solver.options options = default_solver.options
reset_options = default_solver.reset_options reset_options = default_solver.reset_options
...@@ -11,7 +11,7 @@ variable `uses_umfpack` can be checked to determine if UMFPACK is being used. ...@@ -11,7 +11,7 @@ variable `uses_umfpack` can be checked to determine if UMFPACK is being used.
sparse solver framework. sparse solver framework.
""" """
__all__ = ['solve', 'ldos', 'Solver'] __all__ = ['solve', 'ldos', 'wave_func', 'Solver']
import warnings import warnings
import numpy as np import numpy as np
...@@ -163,3 +163,4 @@ default_solver = Solver() ...@@ -163,3 +163,4 @@ default_solver = Solver()
solve = default_solver.solve solve = default_solver.solve
ldos = default_solver.ldos ldos = default_solver.ldos
wave_func = default_solver.wave_func
...@@ -365,3 +365,39 @@ def test_ldos(ldos): ...@@ -365,3 +365,39 @@ def test_ldos(ldos):
fsys = sys.finalized() fsys = sys.finalized()
assert_almost_equal(ldos(fsys, 0), assert_almost_equal(ldos(fsys, 0),
np.array([1, 1]) / (2 * np.pi)) np.array([1, 1]) / (2 * np.pi))
def test_wavefunc_ldos_consistency(wave_func, ldos):
L = 2
W = 3
energy = 0
np.random.seed(31)
sys = kwant.Builder()
left_lead = kwant.Builder(kwant.TranslationalSymmetry((-1, 0)))
top_lead = kwant.Builder(kwant.TranslationalSymmetry((1, 0)))
for b, sites in [(sys, [square(x, y)
for x in range(L) for y in range(W)]),
(left_lead, [square(0, y) for y in range(W)]),
(top_lead, [square(x, 0) for x in range(L)])]:
for site in sites:
h = np.random.rand(n, n) + 1j * np.random.rand(n, n)
h += h.conjugate().transpose()
b[site] = h
for kind in square.nearest:
for hop in b.possible_hoppings(*kind):
b[hop] = 10 * np.random.rand(n, n) + 1j * np.random.rand(n, n)
sys.attach_lead(left_lead)
sys.attach_lead(top_lead)
sys = sys.finalized()
wf = wave_func(sys, energy)
ldos2 = np.zeros(wf.num_orb, float)
for lead in xrange(len(sys.leads)):
temp = abs(wf(lead))
temp **= 2
print type(temp), temp.shape
ldos2 += temp.sum(axis=0)
ldos2 *= (0.5 / np.pi)
assert_almost_equal(ldos2, ldos(sys, energy))
from nose.plugins.skip import Skip, SkipTest from nose.plugins.skip import Skip, SkipTest
from numpy.testing.decorators import skipif from numpy.testing.decorators import skipif
try: try:
from kwant.solvers.mumps import solve, ldos, options, reset_options from kwant.solvers.mumps import \
solve, ldos, wave_func, options, reset_options
import _test_sparse import _test_sparse
_no_mumps = False _no_mumps = False
except ImportError: except ImportError:
...@@ -101,3 +102,10 @@ def test_ldos(): ...@@ -101,3 +102,10 @@ def test_ldos():
reset_options() reset_options()
options(**opts) options(**opts)
_test_sparse.test_ldos(ldos) _test_sparse.test_ldos(ldos)
@skipif(_no_mumps)
def test_wavefunc_ldos_consistency():
for opts in opt_list:
options(**opts)
_test_sparse.test_wavefunc_ldos_consistency(wave_func, ldos)
from nose.plugins.skip import Skip, SkipTest from nose.plugins.skip import Skip, SkipTest
from kwant.solvers.sparse import solve, ldos from kwant.solvers.sparse import solve, ldos, wave_func
import kwant.solvers.sparse import kwant.solvers.sparse
import _test_sparse import _test_sparse
...@@ -50,5 +50,10 @@ def test_umfpack_del(): ...@@ -50,5 +50,10 @@ def test_umfpack_del():
else: else:
raise SkipTest raise SkipTest
def test_ldos(): def test_ldos():
_test_sparse.test_ldos(ldos) _test_sparse.test_ldos(ldos)
def test_wavefunc_ldos_consistency():
_test_sparse.test_wavefunc_ldos_consistency(wave_func, ldos)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment