Commit 895c7809 authored by Artem Pulkin's avatar Artem Pulkin
Browse files

kernel, ml: make fidelity reduced by minimum

parent febb1112
......@@ -529,7 +529,7 @@ def eval(images, potentials, kname, squeeze=True, ignore_missing_species=False,
return out
def total(images, potentials, kname="kernel", squeeze=False, resolving=False, **kwargs):
def total(images, potentials, kname="kernel", squeeze=False, resolving=False, reduction=np.sum, **kwargs):
"""
Total energy as a sum of all possible potential terms.
......@@ -550,6 +550,8 @@ def total(images, potentials, kname="kernel", squeeze=False, resolving=False, **
is passed.
resolving : bool
If True, runs species-resolving kernels.
reduction : Callable
The reduction for individual potentials.
kwargs
Other arguments to `eval`.
......@@ -558,7 +560,7 @@ def total(images, potentials, kname="kernel", squeeze=False, resolving=False, **
energy : float
The total energy value.
"""
return eval(images, potentials, kname, squeeze=squeeze, resolving=resolving, **kwargs).sum(axis=0)
return reduction(eval(images, potentials, kname, squeeze=squeeze, resolving=resolving, **kwargs), axis=0)
def grad(images, potentials, kname="kernel_gradient", **kwargs):
......@@ -895,8 +897,10 @@ class ScalarFunctionWrapper:
# compute fidelity
try:
out = np.full((len(self.potentials), images.cell.size), float("+inf"))
images.cell.meta["fidelity"] = total(images, self.potentials, "fidelity", resolving=True,
prefer_parallel=self.prefer_parallel, missing_kernel=float("inf"))
prefer_parallel=self.prefer_parallel, missing_kernel=float("inf"),
reduction=np.min, out=out)
except ValueError as e:
if e.args == ("All potentials miss kernel 'fidelity'",):
pass
......
......@@ -2175,7 +2175,8 @@ def kernel_u_nn(kind, is_parallel, r_indptr, r_indices, r_data, cartesian_row, c
species_mask, kind, is_parallel, additional_inputs=descriptor_additional_inputs) if do_grad else None
if do_fidelity:
out[out_mask, ...] += fidelity
# reduce by taking min in place of sum
out[out_mask, ...] = np.minimum(out[out_mask, ...], fidelity)
else:
out[out_mask, ...] += nn_forward_middleware(descriptor_values, nn, descriptor_gradient_values)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment