diff --git a/kwant/builder.py b/kwant/builder.py index 7d8d3738ba68909379bd1969170b3bd39ea4f90c..abf1cc69541ffc5964771085d6b5d6d3d13b90bf 100644 --- a/kwant/builder.py +++ b/kwant/builder.py @@ -3025,3 +3025,16 @@ class InfiniteVectorizedSystem(_VectorizedFinalizedBuilderMixin, system.Infinite i -= cs j -= cs return super().hamiltonian(i, j, *args, params=params) + + +def is_finite_system(syst): + return isinstance(syst, (FiniteSystem, FiniteVectorizedSystem)) + + +def is_infinite_system(syst): + return isinstance(syst, (FiniteSystem, FiniteVectorizedSystem)) + + +def is_system(syst): + return isinstance(syst, (FiniteSystem, FiniteVectorizedSystem, + InfiniteSystem, InfiniteVectorizedSystem)) diff --git a/kwant/operator.pyx b/kwant/operator.pyx index 5133ebb522994d288ef2c5344cb9d016366f6fa9..8de1cad8aae740a77bdf36b39ee7c0175cb2f02e 100644 --- a/kwant/operator.pyx +++ b/kwant/operator.pyx @@ -24,7 +24,7 @@ from .graph.core cimport EdgeIterator from .graph.core import DisabledFeatureError, NodeDoesNotExistError from .graph.defs cimport gint from .graph.defs import gint_dtype -from .system import InfiniteSystem, Site +from .system import is_infinite, Site from ._common import UserCodeError, get_parameters, deprecate_args @@ -149,7 +149,7 @@ def _get_all_orbs(gint[:, :] where, gint[:, :] site_ranges): def _get_tot_norbs(syst): cdef gint _unused, tot_norbs - is_infinite_system = isinstance(syst, InfiniteSystem) + is_infinite_system = is_infinite(syst) n_sites = syst.cell_size if is_infinite_system else syst.graph.num_nodes _get_orbs(np.asarray(syst.site_ranges, dtype=gint_dtype), n_sites, &tot_norbs, &_unused) @@ -165,7 +165,7 @@ def _normalize_site_where(syst, where): otherwise it should contain integers. """ if where is None: - if isinstance(syst, InfiniteSystem): + if is_infinite(syst): where = list(range(syst.cell_size)) else: where = list(range(syst.graph.num_nodes)) @@ -173,7 +173,7 @@ def _normalize_site_where(syst, where): try: where = [syst.id_by_site[s] for s in filter(where, syst.sites)] except AttributeError: - if isinstance(syst, InfiniteSystem): + if is_infinite(syst): where = [s for s in range(syst.cell_size) if where(s)] else: where = [s for s in range(syst.graph.num_nodes) if where(s)] @@ -187,7 +187,7 @@ def _normalize_site_where(syst, where): where = np.asarray(where, dtype=gint_dtype).reshape(-1, 1) - if isinstance(syst, InfiniteSystem) and np.any(where >= syst.cell_size): + if is_infinite(syst) and np.any(where >= syst.cell_size): raise ValueError('Only sites in the fundamental domain may be ' 'specified using `where`.') if np.any(np.logical_or(where < 0, where >= syst.graph.num_nodes)): @@ -208,7 +208,7 @@ def _normalize_hopping_where(syst, where): if where is None: # we cannot extract the hoppings in the same order as they are in the # graph while simultaneously excluding all inter-cell hoppings - if isinstance(syst, InfiniteSystem): + if is_infinite(syst): raise ValueError('`where` must be provided when calculating ' 'current in an InfiniteSystem.') where = list(syst.graph) @@ -241,7 +241,7 @@ def _normalize_hopping_where(syst, where): where = np.asarray(where, dtype=gint_dtype) - if isinstance(syst, InfiniteSystem) and np.any(where > syst.cell_size): + if is_infinite(syst) and np.any(where > syst.cell_size): raise ValueError('Only intra-cell hoppings may be specified ' 'using `where`.') diff --git a/kwant/plotter.py b/kwant/plotter.py index a9e5ec42793fb37eee4e5908624443cafc1db59d..31a0aba6658b18eb3c79fccc3c0ffcdb2804bedf 100644 --- a/kwant/plotter.py +++ b/kwant/plotter.py @@ -427,7 +427,7 @@ def sys_leads_sites(sys, num_lead_cells=2): lead.builder.sites() for i in range(num_lead_cells))) lead_cells.append(slice(start, len(sites))) - elif isinstance(syst, system.FiniteSystem): + elif system.is_finite(syst): sites = [(i, None, 0) for i in range(syst.graph.num_nodes)] for leadnr, lead in enumerate(syst.leads): start = len(sites) @@ -961,7 +961,7 @@ def plot(sys, num_lead_cells=2, unit='nn', if site_color is None: cycle = _color_cycle() - if isinstance(syst, (builder.FiniteSystem, builder.InfiniteSystem)): + if builder.is_system(syst): # Skipping the leads for brevity. families = sorted({site.family for site in syst.sites}) color_mapping = dict(zip(families, cycle)) @@ -1291,7 +1291,7 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None, if callable(value): value = [value(site[0]) for site in sites] else: - if not isinstance(syst, system.FiniteSystem): + if not system.is_finite(syst): raise ValueError('List of values is only allowed as input ' 'for finalized systems.') value = np.array(value) @@ -1407,7 +1407,7 @@ def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None, "for bands()") syst = sys # for naming consistency inside function bodies - _common.ensure_isinstance(syst, system.InfiniteSystem) + _common.ensure_isinstance(syst, (system.InfiniteSystem, system.InfiniteVectorizedSystem)) momenta = np.array(momenta) if momenta.ndim != 1: @@ -1483,7 +1483,7 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None, if y is not None and not _p.has3d: raise RuntimeError("Installed matplotlib does not support 3d plotting") - if isinstance(syst, system.FiniteSystem): + if system.is_finite(syst): def ham(**kwargs): return syst.hamiltonian_submatrix(params=kwargs, sparse=False) elif callable(syst): @@ -1751,7 +1751,7 @@ def interpolate_current(syst, current, relwidth=None, abswidth=None, n=9): the extents of `field`: ((x0, x1), (y0, y1), ...) """ - if not isinstance(syst, builder.FiniteSystem): + if not builder.is_finite_system(syst): raise TypeError("The system needs to be finalized.") if len(current) != syst.graph.num_edges: @@ -1844,7 +1844,7 @@ def interpolate_density(syst, density, relwidth=None, abswidth=None, n=9, the extents of ``field``: ((x0, x1), (y0, y1), ...) """ - if not isinstance(syst, builder.FiniteSystem): + if not builder.is_finite_system(syst): raise TypeError("The system needs to be finalized.") if len(density) != len(syst.sites): diff --git a/kwant/system.py b/kwant/system.py index a65f9259e03f9e048dfa44f43bf74337e6f84815..a1c287bb97a982353d79410b8c2cd4fc2ca4bbad 100644 --- a/kwant/system.py +++ b/kwant/system.py @@ -572,6 +572,10 @@ class FiniteVectorizedSystem(VectorizedSystem, FiniteSystemMixin, metaclass=abc. pass +def is_finite(syst): + return isinstance(syst, (FiniteSystem, FiniteVectorizedSystem)) + + class InfiniteSystemMixin(metaclass=abc.ABCMeta): """Abstract infinite low-level system. @@ -731,6 +735,10 @@ class InfiniteVectorizedSystem(VectorizedSystem, InfiniteSystemMixin, metaclass= inter_cell_hopping = _system.vectorized_inter_cell_hopping +def is_infinite(syst): + return isinstance(syst, (InfiniteSystem, InfiniteVectorizedSystem)) + + class PrecalculatedLead: def __init__(self, modes=None, selfenergy=None): """A general lead defined by its self energy.