Commit 447e27c2 authored by Joseph Weston's avatar Joseph Weston
Browse files

raise UserCodeError with message when Hamiltonian value functions fail

fixes #36
parent ab98335e
......@@ -15,6 +15,7 @@ The version of Kwant is available under the name ``__version__``.
:toctree: generated/
KwantDeprecationWarning
UserCodeError
.. currentmodule:: kwant.builder
......
......@@ -23,8 +23,9 @@ except ImportError:
else:
raise
from ._common import KwantDeprecationWarning
__all__.append('KwantDeprecationWarning')
from ._common import KwantDeprecationWarning, UserCodeError
__all__.extend(['KwantDeprecationWarning', 'UserCodeError'])
from ._common import version as __version__
......
......@@ -9,7 +9,7 @@
import subprocess
import os
__all__ = ['version', 'KwantDeprecationWarning']
__all__ = ['version', 'KwantDeprecationWarning', 'UserCodeError']
package_root = os.path.dirname(os.path.realpath(__file__))
distr_root = os.path.dirname(package_root)
......@@ -91,6 +91,19 @@ class KwantDeprecationWarning(Warning):
pass
class UserCodeError(Exception):
"""Class for errors that occur in user-provided code.
Usually users will define value functions that Kwant calls in order to
evaluate the Hamiltonian. If one of these function raises an exception
then it is caught and this error is raised in its place. This makes it
clear that the error is from the user's code (and not a bug in Kwant) and
also makes it possible for any libraries that wrap Kwant to detect when a
user's function causes an error.
"""
pass
def ensure_isinstance(obj, typ, msg=None):
if isinstance(obj, typ):
return
......
......@@ -15,7 +15,7 @@ import operator
from itertools import islice, chain
import tinyarray as ta
import numpy as np
from . import system, graph, KwantDeprecationWarning
from . import system, graph, KwantDeprecationWarning, UserCodeError
from ._common import ensure_isinstance
......@@ -1377,6 +1377,12 @@ class Builder:
################ Finalized systems
def _raise_user_error(exc, func):
msg = ('Error occurred in user-supplied value function "{0}".\n'
'See the upper part of the above backtrace for more information.')
raise UserCodeError(msg.format(func.__name__)) from exc
class FiniteSystem(system.FiniteSystem):
"""
Finalized `Builder` with leads.
......@@ -1387,7 +1393,10 @@ class FiniteSystem(system.FiniteSystem):
if i == j:
value = self.onsite_hamiltonians[i]
if callable(value):
value = value(self.sites[i], *args)
try:
value = value(self.sites[i], *args)
except Exception as exc:
_raise_user_error(exc, value)
else:
edge_id = self.graph.first_edge_id(i, j)
value = self.hoppings[edge_id]
......@@ -1398,7 +1407,10 @@ class FiniteSystem(system.FiniteSystem):
value = self.hoppings[edge_id]
if callable(value):
sites = self.sites
value = value(sites[i], sites[j], *args)
try:
value = value(sites[i], sites[j], *args)
except Exception as exc:
_raise_user_error(exc, value)
if conj:
value = herm_conj(value)
return value
......@@ -1421,7 +1433,10 @@ class InfiniteSystem(system.InfiniteSystem):
i -= self.cell_size
value = self.onsite_hamiltonians[i]
if callable(value):
value = value(self.symmetry.to_fd(self.sites[i]), *args)
try:
value = value(self.symmetry.to_fd(self.sites[i]), *args)
except Exception as exc:
_raise_user_error(exc, value)
else:
edge_id = self.graph.first_edge_id(i, j)
value = self.hoppings[edge_id]
......@@ -1435,7 +1450,10 @@ class InfiniteSystem(system.InfiniteSystem):
site_i = sites[i]
site_j = sites[j]
site_i, site_j = self.symmetry.to_fd(site_i, site_j)
value = value(site_i, site_j, *args)
try:
value = value(site_i, site_j, *args)
except Exception as exc:
_raise_user_error(exc, value)
if conj:
value = herm_conj(value)
return value
......
......@@ -8,7 +8,7 @@
import warnings
from random import Random
from nose.tools import assert_raises
from nose.tools import assert_raises, assert_true
from numpy.testing import assert_equal, assert_almost_equal
import tinyarray as ta
import kwant
......@@ -433,6 +433,42 @@ def test_hamiltonian_evaluation():
assert_equal(fsyst.hamiltonian(t, h),
syst[tsite, hsite](tsite, hsite))
# test when user-function raises errors
def onsite_raises(site):
raise ValueError()
def hopping_raises(a, b):
raise ValueError('error message')
def test_raising(fsyst, hop):
a, b = hop
# exceptions are converted to kwant.UserCodeError and we add our message
with assert_raises(kwant.UserCodeError) as ctx:
fsyst.hamiltonian(a, a)
msg = 'Error occurred in user-supplied value function "onsite_raises"'
assert_true(msg in str(ctx.exception))
for hop in [(a, b), (b, a)]:
with assert_raises(kwant.UserCodeError) as ctx:
fsyst.hamiltonian(*hop)
msg = 'Error occurred in user-supplied value function "hopping_raises"'
assert_true(msg in str(ctx.exception))
# test with finite system
new_hop = (fam(-1, 0), fam(0, 0))
syst[new_hop[0]] = onsite_raises
syst[new_hop] = hopping_raises
fsyst = syst.finalized()
hop = tuple(map(fsyst.sites.index, new_hop))
test_raising(fsyst, hop)
# test with infinite system
inf_syst = kwant.Builder(VerySimpleSymmetry(2))
inf_syst += syst
inf_fsyst = inf_syst.finalized()
hop = tuple(map(inf_fsyst.sites.index, new_hop))
test_raising(inf_fsyst, hop)
def test_dangling():
def make_system():
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment