From 447e27c20125fcb9d132d1d41e96b6f1d5205d1c Mon Sep 17 00:00:00 2001
From: Joseph Weston <joseph.weston08@gmail.com>
Date: Sat, 21 May 2016 10:02:01 +0200
Subject: [PATCH] raise UserCodeError with message when Hamiltonian value
 functions fail

fixes #36
---
 doc/source/reference/kwant.rst |  1 +
 kwant/__init__.py              |  5 +++--
 kwant/_common.py               | 15 +++++++++++++-
 kwant/builder.py               | 28 ++++++++++++++++++++-----
 kwant/tests/test_builder.py    | 38 +++++++++++++++++++++++++++++++++-
 5 files changed, 78 insertions(+), 9 deletions(-)

diff --git a/doc/source/reference/kwant.rst b/doc/source/reference/kwant.rst
index 5dc52a97..61a3d526 100644
--- a/doc/source/reference/kwant.rst
+++ b/doc/source/reference/kwant.rst
@@ -15,6 +15,7 @@ The version of Kwant is available under the name ``__version__``.
    :toctree: generated/
 
    KwantDeprecationWarning
+   UserCodeError
 
 .. currentmodule:: kwant.builder
 
diff --git a/kwant/__init__.py b/kwant/__init__.py
index 74e80f9f..389d0146 100644
--- a/kwant/__init__.py
+++ b/kwant/__init__.py
@@ -23,8 +23,9 @@ except ImportError:
     else:
         raise
 
-from ._common import KwantDeprecationWarning
-__all__.append('KwantDeprecationWarning')
+from ._common import KwantDeprecationWarning, UserCodeError
+
+__all__.extend(['KwantDeprecationWarning', 'UserCodeError'])
 
 from ._common import version as __version__
 
diff --git a/kwant/_common.py b/kwant/_common.py
index b67d1326..23e5084d 100644
--- a/kwant/_common.py
+++ b/kwant/_common.py
@@ -9,7 +9,7 @@
 import subprocess
 import os
 
-__all__ = ['version', 'KwantDeprecationWarning']
+__all__ = ['version', 'KwantDeprecationWarning', 'UserCodeError']
 
 package_root = os.path.dirname(os.path.realpath(__file__))
 distr_root = os.path.dirname(package_root)
@@ -91,6 +91,19 @@ class KwantDeprecationWarning(Warning):
     pass
 
 
+class UserCodeError(Exception):
+    """Class for errors that occur in user-provided code.
+
+    Usually users will define value functions that Kwant calls in order to
+    evaluate the Hamiltonian.  If one of these function raises an exception
+    then it is caught and this error is raised in its place. This makes it
+    clear that the error is from the user's code (and not a bug in Kwant) and
+    also makes it possible for any libraries that wrap Kwant to detect when a
+    user's function causes an error.
+    """
+    pass
+
+
 def ensure_isinstance(obj, typ, msg=None):
     if isinstance(obj, typ):
         return
diff --git a/kwant/builder.py b/kwant/builder.py
index b1d89bb5..4092a006 100644
--- a/kwant/builder.py
+++ b/kwant/builder.py
@@ -15,7 +15,7 @@ import operator
 from itertools import islice, chain
 import tinyarray as ta
 import numpy as np
-from . import system, graph, KwantDeprecationWarning
+from . import system, graph, KwantDeprecationWarning, UserCodeError
 from ._common import ensure_isinstance
 
 
@@ -1377,6 +1377,12 @@ class Builder:
 
 ################ Finalized systems
 
+def _raise_user_error(exc, func):
+    msg = ('Error occurred in user-supplied value function "{0}".\n'
+           'See the upper part of the above backtrace for more information.')
+    raise UserCodeError(msg.format(func.__name__)) from exc
+
+
 class FiniteSystem(system.FiniteSystem):
     """
     Finalized `Builder` with leads.
@@ -1387,7 +1393,10 @@ class FiniteSystem(system.FiniteSystem):
         if i == j:
             value = self.onsite_hamiltonians[i]
             if callable(value):
-                value = value(self.sites[i], *args)
+                try:
+                    value = value(self.sites[i], *args)
+                except Exception as exc:
+                    _raise_user_error(exc, value)
         else:
             edge_id = self.graph.first_edge_id(i, j)
             value = self.hoppings[edge_id]
@@ -1398,7 +1407,10 @@ class FiniteSystem(system.FiniteSystem):
                 value = self.hoppings[edge_id]
             if callable(value):
                 sites = self.sites
-                value = value(sites[i], sites[j], *args)
+                try:
+                    value = value(sites[i], sites[j], *args)
+                except Exception as exc:
+                    _raise_user_error(exc, value)
             if conj:
                 value = herm_conj(value)
         return value
@@ -1421,7 +1433,10 @@ class InfiniteSystem(system.InfiniteSystem):
                 i -= self.cell_size
             value = self.onsite_hamiltonians[i]
             if callable(value):
-                value = value(self.symmetry.to_fd(self.sites[i]), *args)
+                try:
+                    value = value(self.symmetry.to_fd(self.sites[i]), *args)
+                except Exception as exc:
+                    _raise_user_error(exc, value)
         else:
             edge_id = self.graph.first_edge_id(i, j)
             value = self.hoppings[edge_id]
@@ -1435,7 +1450,10 @@ class InfiniteSystem(system.InfiniteSystem):
                 site_i = sites[i]
                 site_j = sites[j]
                 site_i, site_j = self.symmetry.to_fd(site_i, site_j)
-                value = value(site_i, site_j, *args)
+                try:
+                    value = value(site_i, site_j, *args)
+                except Exception as exc:
+                    _raise_user_error(exc, value)
             if conj:
                 value = herm_conj(value)
         return value
diff --git a/kwant/tests/test_builder.py b/kwant/tests/test_builder.py
index 4b9ba493..b705e9a8 100644
--- a/kwant/tests/test_builder.py
+++ b/kwant/tests/test_builder.py
@@ -8,7 +8,7 @@
 
 import warnings
 from random import Random
-from nose.tools import assert_raises
+from nose.tools import assert_raises, assert_true
 from numpy.testing import assert_equal, assert_almost_equal
 import tinyarray as ta
 import kwant
@@ -433,6 +433,42 @@ def test_hamiltonian_evaluation():
         assert_equal(fsyst.hamiltonian(t, h),
                      syst[tsite, hsite](tsite, hsite))
 
+    # test when user-function raises errors
+    def onsite_raises(site):
+        raise ValueError()
+
+    def hopping_raises(a, b):
+        raise ValueError('error message')
+
+    def test_raising(fsyst, hop):
+        a, b = hop
+        # exceptions are converted to kwant.UserCodeError and we add our message
+        with assert_raises(kwant.UserCodeError) as ctx:
+            fsyst.hamiltonian(a, a)
+        msg = 'Error occurred in user-supplied value function "onsite_raises"'
+        assert_true(msg in str(ctx.exception))
+
+        for hop in [(a, b), (b, a)]:
+            with assert_raises(kwant.UserCodeError) as ctx:
+                fsyst.hamiltonian(*hop)
+            msg = 'Error occurred in user-supplied value function "hopping_raises"'
+            assert_true(msg in str(ctx.exception))
+
+    # test with finite system
+    new_hop = (fam(-1, 0), fam(0, 0))
+    syst[new_hop[0]] = onsite_raises
+    syst[new_hop] = hopping_raises
+    fsyst = syst.finalized()
+    hop = tuple(map(fsyst.sites.index, new_hop))
+    test_raising(fsyst, hop)
+
+    # test with infinite system
+    inf_syst = kwant.Builder(VerySimpleSymmetry(2))
+    inf_syst += syst
+    inf_fsyst = inf_syst.finalized()
+    hop = tuple(map(inf_fsyst.sites.index, new_hop))
+    test_raising(inf_fsyst, hop)
+
 
 def test_dangling():
     def make_system():
-- 
GitLab