Commit 848b460e by Kloss

### handle edgecase of flat band with tiny noise

parent 5f943a8e
 ... ... @@ -24,8 +24,8 @@ def _scale_estimate(onsite_hamiltonian, hopping_elements): emin : float """ norm = np.linalg.norm emax = norm(onsite_hamiltonian, ord=2) + 2*norm(hopping_elements, ord=2) emin = norm(onsite_hamiltonian, ord=-2) + 2*norm(hopping_elements, ord=-2) emax = norm(onsite_hamiltonian, ord=2) + 2 * norm(hopping_elements, ord=2) emin = norm(onsite_hamiltonian, ord=-2) + 2 * norm(hopping_elements, ord=-2) return emax, emin ... ... @@ -85,7 +85,7 @@ def _periodic(k, kmin=-np.pi, kmax=np.pi): def _intersection(a, b): """Finds the intersection :math:c = a \cap b of two intervals a and b. r"""Finds the intersection :math:c = a \cap b of two intervals a and b. Parameters ---------- ... ... @@ -230,7 +230,7 @@ def _cubic_coeffs(x, y, dy, axis=0): def _cubic_interpolation(x, y, dy, axis=0, ext=False): """Return a scipy.interpolate.PPoly instance for piecewise cubic r"""Return a scipy.interpolate.PPoly instance for piecewise cubic Hermite interpolation along one direction to a given set function and derivative values. ... ... @@ -274,6 +274,27 @@ def _cubic_interpolation(x, y, dy, axis=0, ext=False): return scipy.interpolate.PPoly(coeffs, x, extrapolate=ext) def remove_nan(roots): """Remove possible nan values from scipy.interpolate.PPoly.roots Notes ----- From PPoly.roots docstring: If the piecewise polynomial contains sections that are identically zero, the root list will contain the start point of the corresponding interval, followed by a nan value. This routine will remove the nan as well as the interval start point, if there are any. """ root_is_nan = np.isnan(roots) if root_is_nan.any(): for i in range(1, len(root_is_nan), 2): if root_is_nan[i]: root_is_nan[i - 1] = True return roots[~root_is_nan] def _machine_epsilon_reached(x0, x1): """Return True if relative difference between x0 and x1 is smaller than machine epsilon for floats""" ... ... @@ -374,7 +395,7 @@ def _order_left_to_right(fl, fr, xl, xr): def _cubic_interpolation_error(dx, fl, fc, fr): """Estimates the error of a cubic interpolant in an interval r"""Estimates the error of a cubic interpolant in an interval :math:[x_l, x_r] employing a 3-point rule: :math:fi = f(x_i),\, dx = x_r - x_l,\, x_c = (x_r + x_l) / 2 ... ... @@ -397,9 +418,9 @@ def _cubic_interpolation_error(dx, fl, fc, fr): delta : float error estimate for the cubic interpolation function """ fm = (fl[0] + fr[0] - 2 * fc[0])/2 + dx/8*(fl[1] - fr[1]) dfm = 3/4*(fr[0] - fl[0]) - dx/8*(fl[1] + fr[1] + 4*fc[1]) return np.sqrt(39*fm*fm + dfm*dfm) fm = (fl[0] + fr[0] - 2 * fc[0]) / 2 + dx / 8 * (fl[1] - fr[1]) dfm = 3 / 4 * (fr[0] - fl[0]) - dx / 8 * (fl[1] + fr[1] + 4 * fc[1]) return np.sqrt(39 * fm * fm + dfm * dfm) def _save_ordering(func): ... ... @@ -517,7 +538,7 @@ def _match_functions(func, xmin=-1, xmax=1, tol=1E-8, min_iter=10, try: xs = np.linspace(xmin, xmax, min_iter + 1) for i in range(min_iter): order(xs[i], xs[i+1]) order(xs[i], xs[i + 1]) except ValueError as err: print(err) ... ... @@ -534,7 +555,7 @@ def _match_functions(func, xmin=-1, xmax=1, tol=1E-8, min_iter=10, def spectrum(syst, args=(), *, params=None, kmin=-np.pi, kmax=np.pi, orderpoint=0, tol=1E-8, match=_match_functions): """Interpolate the dispersion function and provide methods to r"""Interpolate the dispersion function and provide methods to simplify curve sketching and analyzation the periodic spectrum. Parameters ... ... @@ -599,7 +620,7 @@ def spectrum(syst, args=(), *, params=None, kmin=-np.pi, kmax=np.pi, x, y, dy, ordering = match(array_function(bands), kmin, kmax, tol_eff) # order bands according to their energy at momentum orderpoint band_order = np.argsort(y[np.abs(x-orderpoint).argmin()]) band_order = np.argsort(y[np.abs(x - orderpoint).argmin()]) y = y[:, band_order] dy = dy[:, band_order] ordering = ordering[:, band_order] ... ... @@ -695,7 +716,7 @@ class BandSketching: self.nbands = len(y[0]) # the total number of bands def __call__(self, k, band=None, derivative_order=0): """Calculate energies :math:E (or optionally higher momentum r"""Calculate energies :math:E (or optionally higher momentum derivatives) for a list of momenta :math:k Parameters ... ... @@ -824,8 +845,8 @@ class BandSketching: return self.momentum_to_scattering_mode(k[0], band) def intersect(self, f, band, derivative_order=0, kmin=None, kmax=None, tol=None): """Returns all momentum (k) points, that solves the equation: kmin=None, kmax=None, tol=None, ytol=None): r"""Returns all momentum (k) points, that solves the equation: :math:\partial_k^{n} E(k) = f(k),\, k_{min} \leq k \leq k_{max}. Parameters ... ... @@ -845,6 +866,13 @@ class BandSketching: tol : float, optional Numerical tolerance, k points closer tol are merged to the same point. Default is the tol from initialization. ytol : float, optional Numerical tolerance to remove noise if the spectrum :math:\partial_k^{n} E(k) is almost flat. Values for the spectrum are set to thier mean value (averaged over all momentum points where the band is sampled), if they flucutate more than ytol. Default is the tol from initialization. Returns ------- ... ... @@ -863,6 +891,8 @@ class BandSketching: kmax = self.kmax if tol is None: tol = self.tol if ytol is None: ytol = self.tol # type and input checks assert _is_type(band, 'integer') ... ... @@ -870,9 +900,11 @@ class BandSketching: assert _is_type(kmin, 'real_number') assert _is_type(kmax, 'real_number') assert _is_type(tol, 'real_number') assert _is_type(ytol, 'real_number') assert 0 <= band < self.nbands, 'band index out of range' assert kmin <= kmax, 'bounds swapped' assert tol > 0 assert ytol > 0 # check if the interval [kmin, kmax] is entirely inside the samling # interval [self.kmin, self.kmax]. if not, apply periodic mapping. # we have to check however if the required period is smaller or ... ... @@ -889,16 +921,25 @@ class BandSketching: dy = self.dy[:, band] else: y = self._func(self.x, derivative_order)[:, band] dy = self._func(self.x, derivative_order+1)[:, band] dy = self._func(self.x, derivative_order + 1)[:, band] # filter numerical noise if curve is flat y_mean = np.mean(y) if _is_zero(y_mean, ytol): y_mean = 0 y[np.abs(y - y_mean) < ytol] = y_mean dy[np.abs(dy) < ytol] = 0 roots = self.interpolation(self.x, y - f, dy).roots() roots = remove_nan(roots) if self.period: roots = _periodic(roots, *self.period) try: period_len = self.period[1] - self.period[0] npe = int((kmax - kmin) // period_len) # number of periods images = period_len * np.arange(-npe+1, npe).reshape(2*npe-1, -1) images = period_len * np.arange(-npe + 1, npe).reshape(2 * npe - 1, -1) roots = _unique(np.sort(np.ravel(roots + images)), tol) except: # no periodic image, npe = 0 roots = _unique(np.sort(roots), tol) ... ... @@ -907,7 +948,7 @@ class BandSketching: def intervals(self, band, derivative_order=0, lower=None, upper=None, kmin=None, kmax=None, tol=None): """returns a list of momentum (k) intervals that solves the equation r"""returns a list of momentum (k) intervals that solves the equation :math:\text{lower} \leq \partial^m_k E_n(k) \leq \text{upper} Parameters ... ... @@ -971,5 +1012,5 @@ class BandSketching: crossings = _unique(np.sort(crossings), tol) return [x for x in _pairwise(crossings) if (not _is_zero(x[1]-x[0], tol) if (not _is_zero(x[1] - x[0], tol) and _is_inside(f(x), lower, upper))]
 ... ... @@ -261,6 +261,33 @@ def test_band_analysis_methods(): assert_array_almost_equal(ba.y, energies) def test_spectrum_with_flat_band(): # edgecase with tree bands, where the second band is flat but with tiny noise def make_lead_with_flat_band(l=3): lat = kwant.lattice.square(norbs=1) lead = kwant.Builder(kwant.TranslationalSymmetry((-1, 1))) lead[[lat(0, j) for j in range(l)]] = 0 lead[lat.neighbors()] = -1 return lead lead = make_lead_with_flat_band().finalized() spectrum = ks.spectrum(lead) assert_array_almost_equal([0], spectrum.intersect(f=0, band=0, derivative_order=1)) assert spectrum.intersect(f=0, band=1, derivative_order=1).size == 0 assert_array_almost_equal([0], spectrum.intersect(f=0, band=2, derivative_order=1)) # the spectrum has also no wendepunkt assert spectrum.intersect(f=0, band=0, derivative_order=2).size == 0 assert spectrum.intersect(f=0, band=1, derivative_order=2).size == 0 assert spectrum.intersect(f=0, band=2, derivative_order=2).size == 0 # one can lower the tolerance to recover the result with the numerical noise assert spectrum.intersect(f=0, band=1, derivative_order=1, ytol=1E-20).size > 0 assert spectrum.intersect(f=0, band=1, derivative_order=2, ytol=1E-20).size > 0 def test_function_mapping(): def func(xx): # the two functions in f(x) cross at ... ...
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!