Commit 848b460e authored by Kloss's avatar Kloss

handle edgecase of flat band with tiny noise

parent 5f943a8e
......@@ -24,8 +24,8 @@ def _scale_estimate(onsite_hamiltonian, hopping_elements):
emin : float
"""
norm = np.linalg.norm
emax = norm(onsite_hamiltonian, ord=2) + 2*norm(hopping_elements, ord=2)
emin = norm(onsite_hamiltonian, ord=-2) + 2*norm(hopping_elements, ord=-2)
emax = norm(onsite_hamiltonian, ord=2) + 2 * norm(hopping_elements, ord=2)
emin = norm(onsite_hamiltonian, ord=-2) + 2 * norm(hopping_elements, ord=-2)
return emax, emin
......@@ -85,7 +85,7 @@ def _periodic(k, kmin=-np.pi, kmax=np.pi):
def _intersection(a, b):
"""Finds the intersection :math:`c = a \cap b` of two intervals `a` and `b`.
r"""Finds the intersection :math:`c = a \cap b` of two intervals `a` and `b`.
Parameters
----------
......@@ -230,7 +230,7 @@ def _cubic_coeffs(x, y, dy, axis=0):
def _cubic_interpolation(x, y, dy, axis=0, ext=False):
"""Return a `scipy.interpolate.PPoly` instance for piecewise cubic
r"""Return a `scipy.interpolate.PPoly` instance for piecewise cubic
Hermite interpolation along one direction to a given set function and
derivative values.
......@@ -274,6 +274,27 @@ def _cubic_interpolation(x, y, dy, axis=0, ext=False):
return scipy.interpolate.PPoly(coeffs, x, extrapolate=ext)
def remove_nan(roots):
"""Remove possible `nan` values from `scipy.interpolate.PPoly.roots`
Notes
-----
From `PPoly.roots` docstring:
If the piecewise polynomial contains sections that are identically zero,
the root list will contain the start point of the corresponding interval,
followed by a nan value.
This routine will remove the nan as well as the interval start point,
if there are any.
"""
root_is_nan = np.isnan(roots)
if root_is_nan.any():
for i in range(1, len(root_is_nan), 2):
if root_is_nan[i]:
root_is_nan[i - 1] = True
return roots[~root_is_nan]
def _machine_epsilon_reached(x0, x1):
"""Return `True` if relative difference between `x0` and `x1` is smaller
than machine epsilon for floats"""
......@@ -374,7 +395,7 @@ def _order_left_to_right(fl, fr, xl, xr):
def _cubic_interpolation_error(dx, fl, fc, fr):
"""Estimates the error of a cubic interpolant in an interval
r"""Estimates the error of a cubic interpolant in an interval
:math:`[x_l, x_r]` employing a 3-point rule:
:math:`fi = f(x_i),\, dx = x_r - x_l,\, x_c = (x_r + x_l) / 2`
......@@ -397,9 +418,9 @@ def _cubic_interpolation_error(dx, fl, fc, fr):
delta : float
error estimate for the cubic interpolation function
"""
fm = (fl[0] + fr[0] - 2 * fc[0])/2 + dx/8*(fl[1] - fr[1])
dfm = 3/4*(fr[0] - fl[0]) - dx/8*(fl[1] + fr[1] + 4*fc[1])
return np.sqrt(39*fm*fm + dfm*dfm)
fm = (fl[0] + fr[0] - 2 * fc[0]) / 2 + dx / 8 * (fl[1] - fr[1])
dfm = 3 / 4 * (fr[0] - fl[0]) - dx / 8 * (fl[1] + fr[1] + 4 * fc[1])
return np.sqrt(39 * fm * fm + dfm * dfm)
def _save_ordering(func):
......@@ -517,7 +538,7 @@ def _match_functions(func, xmin=-1, xmax=1, tol=1E-8, min_iter=10,
try:
xs = np.linspace(xmin, xmax, min_iter + 1)
for i in range(min_iter):
order(xs[i], xs[i+1])
order(xs[i], xs[i + 1])
except ValueError as err:
print(err)
......@@ -534,7 +555,7 @@ def _match_functions(func, xmin=-1, xmax=1, tol=1E-8, min_iter=10,
def spectrum(syst, args=(), *, params=None, kmin=-np.pi, kmax=np.pi,
orderpoint=0, tol=1E-8, match=_match_functions):
"""Interpolate the dispersion function and provide methods to
r"""Interpolate the dispersion function and provide methods to
simplify curve sketching and analyzation the periodic spectrum.
Parameters
......@@ -599,7 +620,7 @@ def spectrum(syst, args=(), *, params=None, kmin=-np.pi, kmax=np.pi,
x, y, dy, ordering = match(array_function(bands), kmin, kmax, tol_eff)
# order bands according to their energy at momentum `orderpoint`
band_order = np.argsort(y[np.abs(x-orderpoint).argmin()])
band_order = np.argsort(y[np.abs(x - orderpoint).argmin()])
y = y[:, band_order]
dy = dy[:, band_order]
ordering = ordering[:, band_order]
......@@ -695,7 +716,7 @@ class BandSketching:
self.nbands = len(y[0]) # the total number of bands
def __call__(self, k, band=None, derivative_order=0):
"""Calculate energies :math:`E` (or optionally higher momentum
r"""Calculate energies :math:`E` (or optionally higher momentum
derivatives) for a list of momenta :math:`k`
Parameters
......@@ -824,8 +845,8 @@ class BandSketching:
return self.momentum_to_scattering_mode(k[0], band)
def intersect(self, f, band, derivative_order=0,
kmin=None, kmax=None, tol=None):
"""Returns all momentum (k) points, that solves the equation:
kmin=None, kmax=None, tol=None, ytol=None):
r"""Returns all momentum (k) points, that solves the equation:
:math:`\partial_k^{n} E(k) = f(k),\, k_{min} \leq k \leq k_{max}`.
Parameters
......@@ -845,6 +866,13 @@ class BandSketching:
tol : float, optional
Numerical tolerance, `k` points closer tol are merged to the same
point. Default is the `tol` from initialization.
ytol : float, optional
Numerical tolerance to remove noise if the
spectrum :math:`\partial_k^{n} E(k)` is almost flat.
Values for the spectrum are set to thier mean value
(averaged over all momentum points where the band is sampled), if
they flucutate more than `ytol`.
Default is the `tol` from initialization.
Returns
-------
......@@ -863,6 +891,8 @@ class BandSketching:
kmax = self.kmax
if tol is None:
tol = self.tol
if ytol is None:
ytol = self.tol
# type and input checks
assert _is_type(band, 'integer')
......@@ -870,9 +900,11 @@ class BandSketching:
assert _is_type(kmin, 'real_number')
assert _is_type(kmax, 'real_number')
assert _is_type(tol, 'real_number')
assert _is_type(ytol, 'real_number')
assert 0 <= band < self.nbands, 'band index out of range'
assert kmin <= kmax, 'bounds swapped'
assert tol > 0
assert ytol > 0
# check if the interval [kmin, kmax] is entirely inside the samling
# interval [self.kmin, self.kmax]. if not, apply periodic mapping.
# we have to check however if the required period is smaller or
......@@ -889,16 +921,25 @@ class BandSketching:
dy = self.dy[:, band]
else:
y = self._func(self.x, derivative_order)[:, band]
dy = self._func(self.x, derivative_order+1)[:, band]
dy = self._func(self.x, derivative_order + 1)[:, band]
# filter numerical noise if curve is flat
y_mean = np.mean(y)
if _is_zero(y_mean, ytol):
y_mean = 0
y[np.abs(y - y_mean) < ytol] = y_mean
dy[np.abs(dy) < ytol] = 0
roots = self.interpolation(self.x, y - f, dy).roots()
roots = remove_nan(roots)
if self.period:
roots = _periodic(roots, *self.period)
try:
period_len = self.period[1] - self.period[0]
npe = int((kmax - kmin) // period_len) # number of periods
images = period_len * np.arange(-npe+1, npe).reshape(2*npe-1, -1)
images = period_len * np.arange(-npe + 1, npe).reshape(2 * npe - 1, -1)
roots = _unique(np.sort(np.ravel(roots + images)), tol)
except: # no periodic image, npe = 0
roots = _unique(np.sort(roots), tol)
......@@ -907,7 +948,7 @@ class BandSketching:
def intervals(self, band, derivative_order=0, lower=None, upper=None,
kmin=None, kmax=None, tol=None):
"""returns a list of momentum (`k`) intervals that solves the equation
r"""returns a list of momentum (`k`) intervals that solves the equation
:math:`\text{lower} \leq \partial^m_k E_n(k) \leq \text{upper}`
Parameters
......@@ -971,5 +1012,5 @@ class BandSketching:
crossings = _unique(np.sort(crossings), tol)
return [x for x in _pairwise(crossings)
if (not _is_zero(x[1]-x[0], tol)
if (not _is_zero(x[1] - x[0], tol)
and _is_inside(f(x), lower, upper))]
......@@ -261,6 +261,33 @@ def test_band_analysis_methods():
assert_array_almost_equal(ba.y, energies)
def test_spectrum_with_flat_band():
# edgecase with tree bands, where the second band is flat but with tiny noise
def make_lead_with_flat_band(l=3):
lat = kwant.lattice.square(norbs=1)
lead = kwant.Builder(kwant.TranslationalSymmetry((-1, 1)))
lead[[lat(0, j) for j in range(l)]] = 0
lead[lat.neighbors()] = -1
return lead
lead = make_lead_with_flat_band().finalized()
spectrum = ks.spectrum(lead)
assert_array_almost_equal([0], spectrum.intersect(f=0, band=0, derivative_order=1))
assert spectrum.intersect(f=0, band=1, derivative_order=1).size == 0
assert_array_almost_equal([0], spectrum.intersect(f=0, band=2, derivative_order=1))
# the spectrum has also no wendepunkt
assert spectrum.intersect(f=0, band=0, derivative_order=2).size == 0
assert spectrum.intersect(f=0, band=1, derivative_order=2).size == 0
assert spectrum.intersect(f=0, band=2, derivative_order=2).size == 0
# one can lower the tolerance to recover the result with the numerical noise
assert spectrum.intersect(f=0, band=1, derivative_order=1, ytol=1E-20).size > 0
assert spectrum.intersect(f=0, band=1, derivative_order=2, ytol=1E-20).size > 0
def test_function_mapping():
def func(xx):
# the two functions in f(x) cross at
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment