Commit 5a7fd483 authored by Kloss's avatar Kloss
Browse files

update default argument onebody solver

parent bc63427c
...@@ -612,8 +612,10 @@ example for ``manybody.State()``: ...@@ -612,8 +612,10 @@ example for ``manybody.State()``:
.. jupyter-execute:: .. jupyter-execute::
solver_type = functools.partial(tkwant.onebody.solvers.default, rtol=1E-5) solver_type = functools.partial(tkwant.onebody.solvers.default, rtol=1E-5)
onebody_wavefunction = functools.partial(tkwant.onebody.WaveFunction.from_kwant, solver_type=solver_type) onebody_wavefunction_type = functools.partial(tkwant.onebody.WaveFunction.from_kwant, solver_type=solver_type)
state = manybody.State(syst, tmax=10, onebody_wavefunction_type=onebody_wavefunction) scattering_state_type = functools.partial(tkwant.onebody.ScatteringStates,
wavefunction_type=onebody_wavefunction_type)
state = manybody.State(syst, tmax=10, scattering_state_type=scattering_state_type)
A similar strategy is possible to change the onebody kernels A similar strategy is possible to change the onebody kernels
``onebody.kernels`` that evaluate the right-hand-side of the one-body ``onebody.kernels`` that evaluate the right-hand-side of the one-body
......
[pytest] [pytest]
# add coverage report, source analyzer, and pep8 compliance checks # add coverage report, source analyzer, and pep8 compliance checks
addopts = --cov=tkwant --cov-config=.coveragerc --flakes --pep8 addopts = --cov=tkwant --cov-config=.coveragerc --flake8
testpaths = tkwant testpaths = tkwant
pep8ignore = flake8-ignore =
E266 # multiple comment characters E266 # multiple comment characters
E501 # lines too long E501 # lines too long
E402 # module level import not at top of file E402 # module level import not at top of file
W503 # line break before binary operator
W504 # line break after binary operator
tkwant/line_segment.py ALL # ignore third party code tkwant/line_segment.py ALL # ignore third party code
flakes-ignore = tkwant/__init__.py F401 # __version__ import confuse flakes
tkwant/onebody/__init__.py UndefinedName # programatic imports confuse flakes tkwant/manybody.py E722 # do not use bare 'except'
tkwant/line_segment.py ALL # ignore third party code tkwant/onebody/__init__.py F821 # programatic imports confuse flakes
markers = markers =
integtest: marks tests as (slow) integration tests (run them with '--integtest') integtest: marks tests as (slow) integration tests (run them with '--integtest')
mpitest: marks tests as mpi tests (run them with '--mpitest') mpitest: marks tests as mpi tests (run them with '--mpitest')
......
...@@ -36,5 +36,5 @@ del module # remove cruft from namespace ...@@ -36,5 +36,5 @@ del module # remove cruft from namespace
def test(verbose=True): def test(verbose=True):
"""Run tkwant's unit tests.""" """Run tkwant's unit tests."""
import pytest import pytest
return pytest.main([os.path.dirname(os.path.abspath(__file__)), '-s'] + return pytest.main([os.path.dirname(os.path.abspath(__file__)), '-s']
(['-v'] if verbose else [])) + (['-v'] if verbose else []))
...@@ -14,7 +14,11 @@ import inspect ...@@ -14,7 +14,11 @@ import inspect
import numpy as np import numpy as np
__all__ = ['version', 'TkwantDeprecationWarning', 'is_type', 'is_type_array', __all__ = ['version', 'TkwantDeprecationWarning', 'is_type', 'is_type_array',
'is_not_empty', 'is_zero'] 'is_not_empty', 'is_zero', 'time_start', 'time_name']
# tkwant's default initial time and time argument name
time_start = 0
time_name = 'time'
package_root = os.path.dirname(os.path.realpath(__file__)) package_root = os.path.dirname(os.path.realpath(__file__))
distr_root = os.path.dirname(package_root) distr_root = os.path.dirname(package_root)
......
...@@ -145,6 +145,7 @@ def log_func(logger, funcname=''): ...@@ -145,6 +145,7 @@ def log_func(logger, funcname=''):
return wrapper return wrapper
return decorate return decorate
# two predefined handlers for tkwant # two predefined handlers for tkwant
"""logging handler with format: level:module-name:line-number:MPI-rank: message""" """logging handler with format: level:module-name:line-number:MPI-rank: message"""
......
...@@ -36,12 +36,7 @@ __all__ = ['add_voltage', 'SimpleBoundary', 'MonomialAbsorbingBoundary', ...@@ -36,12 +36,7 @@ __all__ = ['add_voltage', 'SimpleBoundary', 'MonomialAbsorbingBoundary',
# set module logger # set module logger
def _get_mpi_rank(): logger = _logging.make_logger(name=__name__, info=mpi.mpi_rank)
"""Return the mpi rank of tkwants global communicator as a dict value"""
rank = mpi.get_communicator().rank
return {'rank': rank}
logger = _logging.make_logger(name=__name__, info=_get_mpi_rank)
log_func = _logging.log_func(logger) log_func = _logging.log_func(logger)
...@@ -256,8 +251,8 @@ class SimpleBoundary(BoundaryBase): ...@@ -256,8 +251,8 @@ class SimpleBoundary(BoundaryBase):
tmax = self.tmax or self.num_total_cells / max_velocity tmax = self.tmax or self.num_total_cells / max_velocity
# add 1 to prevent off-by-one error due to inter-cell hoppings # add 1 to prevent off-by-one error due to inter-cell hoppings
num_cells = (self.num_total_cells or num_cells = (self.num_total_cells
int(np.ceil(self.tmax * max_velocity)) + 1) or int(np.ceil(self.tmax * max_velocity)) + 1)
assert tmax is not None assert tmax is not None
......
...@@ -543,7 +543,7 @@ def _gauss_from_coefficients_numpy(alpha, beta): ...@@ -543,7 +543,7 @@ def _gauss_from_coefficients_numpy(alpha, beta):
beta = beta.astype(numpy.float64) beta = beta.astype(numpy.float64)
x, V = eig_banded(numpy.vstack((numpy.sqrt(beta), alpha)), lower=False) x, V = eig_banded(numpy.vstack((numpy.sqrt(beta), alpha)), lower=False)
w = beta[0] * scipy.real(scipy.power(V[0, :], 2)) w = beta[0] * numpy.real(numpy.lib.scimath.power(V[0, :], 2))
# eigh_tridiagonal is only available from scipy 1.0.0, and has problems # eigh_tridiagonal is only available from scipy 1.0.0, and has problems
# with precision. TODO find out how/why/what # with precision. TODO find out how/why/what
# try: # try:
......
...@@ -30,12 +30,7 @@ __all__ = ['Occupation', 'Interval', 'lead_occupation', ...@@ -30,12 +30,7 @@ __all__ = ['Occupation', 'Interval', 'lead_occupation',
# set module logger # set module logger
def _get_mpi_rank(): logger = _logging.make_logger(name=__name__, info=mpi.mpi_rank)
"""Return the mpi rank of tkwants global communicator as a dict value"""
rank = mpi.get_communicator().rank
return {'rank': rank}
logger = _logging.make_logger(name=__name__, info=_get_mpi_rank)
log_func = _logging.log_func(logger) log_func = _logging.log_func(logger)
# TODO: reintroduce dataclasses at some point when python 3.7 is stable # TODO: reintroduce dataclasses at some point when python 3.7 is stable
...@@ -211,11 +206,11 @@ def lead_occupation(chemical_potential=0, temperature=0, energy_range=None, ...@@ -211,11 +206,11 @@ def lead_occupation(chemical_potential=0, temperature=0, energy_range=None,
def _check_energies(emin, emax): def _check_energies(emin, emax):
if emin is not None and not _common.is_type(emin, 'real_number'): if emin is not None and not _common.is_type(emin, 'real_number'):
raise TypeError('emin={} in energy_range not a real number.' raise TypeError('emin={} in energy_range not a real number.'
.format(emin)) .format(emin))
if emax is not None and not _common.is_type(emax, 'real_number'): if emax is not None and not _common.is_type(emax, 'real_number'):
raise TypeError('emax={} in energy_range not a real number.' raise TypeError('emax={} in energy_range not a real number.'
.format(emax)) .format(emax))
if emin is not None and emax is not None: if emin is not None and emax is not None:
if emax < emin: if emax < emin:
raise ValueError('emin={} > emax={}'.format(emin, emax)) raise ValueError('emin={} > emax={}'.format(emin, emax))
...@@ -866,7 +861,7 @@ def _calc_modes_and_weights(interval, distribution, spectrum): ...@@ -866,7 +861,7 @@ def _calc_modes_and_weights(interval, distribution, spectrum):
@log_func @log_func
def calc_initial_state(syst, tasks, boundaries, params=None, def calc_initial_state(syst, tasks, boundaries, params=None,
onebody_wavefunction_type=onebody.WaveFunction.from_kwant, scattering_state_type=onebody.ScatteringStates,
mpi_distribute=mpi.round_robin, comm=None): mpi_distribute=mpi.round_robin, comm=None):
"""Calculate the initial manybody scattering wave function using MPI. """Calculate the initial manybody scattering wave function using MPI.
...@@ -888,8 +883,9 @@ def calc_initial_state(syst, tasks, boundaries, params=None, ...@@ -888,8 +883,9 @@ def calc_initial_state(syst, tasks, boundaries, params=None,
params : dict, optional params : dict, optional
Extra arguments to pass to the Hamiltonian of ``syst``, Extra arguments to pass to the Hamiltonian of ``syst``,
excluding time. excluding time.
onebody_wavefunction_type : `tkwant.onebody.WaveFunction`, optional scattering_state_type : `tkwant.onebody.ScatteringStates`, optional
Class to evolve a single-particle wavefunction in time. Class to calculate time-dependent onebody wavefunctions starting in an
equilibrium scattering state.
mpi_distribute : callable, optional mpi_distribute : callable, optional
Function to distribute the tasks dict keys over all MPI ranks. Function to distribute the tasks dict keys over all MPI ranks.
By default, keys must be integer and are distributed round-robin like. By default, keys must be integer and are distributed round-robin like.
...@@ -909,9 +905,8 @@ def calc_initial_state(syst, tasks, boundaries, params=None, ...@@ -909,9 +905,8 @@ def calc_initial_state(syst, tasks, boundaries, params=None,
"""Calculate a one-body scattering state that can be evolved in time""" """Calculate a one-body scattering state that can be evolved in time"""
logger.debug('calc scattering state: energy={}, lead={}'.format(energy, lead)) logger.debug('calc scattering state: energy={}, lead={}'.format(energy, lead))
try: try:
return onebody.ScatteringStates(syst, energy, lead, return scattering_state_type(syst, energy, lead, params=params,
params=params, boundaries=boundaries, boundaries=boundaries)
wavefunction_type=onebody_wavefunction_type)
except Exception: except Exception:
raise RuntimeError('scattering state calculation failed for ' raise RuntimeError('scattering state calculation failed for '
'energy={}, lead={}'.format(energy, lead)) 'energy={}, lead={}'.format(energy, lead))
...@@ -1303,7 +1298,7 @@ class State: ...@@ -1303,7 +1298,7 @@ class State:
def __init__(self, syst, tmax=None, occupations=None, params=None, def __init__(self, syst, tmax=None, occupations=None, params=None,
spectra=None, boundaries=None, intervals=Interval, spectra=None, boundaries=None, intervals=Interval,
refine=True, combine=False, error_estimate_operator=None, refine=True, combine=False, error_estimate_operator=None,
onebody_wavefunction_type=onebody.WaveFunction.from_kwant, scattering_state_type=onebody.ScatteringStates,
manybody_wavefunction_type=WaveFunction, manybody_wavefunction_type=WaveFunction,
mpi_distribute=mpi.round_robin, comm=None): mpi_distribute=mpi.round_robin, comm=None):
r""" r"""
...@@ -1352,7 +1347,9 @@ class State: ...@@ -1352,7 +1347,9 @@ class State:
excluding time. excluding time.
spectra : sequence of `~kwant_spectrum.spectrum`, optional spectra : sequence of `~kwant_spectrum.spectrum`, optional
Energy dispersion :math:`E_n(k)` for the leads. Must have Energy dispersion :math:`E_n(k)` for the leads. Must have
the same length as ``syst.leads``. the same length as ``syst.leads``. Required only if
no ``boundaries`` are provided. If needed but not present,
it will be calculated on the fly from `syst.leads`.
boundaries : sequence of `~tkwant.leads.BoundaryBase`, optional boundaries : sequence of `~tkwant.leads.BoundaryBase`, optional
The boundary conditions for each lead attached to ``syst``. The boundary conditions for each lead attached to ``syst``.
Must have the same length as ``syst.leads``. Must have the same length as ``syst.leads``.
...@@ -1387,8 +1384,11 @@ class State: ...@@ -1387,8 +1384,11 @@ class State:
Observable used for the quadrature error estimate. Observable used for the quadrature error estimate.
Must have the calling signature of `kwant.operator`. Must have the calling signature of `kwant.operator`.
Default: Error estimate with density expectation value. Default: Error estimate with density expectation value.
onebody_wavefunction_type : `tkwant.onebody.WaveFunction`, optional scattering_state_type : `tkwant.onebody.ScatteringStates`, optional
Class to evolve a single-particle wavefunction in time. Class to calculate time-dependent onebody wavefunctions starting in
an equilibrium scattering state. Name of the time argument and
initial time are taken from this class. If this is not possible,
default values are used as a fallback.
manybody_wavefunction_type : `tkwant.manybody.WaveFunction`, optional manybody_wavefunction_type : `tkwant.manybody.WaveFunction`, optional
Class to evolve a many-particle wavefunction in time. Class to evolve a many-particle wavefunction in time.
mpi_distribute : callable, optional mpi_distribute : callable, optional
...@@ -1412,11 +1412,22 @@ class State: ...@@ -1412,11 +1412,22 @@ class State:
if tmax is not None and boundaries is not None: if tmax is not None and boundaries is not None:
raise ValueError("'boundaries' and 'tmax' are mutually exclusive.") raise ValueError("'boundaries' and 'tmax' are mutually exclusive.")
# get initial time and time argument name from WaveFunction # get initial time and time argument name from the onebody wavefunction
time_name = _common.get_default_function_argument(onebody_wavefunction_type, try:
'time_name') default_arg = _common.get_default_function_argument
time_start = _common.get_default_function_argument(onebody_wavefunction_type, onebody_wavefunction_type = default_arg(scattering_state_type,
'time_start') 'wavefunction_type')
time_name = default_arg(onebody_wavefunction_type, 'time_name')
time_start = default_arg(onebody_wavefunction_type, 'time_start')
except Exception:
time_name = _common.time_name
time_start = _common.time_start
onebody_wavefunction_type = None
logger.warning('retrieving initial time and time argument name from',
'the onebody wavefunction failed, use default values: ',
'"time_name"={}, "time_start"={}'.format(time_name,
time_start))
# add initial time to the params dict # add initial time to the params dict
tparams = add_time_to_params(params, time_name=time_name, tparams = add_time_to_params(params, time_name=time_name,
time=time_start, check_numeric_type=True) time=time_start, check_numeric_type=True)
...@@ -1461,6 +1472,7 @@ class State: ...@@ -1461,6 +1472,7 @@ class State:
self.occupations = occupations self.occupations = occupations
self.mpi_distribute = mpi_distribute self.mpi_distribute = mpi_distribute
self.onebody_wavefunction_type = onebody_wavefunction_type self.onebody_wavefunction_type = onebody_wavefunction_type
self.scattering_state_type = scattering_state_type
# no public params attribute exists for the manybody state. # no public params attribute exists for the manybody state.
# each individual one-body state holds its own parameters # each individual one-body state holds its own parameters
...@@ -1543,8 +1555,8 @@ class State: ...@@ -1543,8 +1555,8 @@ class State:
initial manybody state. initial manybody state.
""" """
return calc_initial_state(self.syst, tasks, self.boundaries, return calc_initial_state(self.syst, tasks, self.boundaries,
self._params, self.onebody_wavefunction_type, self._params, self.scattering_state_type,
self.mpi_distribute, comm) mpi_distribute=self.mpi_distribute, comm=comm)
def _get_keys_from_interval(self, interval): def _get_keys_from_interval(self, interval):
"""Return a list of all keys corresponding to an interval. """Return a list of all keys corresponding to an interval.
...@@ -2188,6 +2200,11 @@ class State: ...@@ -2188,6 +2200,11 @@ class State:
``boundstate_psi`` (all keys must be identical) ``boundstate_psi`` (all keys must be identical)
and must be the same on all MPI ranks. and must be the same on all MPI ranks.
""" """
# if onebody_wavefunction_type cannot be extracted from
# scattering_state_type, this routine fails, do at least some warning
if self.onebody_wavefunction_type is None:
logger.warning("wavefunction type is None, ",
"provided boundstates must be time dependent")
make_boundstates_time_dependent(boundstate_psi, boundstate_tasks, make_boundstates_time_dependent(boundstate_psi, boundstate_tasks,
self.syst, self.boundaries, self.syst, self.boundaries,
self._params, self._params,
......
...@@ -13,16 +13,17 @@ import dill ...@@ -13,16 +13,17 @@ import dill
from . import _logging from . import _logging
__all__ = ['communicator_init', 'communicator_free', 'get_communicator', __all__ = ['communicator_init', 'communicator_free', 'get_communicator',
'distribute_dict', 'DistributedDict', 'round_robin'] 'distribute_dict', 'DistributedDict', 'round_robin', 'mpi_rank']
# set module logger # set module logger
def _get_mpi_rank(): def mpi_rank():
"""Return the mpi rank of tkwants global communicator as a dict value""" """Return the mpi rank of tkwants global communicator as a dict value"""
rank = _COMM.rank rank = get_communicator().rank
return {'rank': rank} return {'rank': rank}
logger = _logging.make_logger(name=__name__, info=_get_mpi_rank)
logger = _logging.make_logger(name=__name__, info=mpi_rank)
_COMM = None # the global MPI communicator used by tkwant by default. _COMM = None # the global MPI communicator used by tkwant by default.
......
...@@ -16,7 +16,7 @@ import scipy.sparse as sp ...@@ -16,7 +16,7 @@ import scipy.sparse as sp
import kwant import kwant
import kwant_spectrum import kwant_spectrum
from .. import leads, _common from .. import leads, _common, mpi, _logging
from ..system import (extract_perturbation, hamiltonian_with_boundaries, from ..system import (extract_perturbation, hamiltonian_with_boundaries,
add_time_to_params) add_time_to_params)
from . import kernels, solvers from . import kernels, solvers
...@@ -24,8 +24,12 @@ from . import kernels, solvers ...@@ -24,8 +24,12 @@ from . import kernels, solvers
__all__ = ['WaveFunction', 'ScatteringStates', 'Task'] __all__ = ['WaveFunction', 'ScatteringStates', 'Task']
# data formats # set module logger
logger = _logging.make_logger(name=__name__, info=mpi.mpi_rank)
log_func = _logging.log_func(logger)
# data formats
class Task: class Task:
"""Data format to store the set of quantum numbers that uniquely indentifies """Data format to store the set of quantum numbers that uniquely indentifies
...@@ -117,7 +121,7 @@ class WaveFunction: ...@@ -117,7 +121,7 @@ class WaveFunction:
""" """
def __init__(self, H0, W, psi_init, energy=None, params=None, def __init__(self, H0, W, psi_init, energy=None, params=None,
solution_is_valid=None, time_is_valid=None, solution_is_valid=None, time_is_valid=None,
time_start=0, time_name='time', time_start=_common.time_start, time_name=_common.time_name,
kernel_type=kernels.default, solver_type=solvers.default): kernel_type=kernels.default, solver_type=solvers.default):
r""" r"""
Parameters Parameters
...@@ -204,8 +208,8 @@ class WaveFunction: ...@@ -204,8 +208,8 @@ class WaveFunction:
self.energy = energy self.energy = energy
@classmethod @classmethod
def from_kwant(cls, syst, psi_init, boundaries=None, energy=None, def from_kwant(cls, syst, psi_init, boundaries=None, energy=None, params=None,
params=None, time_start=0, time_name='time', time_start=_common.time_start, time_name=_common.time_name,
kernel_type=kernels.default, solver_type=solvers.default): kernel_type=kernels.default, solver_type=solvers.default):
"""Set up a time-dependent onebody wavefunction from a kwant system. """Set up a time-dependent onebody wavefunction from a kwant system.
...@@ -379,15 +383,20 @@ class ScatteringStates(collections.abc.Iterable): ...@@ -379,15 +383,20 @@ class ScatteringStates(collections.abc.Iterable):
exclusive with 'boundaries'. exclusive with 'boundaries'.
params : dict, optional params : dict, optional
Extra arguments to pass to the Hamiltonian of ``syst``, excluding time. Extra arguments to pass to the Hamiltonian of ``syst``, excluding time.
spectra : sequence of `kwant_spectrum.spectrum` spectra : sequence of `kwant_spectrum.spectrum`, optional
Energy dispersion :math:`E_n(k)` for all leads. Energy dispersion :math:`E_n(k)` for the leads. Must have
the same length as ``syst.leads``. Required only if
no ``boundaries`` are provided. If needed but not present,
it will be calculated on the fly from `syst.leads`.
boundaries : sequence of `~tkwant.leads.BoundaryBase`, optional boundaries : sequence of `~tkwant.leads.BoundaryBase`, optional
The boundary conditions for each lead attached to ``syst``. Mutually The boundary conditions for each lead attached to ``syst``. Mutually
exclusive with 'tmax'. exclusive with 'tmax'.
equilibrium_solver : `kwant.wave_function`, optional equilibrium_solver : `kwant.wave_function`, optional
Solver for initial equilibrium scattering problem. Solver for initial equilibrium scattering problem.
WaveFunction_type : `WaveFunction`, optional wavefunction_type : `WaveFunction`, optional
One-body time-dependent wave function. One-body time-dependent wave function. Name of the time argument and
initial time are taken from this class. If this is not possible,
default values are used as a fallback.
Notes Notes
----- -----
...@@ -410,11 +419,19 @@ class ScatteringStates(collections.abc.Iterable): ...@@ -410,11 +419,19 @@ class ScatteringStates(collections.abc.Iterable):
if lead >= len(syst.leads): if lead >= len(syst.leads):
raise ValueError("lead index must be smaller than {}.".format(len(syst.leads))) raise ValueError("lead index must be smaller than {}.".format(len(syst.leads)))
# get initial time and time argument name from WaveFunction # get initial time and time argument name from the onebody wavefunction
time_name = _common.get_default_function_argument(wavefunction_type, try:
'time_name') time_name = _common.get_default_function_argument(wavefunction_type,
time_start = _common.get_default_function_argument(wavefunction_type, 'time_name')
'time_start') time_start = _common.get_default_function_argument(wavefunction_type,
'time_start')
except Exception:
time_name = _common.time_name
time_start = _common.time_start
logger.warning('retrieving initial time and time argument name from',
'the onebody wavefunction failed, use default values: ',
'"time_name"={}, "time_start"={}'.format(time_name,
time_start))
# add initial time to the params dict # add initial time to the params dict
tparams = add_time_to_params(params, time_name=time_name, tparams = add_time_to_params(params, time_name=time_name,
......
...@@ -83,43 +83,43 @@ def all_solvers(): # test only with the pure python kernel ...@@ -83,43 +83,43 @@ def all_solvers(): # test only with the pure python kernel
def make_simple_lead(lat, N): def make_simple_lead(lat, N):
I = ta.identity(lat.norbs) I0 = ta.identity(lat.norbs)
syst = kwant.Builder(kwant.TranslationalSymmetry((-1, 0))) syst = kwant.Builder(kwant.TranslationalSymmetry((-1, 0)))
syst[(lat(0, j) for j in range(N))] = 4 * I syst[(lat(0, j) for j in range(N))] = 4 * I0
syst[lat.neighbors()] = -1 * I syst[lat.neighbors()] = -1 * I0
return syst return syst
def make_complex_lead(lat, N): def make_complex_lead(lat, N):
I = ta.identity(lat.norbs) I0 = ta.identity(lat.norbs)
syst = kwant.Builder(kwant.TranslationalSymmetry((-1, 0))) syst = kwant.Builder(kwant.TranslationalSymmetry((-1, 0)))
syst[(lat(0, j) for j in range(N))] = 4 * I syst[(lat(0, j) for j in range(N))] = 4 * I0
syst[kwant.HoppingKind((0, 1), lat)] = -1 * I syst[kwant.HoppingKind((0, 1), lat)] = -1 * I0
syst[(lat(0, 0), lat(1, 0))] = -1j * I syst[(lat(0, 0), lat(1, 0))] = -1j * I0
syst[(lat(0, 1), lat(1, 0))] = -1 * I syst[(lat(0, 1), lat(1, 0))] = -1 * I0
syst[(lat(0, 2), lat(1, 2))] = (-1 + 1j) * I syst[(lat(0, 2), lat(1, 2))] = (-1 + 1j) * I0
return syst return syst
def make_system(lat, N): def make_system(lat, N):
I = ta.identity(lat.norbs) I0 = ta.identity(lat.norbs)
def random_onsite(site, time, salt): def random_onsite(site, time, salt):
return (4 + kwant.digest.uniform(site.tag, salt=salt)) * I return (4 + kwant.digest.uniform(site.tag, salt=salt)) * I0
syst = kwant.Builder() syst = kwant.Builder()
syst[(lat(i, j) for i in range(N) for j in range(N))] = random_onsite syst[(lat(i, j) for i in range(N) for j in range(N))] = random_onsite
syst[lat.neighbors()] = -1 * I syst[lat.neighbors()] = -1 * I0
return syst return syst
def make_td_system(lat, N, td_onsite): def make_td_system(lat, N, td_onsite):
I = ta.identity(2) I0 = ta.identity(2)
syst = kwant.Builder() syst = kwant.Builder()
square = it.product(range(N), range(N)) square = it.product(range(N), range(N))
syst[(lat(i, j) for i, j in square)] = td_onsite syst[(lat(i, j) for i, j in square)] = td_onsite
syst[lat.neighbors()] = -1 * I syst[lat.neighbors()] = -1 * I0
return syst return syst
...@@ -214,12 +214,12 @@ def test_finite_time_dependent(solver_type, kernel_type): ...@@ -214,12 +214,12 @@ def test_finite_time_dependent(solver_type, kernel_type):
N = 10 N = 10
salt = '1' salt = '1'
lat = kwant.lattice.square(norbs=2) lat = kwant.lattice.square(norbs=2)
I = ta.identity(2) I0 = ta.identity(2)
SX = ta.array([[0, 1], [1, 0]]) SX = ta.array([[0, 1], [1, 0]])
uniform = kwant.digest.uniform uniform = kwant.digest.uniform
def td_onsite(site, time, salt): def td_onsite(site, time, salt):
static_part = (4 + uniform(site.tag, salt=salt)) * I static_part = (4 + uniform(site.tag, salt=salt)) * I0
td_part = uniform(site.tag, salt=salt + '1') * SX * cos(time) td_part = uniform(site.tag, salt=salt + '1') * SX * cos(time)
return static_part + td_part return static_part + td_part
...@@ -341,12 +341,12 @@ def test_infinite