Commit 20e622e0 authored by Kloss's avatar Kloss
Browse files

rename error operator to error_op

parent d6bcfcbf
......@@ -589,7 +589,7 @@ can change to a different observable as follows:
.. jupyter-execute::
current_operator = kwant.operator.Current(syst)
state = manybody.State(syst, tmax=10, error_estimate_operator=current_operator)
state = manybody.State(syst, tmax=10, error_op=current_operator)
.. jupyter-execute::
......
......@@ -1299,7 +1299,7 @@ class State:
def __init__(self, syst, tmax=None, occupations=None, params=None,
spectra=None, boundaries=None, intervals=Interval,
refine=True, combine=False, error_estimate_operator=None,
refine=True, combine=False, error_op=None,
scattering_state_type=onebody.ScatteringStates,
manybody_wavefunction_type=WaveFunction,
mpi_distribute=mpi.round_robin, comm=None):
......@@ -1382,7 +1382,7 @@ class State:
If `True`, intervals are refined at the initial time.
combine : bool, optional
If `True`, intervals are grouped by lead indices.
error_estimate_operator : callable or `kwant.operator`, optional
error_op : callable or `kwant.operator`, optional
Observable used for the quadrature error estimate.
Must have the calling signature of `kwant.operator`.
Default: Error estimate with density expectation value.
......@@ -1508,12 +1508,12 @@ class State:
# our convention has to be the last element of the weight array
self.return_element = -1
if error_estimate_operator is None:
if error_op is None:
logger.info('set default error estimate based on density')
error_estimate_operator = kwant.operator.Density(syst)
error_op = kwant.operator.Density(syst)
else:
logger.info('set error estimate based on user given operator')
self.error_estimate_operator = error_estimate_operator
self.error_op = error_op
if refine:
self.refine_intervals()
......@@ -1687,7 +1687,7 @@ class State:
return sorted(intervals, key=lambda x: (x.lead, x.band, x.kmin))
def refine_intervals(self, atol=1E-5, rtol=1E-5, limit=200,
operator=None, intervals=None):
error_op=None, intervals=None):
r"""Refine intervals until the quadrature error is below tolerance.
Parameters
......@@ -1699,10 +1699,10 @@ class State:
limit : integer, optional
Maximum number of intervals stored in the solver. A warning is
raised and the refinement stops if limit is reached.
operator : callable or `kwant.operator`, optional
error_op : callable or `kwant.operator`, optional
Observable used for the quadrature error estimate.
Must have the calling signature of `kwant.operator`.
Default: ``error_estimate_operator`` from initialization.
Default: ``error_op`` from initialization.
intervals : sequence of `tkwant.manybody.Interval`, optional
Apply the refinement process only to the intervals
given in the sequence. Note that in this case, all intervals
......@@ -1714,7 +1714,7 @@ class State:
abserr : float
Estimate of the modulus of the absolute error,
which should equal or exceed abs(i-result), where i is the exact
integral value. If ``operator`` has an array-like output,
integral value. If ``error_op`` has an array-like output,
we report the maximal value of the error.
(sum ``errors`` over all intervals and take the maximum element).
intervals : list of `tkwant.manybody.Interval`
......@@ -1722,9 +1722,9 @@ class State:
according to `errors`.
errors : `~numpy.ndarray` of floats
Error estimates *E(J)* on the intervals in descending order.
If ``operator`` has an array-like output, the error is
If ``error_op`` has an array-like output, the error is
returned on all array points.
The shape of ``errors`` is like ``operator`` (its expectation value)
The shape of ``errors`` is like ``error_op`` (its expectation value)
with an additional first dimension for the interval index.
Notes
......@@ -1760,11 +1760,11 @@ class State:
\frac{dk}{2 \pi} v_{\alpha}(k) \theta(v_{\alpha}(k)) f_\alpha(k)
[\psi_{\alpha, k}(t)]_i \hat{A} [\psi^\dagger_{\alpha, k}(t)]_j`
where :math:`\hat{A}` corresponds to the ``operator``
where :math:`\hat{A}` corresponds to the ``error_op``
and the error is estimated for the expectation value
:math:`\langle \hat{A}_{ij}(t) \rangle`.
Note that above inequality condition must be fulfilled on each site *i*
and *j* individually. This is the case if ``operator`` generates an
and *j* individually. This is the case if ``error_op`` generates an
array-like output. However, the inequality condition must be fulfilled
only at the current time *t* of the solver.
......@@ -1788,15 +1788,15 @@ class State:
if limit <= 1:
raise ValueError('limit={} must be > 1.'.format(limit))
if operator is None:
operator = self.error_estimate_operator
if error_op is None:
error_op = self.error_op
def observable_with_error(_intervals):
observable = []
errors = []
for interval in _intervals:
error, kronrod = self._error_estimate_quadpack(interval,
operator,
error_op,
return_estimate=True)
observable.append(kronrod)
errors.append(error)
......@@ -1809,6 +1809,10 @@ class State:
if not isinstance(intervals, collections.abc.Iterable):
intervals = [intervals]
# refine only intervals with Gauss-Kronrod quadrature
intervals[:] = [interval for interval in intervals
if interval.quadrature=='kronrod']
results, errors = observable_with_error(intervals)
result = np.sum(results, axis=0) # sum of the integrals over the subintervals
......@@ -1876,7 +1880,7 @@ class State:
return np.max(errsum), intervals, errors
def refine_intervals_local(self, atol=1E-5, rtol=1E-5, limit=200,
operator=None):
error_op=None):
r"""Refine intervals until the quadrature error is below tolerance.
Parameters
......@@ -1888,10 +1892,10 @@ class State:
limit : integer, optional
Maximum number of intervals stored in the solver. A warning is
raised and the refinement stops if limit is reached.
operator : callable or `kwant.operator`, optional
error_op : callable or `kwant.operator`, optional
Observable used for the quadrature error estimate.
Must have the calling signature of `kwant.operator`.
Default: ``error_estimate_operator`` from initialization.
Default: ``error_op`` from initialization.
Notes
......@@ -1902,7 +1906,7 @@ class State:
:math:`|I_n - I_{2 n+1}| <= atol + rtol |I_{ges}|`, where
:math:`I_n` is the integral estimate over an interval with order *n*.
Moreover, :math:`I_{ges} = \sum{I_{2 n +1}}` and the sum runs over
all stored intervals. Note that if the operator has a site dependent
all stored intervals. Note that if `error_op` has a site dependent
array output, the criterion must be fulfilled at each site
individually. One-body states that are not part of an interval
are not altered by this method.
......@@ -1914,21 +1918,21 @@ class State:
if rtol < 0:
raise ValueError('rtol={} is negative.'.format(rtol))
if operator is None:
operator = self.error_estimate_operator
if error_op is None:
error_op = self.error_op
# loop until converged
i = 0
while True:
i += 1
tol = atol + rtol * np.abs(self.evaluate(operator, root=None))
tol = atol + rtol * np.abs(self.evaluate(error_op, root=None))
intervals_to_refine = []
intervals = self.get_intervals()
for interval in intervals:
error = self._error_estimate_gauss_kronrod(interval, operator)
error = self._error_estimate_gauss_kronrod(interval, error_op)
if np.where(error > tol, True, False).any():
logger.info("refine step={}, max error={}, interval={}".
......@@ -1954,29 +1958,30 @@ class State:
assert self.manybody_wavefunction._check_consistency()
self.evolve(time=self.time)
def _error_estimate_quadpack(self, interval, operator, return_estimate=False):
def _error_estimate_quadpack(self, interval, error_op, return_estimate=False):
r"""Error estimate for an integration quadrature.
Parameters
----------
intervals : `tkwant.manybody.Interval`
interval : `tkwant.manybody.Interval`
Integration interval with momentum boundaries *[a,b]*.
operator : callable or `kwant.operator`
The attribute `interval.quadrature` attribute must be "kronrod".
error_op : callable or `kwant.operator`
Observable used for the quadrature error estimate.
Must have the calling signature of `kwant.operator`.
return_estimate :bool, optional
If true, return also the expectation value calculated with
``operator``. It corresponds to the Kronrod :math:`K_{2n + 1}`
``error_op``. It corresponds to the Kronrod :math:`K_{2n + 1}`
estimate of the integral.
Returns
-------
error : `~numpy.ndarray` of floats
The quadrature error :math:`\varepsilon` estimated
for the expectation value of the operator.
The output array has the same shape as evaluating ``operator``.
for the expectation value of the error_op.
The output array has the same shape as evaluating ``error_op``.
kronrod : `~numpy.ndarray` of floats, optional
The expectation value of the operator. Only returned
The expectation value of the error operator. Only returned
if ``return_estimate`` is true.
Notes
......@@ -1996,7 +2001,7 @@ class State:
where :math:`G_n[a, b]` is the Gauss and :math:`K_{2n + 1}[a, b]`
the corresponding Kronrod estimate of the integral of *f(x)* over
interval *[a, b]*. In our case, *f(x)* corresponds to the expectation
value of the ``operator`` at the current time of the solver.
value of the ``error_op`` at the current time of the solver.
If the expectation value is an array,
above expression for the error is evaluated point-wise on that array.
......@@ -2008,7 +2013,11 @@ class State:
.. [2] Gonnet, P., A Review of Error Estimation in Adaptive Quadrature,
ACM Computing Surveys, Vol. 44, No. 4, Article 22, (2012).
"""
(gauss, kronrod), func = self._evaluate_interval(interval, operator,
if not interval.quadrature=='kronrod':
raise ValueError('quadpack error estimate works only on '
'Gauss-Kronrod quadrature intervals')
(gauss, kronrod), func = self._evaluate_interval(interval, error_op,
return_integrand=True)
dk = interval.kmax - interval.kmin
......@@ -2028,16 +2037,16 @@ class State:
return error, kronrod
return error
def _error_estimate_gauss_kronrod(self, interval, operator):
def _error_estimate_gauss_kronrod(self, interval, error_op):
r"""Maximal absolute error for the solver result with `quadrature="kronrod"`
Parameters
----------
intervals : `tkwant.manybody.Interval`
Integration interval with momentum boundaries *[a,b]*.
operator : callable or `kwant.operator`
error_op : callable or `kwant.error_op`
Observable used for the quadrature error estimate.
Must have the calling signature of `kwant.operator`.
Must have the calling signature of `kwant.error_op`.
Returns
-------
......@@ -2047,10 +2056,10 @@ class State:
*n* point Gauss and :math:`K_{2 n+1}` the corresponding
*2*n +1* Kronrod estimate.
"""
gauss, kronrod = self._evaluate_interval(interval, operator)
gauss, kronrod = self._evaluate_interval(interval, error_op)
return np.abs(gauss - kronrod)
def estimate_error(self, intervals=None, operator=None, estimate=None,
def estimate_error(self, intervals=None, error_op=None, estimate=None,
full_output=False):
r"""Estimate the numerical error of the integration quadrature.
......@@ -2059,15 +2068,15 @@ class State:
intervals : `tkwant.manybody.Interval` or sequence thereof, optional
If present, the error is estimated on given intervals :math:`I_n`,
otherwise the total error over all integrals is evaluated.
operator : callable or `kwant.operator`, optional
error_op : callable or `kwant.operator`, optional
Observable used for the quadrature error estimate.
Must have the calling signature of `kwant.operator`.
Default: ``error_estimate_operator`` from initialization.
Default: ``error_op`` from initialization.
estimate : callable, optional
Function to estimate an error on an interval :math:`I_n`.
Default: `_error_estimate_quadpack`
full_output : bool, optional
If the expectation value of ``operator`` is an array and
If the expectation value of ``error_op`` is an array and
if ``full_output`` is true, then the error is estimated on each
point of the array.
By default, we only return the maximum value of the array error.
......@@ -2083,7 +2092,7 @@ class State:
in order to get the error from all intervals together.
If ``full_output`` is true, the shape of ``error`` is (N, M) in
the first or (M, ) in the second case, where *M* is the shape of
the array resulting from evaluating ``operator``.
the array resulting from evaluating ``error_op``.
Notes
-----
......@@ -2094,25 +2103,25 @@ class State:
if estimate is None:
estimate = self._error_estimate_quadpack
sum_up = False
if operator is None:
operator = self.error_estimate_operator
if error_op is None:
error_op = self.error_op
if intervals is None:
intervals = self.get_intervals()
sum_up = True
if isinstance(intervals, collections.abc.Iterable):
if full_output:
errors = np.array([estimate(interval, operator)
errors = np.array([estimate(interval, error_op)
for interval in intervals])
else:
errors = np.array([np.max(estimate(interval, operator))
errors = np.array([np.max(estimate(interval, error_op))
for interval in intervals])
if sum_up:
errors = np.sum(errors, axis=0)
else:
if full_output:
errors = estimate(intervals, operator)
errors = estimate(intervals, error_op)
else:
errors = np.max(estimate(intervals, operator))
errors = np.max(estimate(intervals, error_op))
return errors
def evaluate(self, observable, root=0):
......
......@@ -1591,8 +1591,8 @@ def test_manybody_state():
# test for missing input
raises(ValueError, manybody.State, syst=syst, params=params)
# test for system not present
raises(AttributeError, manybody.State, syst=[], tmax=5, params=params) # TODO: maybe catch that error
# test for system not finalized
raises(TypeError, manybody.State, syst=make_system(), tmax=5, params=params)
# test for mutual exclusive arguments
raises(ValueError, manybody.State, syst, tmax, boundaries=boundaries, params=params)
......@@ -1688,43 +1688,43 @@ def test_state_estimate_error():
# --- measure the error with the current operator
# maximum error, summed over all intervals
errors = state.estimate_error(operator=current_operator)
errors = state.estimate_error(error_op=current_operator)
assert errors.shape == ()
# error over array, summed over all intervals
errors = state.estimate_error(operator=current_operator, full_output=True)
errors = state.estimate_error(error_op=current_operator, full_output=True)
assert errors.shape == shape_current
# maximum error per interval
errors = state.estimate_error(operator=current_operator, intervals=intervals)
errors = state.estimate_error(error_op=current_operator, intervals=intervals)
assert errors.shape == (len(intervals),)
# error over array per interval
errors = state.estimate_error(operator=current_operator, intervals=intervals, full_output=True)
errors = state.estimate_error(error_op=current_operator, intervals=intervals, full_output=True)
assert errors.shape == (len(intervals),) + shape_current
# -- an operator with a scalar output
# maximum error, summed over all intervals
errors = state.estimate_error(operator=density_sum)
errors = state.estimate_error(error_op=density_sum)
assert errors.shape == ()
# error over array, summed over all intervals
errors = state.estimate_error(operator=density_sum, full_output=True)
errors = state.estimate_error(error_op=density_sum, full_output=True)
assert errors.shape == shape_density_sum
# maximum error per interval
errors = state.estimate_error(operator=density_sum, intervals=intervals)
errors = state.estimate_error(error_op=density_sum, intervals=intervals)
assert errors.shape == (len(intervals),)
# error over array per interval
errors = state.estimate_error(operator=density_sum, intervals=intervals, full_output=True)
errors = state.estimate_error(error_op=density_sum, intervals=intervals, full_output=True)
assert errors.shape == (len(intervals),) + shape_density_sum
# --- set current operator as default from the beginning
state = manybody.State(syst, tmax, occupation, params=params,
error_estimate_operator=current_operator, refine=False)
error_op=current_operator, refine=False)
errors = state.estimate_error()
assert errors.shape == ()
......
......@@ -637,10 +637,10 @@ def test_quadpack_error_estimate():
state.evolve(time)
# compare if array valued operator and single-valued operator give the same result
error = state._error_estimate_quadpack(interval=interval, operator=density_operator)
error = state._error_estimate_quadpack(interval=interval, error_op=density_operator)
error_array = state._error_estimate_quadpack(interval=interval,
operator=density_operator_array)
error_op=density_operator_array)
assert np.allclose(error, error_array[element])
......@@ -724,7 +724,7 @@ def test_quadpack_adaptive_refinement_of_manybody_state():
result_ref, errors_ref, max_error_per_interval_ref, intervals_ref = tmp
# ----- test for an operator with an array output
tmp = state.refine_intervals(operator=density_operator_array, intervals=interval,
tmp = state.refine_intervals(error_op=density_operator_array, intervals=interval,
atol=1.0e-3, rtol=1.0e-3)
error_sum, intervals, errors = tmp
......@@ -756,7 +756,7 @@ def test_quadpack_adaptive_refinement_of_manybody_state():
# set up a new state, as the old one has refined intervals
state = manybody.State(syst, tmax, occupation, params=params, refine=False)
tmp = state.refine_intervals(operator=density_operator, intervals=interval,
tmp = state.refine_intervals(error_op=density_operator, intervals=interval,
atol=1.0e-3, rtol=1.0e-3)
error_sum, intervals, errors = tmp
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment