Commit 5a7fd483 authored by Kloss's avatar Kloss
Browse files

update default argument onebody solver

parent bc63427c
......@@ -612,8 +612,10 @@ example for ``manybody.State()``:
.. jupyter-execute::
solver_type = functools.partial(tkwant.onebody.solvers.default, rtol=1E-5)
onebody_wavefunction = functools.partial(tkwant.onebody.WaveFunction.from_kwant, solver_type=solver_type)
state = manybody.State(syst, tmax=10, onebody_wavefunction_type=onebody_wavefunction)
onebody_wavefunction_type = functools.partial(tkwant.onebody.WaveFunction.from_kwant, solver_type=solver_type)
scattering_state_type = functools.partial(tkwant.onebody.ScatteringStates,
wavefunction_type=onebody_wavefunction_type)
state = manybody.State(syst, tmax=10, scattering_state_type=scattering_state_type)
A similar strategy is possible to change the onebody kernels
``onebody.kernels`` that evaluate the right-hand-side of the one-body
......
[pytest]
# add coverage report, source analyzer, and pep8 compliance checks
addopts = --cov=tkwant --cov-config=.coveragerc --flakes --pep8
addopts = --cov=tkwant --cov-config=.coveragerc --flake8
testpaths = tkwant
pep8ignore =
flake8-ignore =
E266 # multiple comment characters
E501 # lines too long
E402 # module level import not at top of file
W503 # line break before binary operator
W504 # line break after binary operator
tkwant/line_segment.py ALL # ignore third party code
flakes-ignore =
tkwant/onebody/__init__.py UndefinedName # programatic imports confuse flakes
tkwant/line_segment.py ALL # ignore third party code
tkwant/__init__.py F401 # __version__ import confuse flakes
tkwant/manybody.py E722 # do not use bare 'except'
tkwant/onebody/__init__.py F821 # programatic imports confuse flakes
markers =
integtest: marks tests as (slow) integration tests (run them with '--integtest')
mpitest: marks tests as mpi tests (run them with '--mpitest')
......
......@@ -36,5 +36,5 @@ del module # remove cruft from namespace
def test(verbose=True):
"""Run tkwant's unit tests."""
import pytest
return pytest.main([os.path.dirname(os.path.abspath(__file__)), '-s'] +
(['-v'] if verbose else []))
return pytest.main([os.path.dirname(os.path.abspath(__file__)), '-s']
+ (['-v'] if verbose else []))
......@@ -14,7 +14,11 @@ import inspect
import numpy as np
__all__ = ['version', 'TkwantDeprecationWarning', 'is_type', 'is_type_array',
'is_not_empty', 'is_zero']
'is_not_empty', 'is_zero', 'time_start', 'time_name']
# tkwant's default initial time and time argument name
time_start = 0
time_name = 'time'
package_root = os.path.dirname(os.path.realpath(__file__))
distr_root = os.path.dirname(package_root)
......
......@@ -145,6 +145,7 @@ def log_func(logger, funcname=''):
return wrapper
return decorate
# two predefined handlers for tkwant
"""logging handler with format: level:module-name:line-number:MPI-rank: message"""
......
......@@ -36,12 +36,7 @@ __all__ = ['add_voltage', 'SimpleBoundary', 'MonomialAbsorbingBoundary',
# set module logger
def _get_mpi_rank():
"""Return the mpi rank of tkwants global communicator as a dict value"""
rank = mpi.get_communicator().rank
return {'rank': rank}
logger = _logging.make_logger(name=__name__, info=_get_mpi_rank)
logger = _logging.make_logger(name=__name__, info=mpi.mpi_rank)
log_func = _logging.log_func(logger)
......@@ -256,8 +251,8 @@ class SimpleBoundary(BoundaryBase):
tmax = self.tmax or self.num_total_cells / max_velocity
# add 1 to prevent off-by-one error due to inter-cell hoppings
num_cells = (self.num_total_cells or
int(np.ceil(self.tmax * max_velocity)) + 1)
num_cells = (self.num_total_cells
or int(np.ceil(self.tmax * max_velocity)) + 1)
assert tmax is not None
......
......@@ -543,7 +543,7 @@ def _gauss_from_coefficients_numpy(alpha, beta):
beta = beta.astype(numpy.float64)
x, V = eig_banded(numpy.vstack((numpy.sqrt(beta), alpha)), lower=False)
w = beta[0] * scipy.real(scipy.power(V[0, :], 2))
w = beta[0] * numpy.real(numpy.lib.scimath.power(V[0, :], 2))
# eigh_tridiagonal is only available from scipy 1.0.0, and has problems
# with precision. TODO find out how/why/what
# try:
......
......@@ -30,12 +30,7 @@ __all__ = ['Occupation', 'Interval', 'lead_occupation',
# set module logger
def _get_mpi_rank():
"""Return the mpi rank of tkwants global communicator as a dict value"""
rank = mpi.get_communicator().rank
return {'rank': rank}
logger = _logging.make_logger(name=__name__, info=_get_mpi_rank)
logger = _logging.make_logger(name=__name__, info=mpi.mpi_rank)
log_func = _logging.log_func(logger)
# TODO: reintroduce dataclasses at some point when python 3.7 is stable
......@@ -211,11 +206,11 @@ def lead_occupation(chemical_potential=0, temperature=0, energy_range=None,
def _check_energies(emin, emax):
if emin is not None and not _common.is_type(emin, 'real_number'):
raise TypeError('emin={} in energy_range not a real number.'
.format(emin))
raise TypeError('emin={} in energy_range not a real number.'
.format(emin))
if emax is not None and not _common.is_type(emax, 'real_number'):
raise TypeError('emax={} in energy_range not a real number.'
.format(emax))
raise TypeError('emax={} in energy_range not a real number.'
.format(emax))
if emin is not None and emax is not None:
if emax < emin:
raise ValueError('emin={} > emax={}'.format(emin, emax))
......@@ -866,7 +861,7 @@ def _calc_modes_and_weights(interval, distribution, spectrum):
@log_func
def calc_initial_state(syst, tasks, boundaries, params=None,
onebody_wavefunction_type=onebody.WaveFunction.from_kwant,
scattering_state_type=onebody.ScatteringStates,
mpi_distribute=mpi.round_robin, comm=None):
"""Calculate the initial manybody scattering wave function using MPI.
......@@ -888,8 +883,9 @@ def calc_initial_state(syst, tasks, boundaries, params=None,
params : dict, optional
Extra arguments to pass to the Hamiltonian of ``syst``,
excluding time.
onebody_wavefunction_type : `tkwant.onebody.WaveFunction`, optional
Class to evolve a single-particle wavefunction in time.
scattering_state_type : `tkwant.onebody.ScatteringStates`, optional
Class to calculate time-dependent onebody wavefunctions starting in an
equilibrium scattering state.
mpi_distribute : callable, optional
Function to distribute the tasks dict keys over all MPI ranks.
By default, keys must be integer and are distributed round-robin like.
......@@ -909,9 +905,8 @@ def calc_initial_state(syst, tasks, boundaries, params=None,
"""Calculate a one-body scattering state that can be evolved in time"""
logger.debug('calc scattering state: energy={}, lead={}'.format(energy, lead))
try:
return onebody.ScatteringStates(syst, energy, lead,
params=params, boundaries=boundaries,
wavefunction_type=onebody_wavefunction_type)
return scattering_state_type(syst, energy, lead, params=params,
boundaries=boundaries)
except Exception:
raise RuntimeError('scattering state calculation failed for '
'energy={}, lead={}'.format(energy, lead))
......@@ -1303,7 +1298,7 @@ class State:
def __init__(self, syst, tmax=None, occupations=None, params=None,
spectra=None, boundaries=None, intervals=Interval,
refine=True, combine=False, error_estimate_operator=None,
onebody_wavefunction_type=onebody.WaveFunction.from_kwant,
scattering_state_type=onebody.ScatteringStates,
manybody_wavefunction_type=WaveFunction,
mpi_distribute=mpi.round_robin, comm=None):
r"""
......@@ -1352,7 +1347,9 @@ class State:
excluding time.
spectra : sequence of `~kwant_spectrum.spectrum`, optional
Energy dispersion :math:`E_n(k)` for the leads. Must have
the same length as ``syst.leads``.
the same length as ``syst.leads``. Required only if
no ``boundaries`` are provided. If needed but not present,
it will be calculated on the fly from `syst.leads`.
boundaries : sequence of `~tkwant.leads.BoundaryBase`, optional
The boundary conditions for each lead attached to ``syst``.
Must have the same length as ``syst.leads``.
......@@ -1387,8 +1384,11 @@ class State:
Observable used for the quadrature error estimate.
Must have the calling signature of `kwant.operator`.
Default: Error estimate with density expectation value.
onebody_wavefunction_type : `tkwant.onebody.WaveFunction`, optional
Class to evolve a single-particle wavefunction in time.
scattering_state_type : `tkwant.onebody.ScatteringStates`, optional
Class to calculate time-dependent onebody wavefunctions starting in
an equilibrium scattering state. Name of the time argument and
initial time are taken from this class. If this is not possible,
default values are used as a fallback.
manybody_wavefunction_type : `tkwant.manybody.WaveFunction`, optional
Class to evolve a many-particle wavefunction in time.
mpi_distribute : callable, optional
......@@ -1412,11 +1412,22 @@ class State:
if tmax is not None and boundaries is not None:
raise ValueError("'boundaries' and 'tmax' are mutually exclusive.")
# get initial time and time argument name from WaveFunction
time_name = _common.get_default_function_argument(onebody_wavefunction_type,
'time_name')
time_start = _common.get_default_function_argument(onebody_wavefunction_type,
'time_start')
# get initial time and time argument name from the onebody wavefunction
try:
default_arg = _common.get_default_function_argument
onebody_wavefunction_type = default_arg(scattering_state_type,
'wavefunction_type')
time_name = default_arg(onebody_wavefunction_type, 'time_name')
time_start = default_arg(onebody_wavefunction_type, 'time_start')
except Exception:
time_name = _common.time_name
time_start = _common.time_start
onebody_wavefunction_type = None
logger.warning('retrieving initial time and time argument name from',
'the onebody wavefunction failed, use default values: ',
'"time_name"={}, "time_start"={}'.format(time_name,
time_start))
# add initial time to the params dict
tparams = add_time_to_params(params, time_name=time_name,
time=time_start, check_numeric_type=True)
......@@ -1461,6 +1472,7 @@ class State:
self.occupations = occupations
self.mpi_distribute = mpi_distribute
self.onebody_wavefunction_type = onebody_wavefunction_type
self.scattering_state_type = scattering_state_type
# no public params attribute exists for the manybody state.
# each individual one-body state holds its own parameters
......@@ -1543,8 +1555,8 @@ class State:
initial manybody state.
"""
return calc_initial_state(self.syst, tasks, self.boundaries,
self._params, self.onebody_wavefunction_type,
self.mpi_distribute, comm)
self._params, self.scattering_state_type,
mpi_distribute=self.mpi_distribute, comm=comm)
def _get_keys_from_interval(self, interval):
"""Return a list of all keys corresponding to an interval.
......@@ -2188,6 +2200,11 @@ class State:
``boundstate_psi`` (all keys must be identical)
and must be the same on all MPI ranks.
"""
# if onebody_wavefunction_type cannot be extracted from
# scattering_state_type, this routine fails, do at least some warning
if self.onebody_wavefunction_type is None:
logger.warning("wavefunction type is None, ",
"provided boundstates must be time dependent")
make_boundstates_time_dependent(boundstate_psi, boundstate_tasks,
self.syst, self.boundaries,
self._params,
......
......@@ -13,16 +13,17 @@ import dill
from . import _logging
__all__ = ['communicator_init', 'communicator_free', 'get_communicator',
'distribute_dict', 'DistributedDict', 'round_robin']
'distribute_dict', 'DistributedDict', 'round_robin', 'mpi_rank']
# set module logger
def _get_mpi_rank():
def mpi_rank():
"""Return the mpi rank of tkwants global communicator as a dict value"""
rank = _COMM.rank
rank = get_communicator().rank
return {'rank': rank}
logger = _logging.make_logger(name=__name__, info=_get_mpi_rank)
logger = _logging.make_logger(name=__name__, info=mpi_rank)
_COMM = None # the global MPI communicator used by tkwant by default.
......
......@@ -16,7 +16,7 @@ import scipy.sparse as sp
import kwant
import kwant_spectrum
from .. import leads, _common
from .. import leads, _common, mpi, _logging
from ..system import (extract_perturbation, hamiltonian_with_boundaries,
add_time_to_params)
from . import kernels, solvers
......@@ -24,8 +24,12 @@ from . import kernels, solvers
__all__ = ['WaveFunction', 'ScatteringStates', 'Task']
# data formats
# set module logger
logger = _logging.make_logger(name=__name__, info=mpi.mpi_rank)
log_func = _logging.log_func(logger)
# data formats
class Task:
"""Data format to store the set of quantum numbers that uniquely indentifies
......@@ -117,7 +121,7 @@ class WaveFunction:
"""
def __init__(self, H0, W, psi_init, energy=None, params=None,
solution_is_valid=None, time_is_valid=None,
time_start=0, time_name='time',
time_start=_common.time_start, time_name=_common.time_name,
kernel_type=kernels.default, solver_type=solvers.default):
r"""
Parameters
......@@ -204,8 +208,8 @@ class WaveFunction:
self.energy = energy
@classmethod
def from_kwant(cls, syst, psi_init, boundaries=None, energy=None,
params=None, time_start=0, time_name='time',
def from_kwant(cls, syst, psi_init, boundaries=None, energy=None, params=None,
time_start=_common.time_start, time_name=_common.time_name,
kernel_type=kernels.default, solver_type=solvers.default):
"""Set up a time-dependent onebody wavefunction from a kwant system.
......@@ -379,15 +383,20 @@ class ScatteringStates(collections.abc.Iterable):
exclusive with 'boundaries'.
params : dict, optional
Extra arguments to pass to the Hamiltonian of ``syst``, excluding time.
spectra : sequence of `kwant_spectrum.spectrum`
Energy dispersion :math:`E_n(k)` for all leads.
spectra : sequence of `kwant_spectrum.spectrum`, optional
Energy dispersion :math:`E_n(k)` for the leads. Must have
the same length as ``syst.leads``. Required only if
no ``boundaries`` are provided. If needed but not present,
it will be calculated on the fly from `syst.leads`.
boundaries : sequence of `~tkwant.leads.BoundaryBase`, optional
The boundary conditions for each lead attached to ``syst``. Mutually
exclusive with 'tmax'.
equilibrium_solver : `kwant.wave_function`, optional
Solver for initial equilibrium scattering problem.
WaveFunction_type : `WaveFunction`, optional
One-body time-dependent wave function.
wavefunction_type : `WaveFunction`, optional
One-body time-dependent wave function. Name of the time argument and
initial time are taken from this class. If this is not possible,
default values are used as a fallback.
Notes
-----
......@@ -410,11 +419,19 @@ class ScatteringStates(collections.abc.Iterable):
if lead >= len(syst.leads):
raise ValueError("lead index must be smaller than {}.".format(len(syst.leads)))
# get initial time and time argument name from WaveFunction
time_name = _common.get_default_function_argument(wavefunction_type,
'time_name')
time_start = _common.get_default_function_argument(wavefunction_type,
'time_start')
# get initial time and time argument name from the onebody wavefunction
try:
time_name = _common.get_default_function_argument(wavefunction_type,
'time_name')
time_start = _common.get_default_function_argument(wavefunction_type,
'time_start')
except Exception:
time_name = _common.time_name
time_start = _common.time_start
logger.warning('retrieving initial time and time argument name from',
'the onebody wavefunction failed, use default values: ',
'"time_name"={}, "time_start"={}'.format(time_name,
time_start))
# add initial time to the params dict
tparams = add_time_to_params(params, time_name=time_name,
......
......@@ -83,43 +83,43 @@ def all_solvers(): # test only with the pure python kernel
def make_simple_lead(lat, N):
I = ta.identity(lat.norbs)
I0 = ta.identity(lat.norbs)
syst = kwant.Builder(kwant.TranslationalSymmetry((-1, 0)))
syst[(lat(0, j) for j in range(N))] = 4 * I
syst[lat.neighbors()] = -1 * I
syst[(lat(0, j) for j in range(N))] = 4 * I0
syst[lat.neighbors()] = -1 * I0
return syst
def make_complex_lead(lat, N):
I = ta.identity(lat.norbs)
I0 = ta.identity(lat.norbs)
syst = kwant.Builder(kwant.TranslationalSymmetry((-1, 0)))
syst[(lat(0, j) for j in range(N))] = 4 * I
syst[kwant.HoppingKind((0, 1), lat)] = -1 * I
syst[(lat(0, 0), lat(1, 0))] = -1j * I
syst[(lat(0, 1), lat(1, 0))] = -1 * I
syst[(lat(0, 2), lat(1, 2))] = (-1 + 1j) * I
syst[(lat(0, j) for j in range(N))] = 4 * I0
syst[kwant.HoppingKind((0, 1), lat)] = -1 * I0
syst[(lat(0, 0), lat(1, 0))] = -1j * I0
syst[(lat(0, 1), lat(1, 0))] = -1 * I0
syst[(lat(0, 2), lat(1, 2))] = (-1 + 1j) * I0
return syst
def make_system(lat, N):
I = ta.identity(lat.norbs)
I0 = ta.identity(lat.norbs)
def random_onsite(site, time, salt):
return (4 + kwant.digest.uniform(site.tag, salt=salt)) * I
return (4 + kwant.digest.uniform(site.tag, salt=salt)) * I0
syst = kwant.Builder()
syst[(lat(i, j) for i in range(N) for j in range(N))] = random_onsite
syst[lat.neighbors()] = -1 * I
syst[lat.neighbors()] = -1 * I0
return syst
def make_td_system(lat, N, td_onsite):
I = ta.identity(2)
I0 = ta.identity(2)
syst = kwant.Builder()
square = it.product(range(N), range(N))
syst[(lat(i, j) for i, j in square)] = td_onsite
syst[lat.neighbors()] = -1 * I
syst[lat.neighbors()] = -1 * I0
return syst
......@@ -214,12 +214,12 @@ def test_finite_time_dependent(solver_type, kernel_type):
N = 10
salt = '1'
lat = kwant.lattice.square(norbs=2)
I = ta.identity(2)
I0 = ta.identity(2)
SX = ta.array([[0, 1], [1, 0]])
uniform = kwant.digest.uniform
def td_onsite(site, time, salt):
static_part = (4 + uniform(site.tag, salt=salt)) * I
static_part = (4 + uniform(site.tag, salt=salt)) * I0
td_part = uniform(site.tag, salt=salt + '1') * SX * cos(time)
return static_part + td_part
......@@ -341,12 +341,12 @@ def test_infinite_time_dependent(solver_type, kernel_type):
N = 10
salt = '1'
lat = kwant.lattice.square(norbs=2)
I = ta.identity(2)
I0 = ta.identity(2)
SZ = ta.array([[1, 0], [0, -1]])
uniform = kwant.digest.uniform
def td_onsite(site, time, salt):
static_part = (4 + uniform(site.tag, salt=salt)) * I
static_part = (4 + uniform(site.tag, salt=salt)) * I0
td_part = SZ * sin(time)
return static_part + td_part
......@@ -997,7 +997,7 @@ def test_time_argument_name_and_initial_time_change(solver_type, kernel_type):
N = 10
lat = kwant.lattice.square(norbs=2)
I = ta.identity(2)
I0 = ta.identity(2)
SX = ta.array([[0, 1], [1, 0]])
uniform = kwant.digest.uniform
......@@ -1007,7 +1007,7 @@ def test_time_argument_name_and_initial_time_change(solver_type, kernel_type):
# and the initial time starts at zeit = 42
def td_onsite_time(site, time, salt):
static_part = (4 + uniform(site.tag, salt=salt)) * I
static_part = (4 + uniform(site.tag, salt=salt)) * I0
td_part = uniform(site.tag, salt=salt + '1') * SX * cos(time)
return static_part + td_part
......@@ -1015,7 +1015,7 @@ def test_time_argument_name_and_initial_time_change(solver_type, kernel_type):
def td_onsite_zeit(site, zeit, salt):
time = zeit - zeit_offset
static_part = (4 + uniform(site.tag, salt=salt)) * I
static_part = (4 + uniform(site.tag, salt=salt)) * I0
td_part = uniform(site.tag, salt=salt + '1') * SX * cos(time)
return static_part + td_part
......@@ -1080,7 +1080,7 @@ def test_time_argument_name_and_initial_time_change(solver_type, kernel_type):
def test_time_name_and_start_default_change_scattering_state():
N = 10
lat = kwant.lattice.square(norbs=2)
I = ta.identity(2)
I0 = ta.identity(2)
SZ = ta.array([[1, 0], [0, -1]])
uniform = kwant.digest.uniform
......@@ -1092,7 +1092,7 @@ def test_time_name_and_start_default_change_scattering_state():
params = {'salt': '1'}
def td_onsite_time(site, time, salt):
static_part = (4 + uniform(site.tag, salt=salt)) * I
static_part = (4 + uniform(site.tag, salt=salt)) * I0
td_part = SZ * sin(time)
return static_part + td_part
......@@ -1103,7 +1103,7 @@ def test_time_name_and_start_default_change_scattering_state():
def td_onsite_zeit(site, zeit, salt):
time = zeit - zeit_offset
static_part = (4 + uniform(site.tag, salt=salt)) * I
static_part = (4 + uniform(site.tag, salt=salt)) * I0
td_part = SZ * sin(time)
return static_part + td_part
......
......@@ -567,10 +567,10 @@ def test_automatic_boundary_for_lead_with_flat_band():
# whereas for long times, the algorithm should select absorbing boundaries
# edgecase with tree bands, where the second band is flat but with tiny noise
def make_lead_with_flat_band(l=3):
def make_lead_with_flat_band(ll=3):
lat = kwant.lattice.square(norbs=1)
lead = kwant.Builder(kwant.TranslationalSymmetry((-1, 1)))
lead[[lat(0, j) for j in range(l)]] = 0
lead[[lat(0, j) for j in range(ll)]] = 0
lead[lat.neighbors()] = -1
return lead
......
......@@ -79,11 +79,11 @@ def make_system(a=1, gamma=1.0, W=1, L=30):
def make_td_system(lat, N, td_onsite):
I = ta.identity(2)
I0 = ta.identity(2)
syst = kwant.Builder()
square = itertools.product(range(N), range(N))
syst[(lat(i, j) for i, j in square)] = td_onsite
syst[lat.neighbors()] = -1 * I
syst[lat.neighbors()] = -1 * I0
return syst
......@@ -103,10 +103,10 @@ def two_band_spectrum():
@pytest.fixture
def flat_band_spectrum():
def make_lead_with_flat_band(l=3):
def make_lead_with_flat_band(ll=3):
lat = kwant.lattice.square(norbs=1)
lead = kwant.Builder(kwant.TranslationalSymmetry((-1, 1)))
lead[[lat(0, j) for j in range(l)]] = 0
lead[[lat(0, j) for j in range(ll)]] = 0
lead[lat.neighbors()] = -1
return lead
......@@ -1199,46 +1199,22 @@ def test_calc_initial_state():
@pytest.mark.integtest
def test_calc_initial_state_with_kwant_failing():
# a system a flat band (example from Adel Kara Slimane)
def make_system(L=5, l=3):
lat = kwant.lattice.square(norbs=1)
syst = kwant.Builder()
for i in range(L):
for j in range(L):
syst[lat(i, j)] = 0
syst[lat.neighbors()] = -1
# Define and attach the leads
lead1 = kwant.Builder(kwant.TranslationalSymmetry((-1, 1)))
lead1[[lat(0, j) for j in range(l)]] = 0
lead1[lat.neighbors()] = -1
lead2 = kwant.Builder(kwant.TranslationalSymmetry((1, 1)))
lead2[[lat(j - 4, 0) for j in range(l)]] = 0
lead2[lat.neighbors()] = -1
lead3 = kwant.Builder(kwant.TranslationalSymmetry((1, -1)))
lead3[[lat(0, j + 2) for j in range(l)]] = 0
lead3[lat.neighbors()] = -1
syst.attach_lead(lead1, add_cells=2)
syst.attach_lead(lead2, add_cells=2)
syst.attach_lead(lead3, add_cells=2)
return syst
syst = make_system().finalized()
occupation = manybody.lead_occupation(chemical_potential=0.5)
occupation = manybody.lead_occupation(chemical_potential=3)
spectrum = kwant_spectrum.spectra(syst.leads)
intervals = manybody.calc_intervals(spectrum, occupation)
# we have to set the tol value so low that states, for which
# the kwant solver will fail
tasks = manybody.calc_tasks(intervals, spectrum, occupation, tol=1E-30)
tasks = manybody.calc_tasks(intervals, spectrum, occupation)
boundaries = [leads.SimpleBoundary(tmax=1)] * len(syst.leads)
class ScatteringStatesRaise:
"""raise with a runtime error"""
def __init__(self, *args, **kwargs):
raise RuntimeError()
# we just mock that state calculation fails
with pytest.raises(RuntimeError) as exc:
manybody.calc_initial_state(syst, tasks, boundaries)
manybody.calc_initial_state(syst, tasks, boundaries,