From 447e27c20125fcb9d132d1d41e96b6f1d5205d1c Mon Sep 17 00:00:00 2001 From: Joseph Weston <joseph.weston08@gmail.com> Date: Sat, 21 May 2016 10:02:01 +0200 Subject: [PATCH] raise UserCodeError with message when Hamiltonian value functions fail fixes #36 --- doc/source/reference/kwant.rst | 1 + kwant/__init__.py | 5 +++-- kwant/_common.py | 15 +++++++++++++- kwant/builder.py | 28 ++++++++++++++++++++----- kwant/tests/test_builder.py | 38 +++++++++++++++++++++++++++++++++- 5 files changed, 78 insertions(+), 9 deletions(-) diff --git a/doc/source/reference/kwant.rst b/doc/source/reference/kwant.rst index 5dc52a97..61a3d526 100644 --- a/doc/source/reference/kwant.rst +++ b/doc/source/reference/kwant.rst @@ -15,6 +15,7 @@ The version of Kwant is available under the name ``__version__``. :toctree: generated/ KwantDeprecationWarning + UserCodeError .. currentmodule:: kwant.builder diff --git a/kwant/__init__.py b/kwant/__init__.py index 74e80f9f..389d0146 100644 --- a/kwant/__init__.py +++ b/kwant/__init__.py @@ -23,8 +23,9 @@ except ImportError: else: raise -from ._common import KwantDeprecationWarning -__all__.append('KwantDeprecationWarning') +from ._common import KwantDeprecationWarning, UserCodeError + +__all__.extend(['KwantDeprecationWarning', 'UserCodeError']) from ._common import version as __version__ diff --git a/kwant/_common.py b/kwant/_common.py index b67d1326..23e5084d 100644 --- a/kwant/_common.py +++ b/kwant/_common.py @@ -9,7 +9,7 @@ import subprocess import os -__all__ = ['version', 'KwantDeprecationWarning'] +__all__ = ['version', 'KwantDeprecationWarning', 'UserCodeError'] package_root = os.path.dirname(os.path.realpath(__file__)) distr_root = os.path.dirname(package_root) @@ -91,6 +91,19 @@ class KwantDeprecationWarning(Warning): pass +class UserCodeError(Exception): + """Class for errors that occur in user-provided code. + + Usually users will define value functions that Kwant calls in order to + evaluate the Hamiltonian. If one of these function raises an exception + then it is caught and this error is raised in its place. This makes it + clear that the error is from the user's code (and not a bug in Kwant) and + also makes it possible for any libraries that wrap Kwant to detect when a + user's function causes an error. + """ + pass + + def ensure_isinstance(obj, typ, msg=None): if isinstance(obj, typ): return diff --git a/kwant/builder.py b/kwant/builder.py index b1d89bb5..4092a006 100644 --- a/kwant/builder.py +++ b/kwant/builder.py @@ -15,7 +15,7 @@ import operator from itertools import islice, chain import tinyarray as ta import numpy as np -from . import system, graph, KwantDeprecationWarning +from . import system, graph, KwantDeprecationWarning, UserCodeError from ._common import ensure_isinstance @@ -1377,6 +1377,12 @@ class Builder: ################ Finalized systems +def _raise_user_error(exc, func): + msg = ('Error occurred in user-supplied value function "{0}".\n' + 'See the upper part of the above backtrace for more information.') + raise UserCodeError(msg.format(func.__name__)) from exc + + class FiniteSystem(system.FiniteSystem): """ Finalized `Builder` with leads. @@ -1387,7 +1393,10 @@ class FiniteSystem(system.FiniteSystem): if i == j: value = self.onsite_hamiltonians[i] if callable(value): - value = value(self.sites[i], *args) + try: + value = value(self.sites[i], *args) + except Exception as exc: + _raise_user_error(exc, value) else: edge_id = self.graph.first_edge_id(i, j) value = self.hoppings[edge_id] @@ -1398,7 +1407,10 @@ class FiniteSystem(system.FiniteSystem): value = self.hoppings[edge_id] if callable(value): sites = self.sites - value = value(sites[i], sites[j], *args) + try: + value = value(sites[i], sites[j], *args) + except Exception as exc: + _raise_user_error(exc, value) if conj: value = herm_conj(value) return value @@ -1421,7 +1433,10 @@ class InfiniteSystem(system.InfiniteSystem): i -= self.cell_size value = self.onsite_hamiltonians[i] if callable(value): - value = value(self.symmetry.to_fd(self.sites[i]), *args) + try: + value = value(self.symmetry.to_fd(self.sites[i]), *args) + except Exception as exc: + _raise_user_error(exc, value) else: edge_id = self.graph.first_edge_id(i, j) value = self.hoppings[edge_id] @@ -1435,7 +1450,10 @@ class InfiniteSystem(system.InfiniteSystem): site_i = sites[i] site_j = sites[j] site_i, site_j = self.symmetry.to_fd(site_i, site_j) - value = value(site_i, site_j, *args) + try: + value = value(site_i, site_j, *args) + except Exception as exc: + _raise_user_error(exc, value) if conj: value = herm_conj(value) return value diff --git a/kwant/tests/test_builder.py b/kwant/tests/test_builder.py index 4b9ba493..b705e9a8 100644 --- a/kwant/tests/test_builder.py +++ b/kwant/tests/test_builder.py @@ -8,7 +8,7 @@ import warnings from random import Random -from nose.tools import assert_raises +from nose.tools import assert_raises, assert_true from numpy.testing import assert_equal, assert_almost_equal import tinyarray as ta import kwant @@ -433,6 +433,42 @@ def test_hamiltonian_evaluation(): assert_equal(fsyst.hamiltonian(t, h), syst[tsite, hsite](tsite, hsite)) + # test when user-function raises errors + def onsite_raises(site): + raise ValueError() + + def hopping_raises(a, b): + raise ValueError('error message') + + def test_raising(fsyst, hop): + a, b = hop + # exceptions are converted to kwant.UserCodeError and we add our message + with assert_raises(kwant.UserCodeError) as ctx: + fsyst.hamiltonian(a, a) + msg = 'Error occurred in user-supplied value function "onsite_raises"' + assert_true(msg in str(ctx.exception)) + + for hop in [(a, b), (b, a)]: + with assert_raises(kwant.UserCodeError) as ctx: + fsyst.hamiltonian(*hop) + msg = 'Error occurred in user-supplied value function "hopping_raises"' + assert_true(msg in str(ctx.exception)) + + # test with finite system + new_hop = (fam(-1, 0), fam(0, 0)) + syst[new_hop[0]] = onsite_raises + syst[new_hop] = hopping_raises + fsyst = syst.finalized() + hop = tuple(map(fsyst.sites.index, new_hop)) + test_raising(fsyst, hop) + + # test with infinite system + inf_syst = kwant.Builder(VerySimpleSymmetry(2)) + inf_syst += syst + inf_fsyst = inf_syst.finalized() + hop = tuple(map(inf_fsyst.sites.index, new_hop)) + test_raising(inf_fsyst, hop) + def test_dangling(): def make_system(): -- GitLab