diff --git a/kwant/operator.pyx b/kwant/operator.pyx index 8421ba56df78d8718cf421e71ab0c09a6a1f4744..a10624d1ccf3d7d6a2a175e16d6e00e3126c3bad 100644 --- a/kwant/operator.pyx +++ b/kwant/operator.pyx @@ -13,6 +13,7 @@ import cython from operator import itemgetter import functools as ft import collections +import numbers import numpy as np import tinyarray as ta @@ -21,6 +22,7 @@ from scipy.sparse import coo_matrix from libc cimport math from .graph.core cimport EdgeIterator +from .graph.core import DisabledFeatureError, NodeDoesNotExistError from .graph.defs cimport gint from .graph.defs import gint_dtype from .system import InfiniteSystem @@ -159,34 +161,42 @@ def _normalize_site_where(syst, where): """Normalize the format of `where` when `where` contains sites. If `where` is None, then all sites in the system are returned. - If it is a general iterator then it is expanded into an array. If `syst` - is a finalized Builder then `where` should contain `Site` objects, + If it is a general sequence then it is expanded into an array. If `syst` + is a finalized Builder then `where` may contain `Site` objects, otherwise it should contain integers. """ if where is None: - size = (syst.cell_size - if isinstance(syst, InfiniteSystem) else syst.graph.num_nodes) - _where = list(range(size)) + if isinstance(syst, InfiniteSystem): + where = list(range(syst.cell_size)) + else: + where = list(range(syst.graph.num_nodes)) elif callable(where): try: - _where = [syst.id_by_site[a] for a in filter(where, syst.sites)] + where = [syst.id_by_site[s] for s in filter(where, syst.sites)] except AttributeError: - _where = list(filter(where, range(syst.graph.num_nodes))) + if isinstance(syst, InfiniteSystem): + where = [s for s in range(syst.cell_size) if where(s)] + else: + where = [s for s in range(syst.graph.num_nodes) if where(s)] else: - try: - _where = list(syst.id_by_site[s] for s in where) - except AttributeError: - _where = list(where) - if any(w < 0 or w >= syst.graph.num_nodes for w in _where): - raise ValueError('`where` contains sites that are not in the ' - 'system.') + # Cannot check for builder.Site due to circular imports + if not isinstance(where[0], numbers.Integral): + try: + where = [syst.id_by_site[s] for s in where] + except AttributeError: + raise TypeError("'where' contains Sites, but the system is not " + "a finalized Builder.") + + where = np.asarray(where, dtype=gint_dtype).reshape(-1, 1) - if isinstance(syst, InfiniteSystem): - if any(w >= syst.cell_size for w in _where): - raise ValueError('Only sites in the fundamental domain may be ' - 'specified using `where`.') + if isinstance(syst, InfiniteSystem) and np.any(where >= syst.cell_size): + raise ValueError('Only sites in the fundamental domain may be ' + 'specified using `where`.') + if np.any(np.logical_or(where < 0, where >= syst.graph.num_nodes)): + raise ValueError('`where` contains sites that are not in the ' + 'system.') - return np.asarray(_where, dtype=gint_dtype).reshape(len(_where), 1) + return where def _normalize_hopping_where(syst, where): @@ -194,7 +204,7 @@ def _normalize_hopping_where(syst, where): If `where` is None, then all hoppings in the system are returned. If it is a general iterator then it is expanded into an array. If `syst` is - a finalized Builder then `where` should contain pairs of `Site` objects, + a finalized Builder then `where` may contain pairs of `Site` objects, otherwise it should contain pairs of integers. """ if where is None: @@ -203,33 +213,42 @@ def _normalize_hopping_where(syst, where): if isinstance(syst, InfiniteSystem): raise ValueError('`where` must be provided when calculating ' 'current in an InfiniteSystem.') - _where = list(syst.graph) + where = list(syst.graph) elif callable(where): if hasattr(syst, "sites"): - def idx_where(hop): + def idxwhere(hop): a, b = hop return where(syst.sites[a], syst.sites[b]) - _where = list(filter(idx_where, syst.graph)) + where = list(filter(idxwhere, syst.graph)) else: - _where = list(filter(lambda h: where(*h), syst.graph)) + where = list(filter(lambda h: where(*h), syst.graph)) else: + # Cannot check for builder.Site due to circular imports + if not isinstance(where[0][0], numbers.Integral): + try: + where = list((syst.id_by_site[a], syst.id_by_site[b]) + for a, b in where) + except AttributeError: + raise TypeError("'where' contains Sites, but the system is not " + "a finalized Builder.") + # NOTE: if we ever have operators that contain elements that are + # not in the system graph, then we should modify this check try: - _where = list((syst.id_by_site[a], syst.id_by_site[b]) - for a, b in where) - except AttributeError: - _where = list(where) - # NOTE: if we ever have operators that contain elements that are - # not in the system graph, then we should modify this check + error = ValueError('`where` contains hoppings that are not ' + 'in the system.') if any(not syst.graph.has_edge(*w) for w in where): - raise ValueError('`where` contains hoppings that are not in the ' - 'system.') + raise error + # If where contains: negative integers, or integers > # of sites + except (NodeDoesNotExistError, DisabledFeatureError): + raise error + + where = np.asarray(where, dtype=gint_dtype) - if isinstance(syst, InfiniteSystem): - if any(a > syst.cell_size or b > syst.cell_size for a, b in _where): - raise ValueError('Only intra-cell hoppings may be specified ' - 'using `where`.') + if isinstance(syst, InfiniteSystem) and np.any(where > syst.cell_size): + raise ValueError('Only intra-cell hoppings may be specified ' + 'using `where`.') - return np.asarray(_where, dtype=gint_dtype) + return where ## These two classes are here to avoid using closures, as these will diff --git a/kwant/tests/test_operator.py b/kwant/tests/test_operator.py index f6e5271ba658e683ac5a2e02abf3c3ba311de61f..6bb05593f2237ae1599d0ea24d124c9d4ca0efd0 100644 --- a/kwant/tests/test_operator.py +++ b/kwant/tests/test_operator.py @@ -128,12 +128,35 @@ def test_operator_construction(): fwhere = tuple(fsyst.id_by_site[s] for s in where) A = ops.Density(fsyst, where=where) assert np.all(np.asarray(A.where).reshape(-1) == fwhere) + # Test for passing integers as 'where' + A = ops.Density(fsyst, where=fwhere) + assert np.all(np.asarray(A.where).reshape(-1) == fwhere) + # Test passing invalid sites + with raises(ValueError): + ops.Density(fsyst, where=[lat(100)]) + with raises(ValueError): + ops.Density(fsyst, where=[-1]) + with raises(ValueError): + ops.Density(fsyst, where=[10000]) where = [(lat(2, 2), lat(1, 2)), (lat(0, 0), lat(0, 1))] fwhere = np.asarray([(fsyst.id_by_site[a], fsyst.id_by_site[b]) for a, b in where]) A = ops.Current(fsyst, where=where) assert np.all(np.asarray(A.where) == fwhere) + # Test for passing integers as 'where' + A = ops.Current(fsyst, where=fwhere) + assert np.all(np.asarray(A.where) == fwhere) + # Test passing invalid hoppings + with raises(ValueError): + ops.Current(fsyst, where=[(lat(2, 2), lat(0, 0))]) + with raises(ValueError): + ops.Current(fsyst, where=[(-1, 1)]) + with raises(ValueError): + ops.Current(fsyst, where=[(len(fsyst.sites), 1)]) + with raises(ValueError): + ops.Current(fsyst, where=[(fsyst.id_by_site[lat(2, 2)], + fsyst.id_by_site[lat(0, 0)])]) # test construction with `where` given by a function tag_list = [(1, 0), (1, 1), (1, 2)] @@ -304,7 +327,7 @@ def test_opservables_spin(): down, up = kwant.wave_function(fsyst, energy=1., params=params)(0) x_hoppings = kwant.builder.HoppingKind((1,), lat) - spin_current_z = ops.Current(fsyst, sigmaz, where=x_hoppings(syst)) + spin_current_z = ops.Current(fsyst, sigmaz, where=list(x_hoppings(syst))) _test(spin_current_z, up, params=params, per_el_val=1) _test(spin_current_z, down, params=params, per_el_val=-1) @@ -366,12 +389,14 @@ def test_opservables_gauged(): (Us[i], sigmaz, Us[i].conjugate().transpose())) x_hoppings = kwant.builder.HoppingKind((1,), lat) - spin_current_gauge = ops.Current(fsyst, M_a, where=x_hoppings(syst)) + spin_current_gauge = ops.Current(fsyst, M_a, + where=list(x_hoppings(syst))) _test(spin_current_gauge, up, per_el_val=1) _test(spin_current_gauge, down, per_el_val=-1) # check the reverse is also true minus_x_hoppings = kwant.builder.HoppingKind((-1,), lat) - spin_current_gauge = ops.Current(fsyst, M_a, where=minus_x_hoppings(syst)) + spin_current_gauge = ops.Current(fsyst, M_a, + where=list(minus_x_hoppings(syst))) _test(spin_current_gauge, up, per_el_val=-1) _test(spin_current_gauge, down, per_el_val=1)