From 4c0593c4d08c73e3d2519a0bf4b16e68117f1189 Mon Sep 17 00:00:00 2001
From: Joseph Weston <joseph@weston.cloud>
Date: Tue, 1 Oct 2019 11:47:44 +0200
Subject: [PATCH] update all relevant 'isinstance' checks on systems

We now have *System and *System2, so the old 'isinstance' checks
will not be complete. We add convenience functions in 'kwant.system'
and 'kwant.builder' to check against all the relvant classes.
We leave the original 'isinstance' checks in 'kwant.physics.gauge'
because that module does not work with vectorize systems yet.
---
 kwant/builder.py   | 13 +++++++++++++
 kwant/operator.pyx | 14 +++++++-------
 kwant/plotter.py   | 14 +++++++-------
 kwant/system.py    |  8 ++++++++
 4 files changed, 35 insertions(+), 14 deletions(-)

diff --git a/kwant/builder.py b/kwant/builder.py
index 7d8d3738..abf1cc69 100644
--- a/kwant/builder.py
+++ b/kwant/builder.py
@@ -3025,3 +3025,16 @@ class InfiniteVectorizedSystem(_VectorizedFinalizedBuilderMixin, system.Infinite
             i -= cs
             j -= cs
         return super().hamiltonian(i, j, *args, params=params)
+
+
+def is_finite_system(syst):
+    return isinstance(syst, (FiniteSystem, FiniteVectorizedSystem))
+
+
+def is_infinite_system(syst):
+    return isinstance(syst, (FiniteSystem, FiniteVectorizedSystem))
+
+
+def is_system(syst):
+    return isinstance(syst, (FiniteSystem, FiniteVectorizedSystem,
+                             InfiniteSystem, InfiniteVectorizedSystem))
diff --git a/kwant/operator.pyx b/kwant/operator.pyx
index 5133ebb5..8de1cad8 100644
--- a/kwant/operator.pyx
+++ b/kwant/operator.pyx
@@ -24,7 +24,7 @@ from .graph.core cimport EdgeIterator
 from .graph.core import DisabledFeatureError, NodeDoesNotExistError
 from .graph.defs cimport gint
 from .graph.defs import gint_dtype
-from .system import InfiniteSystem, Site
+from .system import is_infinite, Site
 from ._common import UserCodeError, get_parameters, deprecate_args
 
 
@@ -149,7 +149,7 @@ def _get_all_orbs(gint[:, :] where, gint[:, :] site_ranges):
 
 def _get_tot_norbs(syst):
     cdef gint _unused, tot_norbs
-    is_infinite_system = isinstance(syst, InfiniteSystem)
+    is_infinite_system = is_infinite(syst)
     n_sites = syst.cell_size if is_infinite_system else syst.graph.num_nodes
     _get_orbs(np.asarray(syst.site_ranges, dtype=gint_dtype),
               n_sites, &tot_norbs, &_unused)
@@ -165,7 +165,7 @@ def _normalize_site_where(syst, where):
     otherwise it should contain integers.
     """
     if where is None:
-        if isinstance(syst, InfiniteSystem):
+        if is_infinite(syst):
             where = list(range(syst.cell_size))
         else:
             where = list(range(syst.graph.num_nodes))
@@ -173,7 +173,7 @@ def _normalize_site_where(syst, where):
         try:
             where = [syst.id_by_site[s] for s in filter(where, syst.sites)]
         except AttributeError:
-            if isinstance(syst, InfiniteSystem):
+            if is_infinite(syst):
                 where = [s for s in range(syst.cell_size) if where(s)]
             else:
                 where = [s for s in range(syst.graph.num_nodes) if where(s)]
@@ -187,7 +187,7 @@ def _normalize_site_where(syst, where):
 
     where = np.asarray(where, dtype=gint_dtype).reshape(-1, 1)
 
-    if isinstance(syst, InfiniteSystem) and np.any(where >= syst.cell_size):
+    if is_infinite(syst) and np.any(where >= syst.cell_size):
         raise ValueError('Only sites in the fundamental domain may be '
                          'specified using `where`.')
     if np.any(np.logical_or(where < 0, where >= syst.graph.num_nodes)):
@@ -208,7 +208,7 @@ def _normalize_hopping_where(syst, where):
     if where is None:
         # we cannot extract the hoppings in the same order as they are in the
         # graph while simultaneously excluding all inter-cell hoppings
-        if isinstance(syst, InfiniteSystem):
+        if is_infinite(syst):
             raise ValueError('`where` must be provided when calculating '
                              'current in an InfiniteSystem.')
         where = list(syst.graph)
@@ -241,7 +241,7 @@ def _normalize_hopping_where(syst, where):
 
     where = np.asarray(where, dtype=gint_dtype)
 
-    if isinstance(syst, InfiniteSystem) and np.any(where > syst.cell_size):
+    if is_infinite(syst) and np.any(where > syst.cell_size):
         raise ValueError('Only intra-cell hoppings may be specified '
                          'using `where`.')
 
diff --git a/kwant/plotter.py b/kwant/plotter.py
index a9e5ec42..31a0aba6 100644
--- a/kwant/plotter.py
+++ b/kwant/plotter.py
@@ -427,7 +427,7 @@ def sys_leads_sites(sys, num_lead_cells=2):
                               lead.builder.sites() for i in
                               range(num_lead_cells)))
             lead_cells.append(slice(start, len(sites)))
-    elif isinstance(syst, system.FiniteSystem):
+    elif system.is_finite(syst):
         sites = [(i, None, 0) for i in range(syst.graph.num_nodes)]
         for leadnr, lead in enumerate(syst.leads):
             start = len(sites)
@@ -961,7 +961,7 @@ def plot(sys, num_lead_cells=2, unit='nn',
 
     if site_color is None:
         cycle = _color_cycle()
-        if isinstance(syst, (builder.FiniteSystem, builder.InfiniteSystem)):
+        if builder.is_system(syst):
             # Skipping the leads for brevity.
             families = sorted({site.family for site in syst.sites})
             color_mapping = dict(zip(families, cycle))
@@ -1291,7 +1291,7 @@ def map(sys, value, colorbar=True, cmap=None, vmin=None, vmax=None, a=None,
     if callable(value):
         value = [value(site[0]) for site in sites]
     else:
-        if not isinstance(syst, system.FiniteSystem):
+        if not system.is_finite(syst):
             raise ValueError('List of values is only allowed as input '
                              'for finalized systems.')
     value = np.array(value)
@@ -1407,7 +1407,7 @@ def bands(sys, args=(), momenta=65, file=None, show=True, dpi=None,
                            "for bands()")
 
     syst = sys  # for naming consistency inside function bodies
-    _common.ensure_isinstance(syst, system.InfiniteSystem)
+    _common.ensure_isinstance(syst, (system.InfiniteSystem, system.InfiniteVectorizedSystem))
 
     momenta = np.array(momenta)
     if momenta.ndim != 1:
@@ -1483,7 +1483,7 @@ def spectrum(syst, x, y=None, params=None, mask=None, file=None,
     if y is not None and not _p.has3d:
         raise RuntimeError("Installed matplotlib does not support 3d plotting")
 
-    if isinstance(syst, system.FiniteSystem):
+    if system.is_finite(syst):
         def ham(**kwargs):
             return syst.hamiltonian_submatrix(params=kwargs, sparse=False)
     elif callable(syst):
@@ -1751,7 +1751,7 @@ def interpolate_current(syst, current, relwidth=None, abswidth=None, n=9):
         the extents of `field`: ((x0, x1), (y0, y1), ...)
 
     """
-    if not isinstance(syst, builder.FiniteSystem):
+    if not builder.is_finite_system(syst):
         raise TypeError("The system needs to be finalized.")
 
     if len(current) != syst.graph.num_edges:
@@ -1844,7 +1844,7 @@ def interpolate_density(syst, density, relwidth=None, abswidth=None, n=9,
         the extents of ``field``: ((x0, x1), (y0, y1), ...)
 
     """
-    if not isinstance(syst, builder.FiniteSystem):
+    if not builder.is_finite_system(syst):
         raise TypeError("The system needs to be finalized.")
 
     if len(density) != len(syst.sites):
diff --git a/kwant/system.py b/kwant/system.py
index a65f9259..a1c287bb 100644
--- a/kwant/system.py
+++ b/kwant/system.py
@@ -572,6 +572,10 @@ class FiniteVectorizedSystem(VectorizedSystem, FiniteSystemMixin, metaclass=abc.
     pass
 
 
+def is_finite(syst):
+    return isinstance(syst, (FiniteSystem, FiniteVectorizedSystem))
+
+
 class InfiniteSystemMixin(metaclass=abc.ABCMeta):
     """Abstract infinite low-level system.
 
@@ -731,6 +735,10 @@ class InfiniteVectorizedSystem(VectorizedSystem, InfiniteSystemMixin, metaclass=
     inter_cell_hopping = _system.vectorized_inter_cell_hopping
 
 
+def is_infinite(syst):
+    return isinstance(syst, (InfiniteSystem, InfiniteVectorizedSystem))
+
+
 class PrecalculatedLead:
     def __init__(self, modes=None, selfenergy=None):
         """A general lead defined by its self energy.
-- 
GitLab