Commit b3694fd8 authored by Artem Pulkin's avatar Artem Pulkin
Browse files

kernel: more fidelity API

parent 895c7809
Pipeline #85423 passed with stages
in 19 minutes and 7 seconds
......@@ -616,6 +616,38 @@ def grad_cell(images, potentials, kname="kernel_cell_gradient", **kwargs):
return total(images, potentials, kname=kname, **kwargs)
def fidelity(images, potentials, kname="fidelity", resolving=True, reduction=np.min, missing_kernel=float("inf"),
**kwargs):
"""
Potential fidelity.
Parameters
----------
images : CellImages
Cell and its images.
potentials : list, LocalPotential
A list of potentials or a single potential.
kname : str, None
Function to evaluate: 'kernel', 'kernel_gradient' or whatever
other kernel function set for all potentials in the list.
resolving : bool
If True, runs species-resolving kernels.
reduction : Callable
The reduction for individual potentials.
missing_kernel : float
A float to use in place of kernels that do not support fidelity.
kwargs
Other arguments to `eval`.
Returns
-------
energy : float
The total energy value.
"""
return total(images, potentials, kname=kname, resolving=resolving, reduction=reduction,
missing_kernel=missing_kernel, **kwargs)
def common_cutoff(potentials):
"""
The maximal (common) cutoff of many potentials.
......@@ -777,6 +809,7 @@ class CellImages:
total = total
grad = grad
grad_cell = grad_cell
fidelity = fidelity
def _parameters_cache(f):
......@@ -898,9 +931,7 @@ class ScalarFunctionWrapper:
# compute fidelity
try:
out = np.full((len(self.potentials), images.cell.size), float("+inf"))
images.cell.meta["fidelity"] = total(images, self.potentials, "fidelity", resolving=True,
prefer_parallel=self.prefer_parallel, missing_kernel=float("inf"),
reduction=np.min, out=out)
images.cell.meta["fidelity"] = fidelity(images, self.potentials, out=out)
except ValueError as e:
if e.args == ("All potentials miss kernel 'fidelity'",):
pass
......
......@@ -343,7 +343,7 @@ class MainTest(TestCase):
save: {cells_direct}
""", cells=amorphous_bise_json_sample, cells_direct_out=True)
cells = Cell.load(out["cells_direct"])
assert len(cells) == 0 # fidelity is zero everywhere
assert 100 > len(cells) >= 20 # fidelity is +inf
def test_md(self):
out = self.__run_yaml__("""
......
......@@ -591,7 +591,8 @@ class LJBoxTest(TestCase):
modules, self.descriptor_collection_trivial, descriptor_fidelity_histograms=hist)
for i, image in enumerate(self.cells_wrapped_trivial):
fidelity = image.eval(ml_potentials, "fidelity", ignore_missing_species=True)
fidelity = image.eval(ml_potentials, "fidelity", ignore_missing_species=True,
out=np.full((len(ml_potentials), image.cell.size), float("+inf")))
for _i, (_p, _f) in enumerate(zip(ml_potentials, fidelity)):
_d = image.eval(_p.descriptors, "kernel", ignore_missing_species=True)[:, image.cell.values == _p.tag]
testing.assert_array_less(_d, 10)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment