Skip to content
Snippets Groups Projects
Commit ba8b6ea8 authored by Joseph Weston's avatar Joseph Weston
Browse files

move 'memoize' decorator to kwant/common.py

this decorator will be needed in kwant/builder.py in the next
commit.
parent a3548f5f
No related branches found
No related tags found
No related merge requests found
...@@ -13,6 +13,7 @@ import inspect ...@@ -13,6 +13,7 @@ import inspect
import warnings import warnings
import importlib import importlib
import functools import functools
import collections
from contextlib import contextmanager from contextlib import contextmanager
__all__ = ['KwantDeprecationWarning', 'UserCodeError'] __all__ = ['KwantDeprecationWarning', 'UserCodeError']
...@@ -191,3 +192,27 @@ class lazy_import: ...@@ -191,3 +192,27 @@ class lazy_import:
package = sys.modules[self.__package] package = sys.modules[self.__package]
setattr(package, self.__module, mod) setattr(package, self.__module, mod)
return getattr(mod, name) return getattr(mod, name)
def _hashable(obj):
return isinstance(obj, collections.abc.Hashable)
def memoize(f):
"""Decorator to memoize a function that works even with unhashable args.
This decorator will even work with functions whose args are not hashable.
The cache key is made up by the hashable arguments and the ids of the
non-hashable args. It is up to the user to make sure that non-hashable
args do not change during the lifetime of the decorator.
This decorator will keep reevaluating functions that return None.
"""
def lookup(*args):
key = tuple(arg if _hashable(arg) else id(arg) for arg in args)
result = cache.get(key)
if result is None:
cache[key] = result = f(*args)
return result
cache = {}
return lookup
...@@ -20,36 +20,12 @@ from . import builder, system, plotter ...@@ -20,36 +20,12 @@ from . import builder, system, plotter
from .linalg import lll from .linalg import lll
from .builder import herm_conj, HermConjOfFunc from .builder import herm_conj, HermConjOfFunc
from .lattice import TranslationalSymmetry from .lattice import TranslationalSymmetry
from ._common import get_parameters from ._common import get_parameters, memoize
__all__ = ['wraparound', 'plot_2d_bands'] __all__ = ['wraparound', 'plot_2d_bands']
def _hashable(obj):
return isinstance(obj, collections.abc.Hashable)
def _memoize(f):
"""Decorator to memoize a function that works even with unhashable args.
This decorator will even work with functions whose args are not hashable.
The cache key is made up by the hashable arguments and the ids of the
non-hashable args. It is up to the user to make sure that non-hashable
args do not change during the lifetime of the decorator.
This decorator will keep reevaluating functions that return None.
"""
def lookup(*args):
key = tuple(arg if _hashable(arg) else id(arg) for arg in args)
result = cache.get(key)
if result is None:
cache[key] = result = f(*args)
return result
cache = {}
return lookup
def _set_signature(func, params): def _set_signature(func, params):
"""Set the signature of 'func'. """Set the signature of 'func'.
...@@ -103,7 +79,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'): ...@@ -103,7 +79,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
format. It will be deprecated in the 2.0 release of Kwant. format. It will be deprecated in the 2.0 release of Kwant.
""" """
@_memoize @memoize
def bind_site(val): def bind_site(val):
def f(*args): def f(*args):
a, *args = args a, *args = args
...@@ -113,7 +89,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'): ...@@ -113,7 +89,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
_set_signature(f, get_parameters(val) + momenta) _set_signature(f, get_parameters(val) + momenta)
return f return f
@_memoize @memoize
def bind_hopping_as_site(elem, val): def bind_hopping_as_site(elem, val):
def f(*args): def f(*args):
a, *args = args a, *args = args
...@@ -128,7 +104,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'): ...@@ -128,7 +104,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
_set_signature(f, params + momenta) _set_signature(f, params + momenta)
return f return f
@_memoize @memoize
def bind_hopping(elem, val): def bind_hopping(elem, val):
def f(*args): def f(*args):
a, b, *args = args a, b, *args = args
...@@ -142,7 +118,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'): ...@@ -142,7 +118,7 @@ def wraparound(builder, keep=None, *, coordinate_names='xyz'):
_set_signature(f, params + momenta) _set_signature(f, params + momenta)
return f return f
@_memoize @memoize
def bind_sum(num_sites, *vals): def bind_sum(num_sites, *vals):
"""Construct joint signature for all 'vals'.""" """Construct joint signature for all 'vals'."""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment