Commit 05af5bd3 authored by Artem Pulkin's avatar Artem Pulkin
Browse files

workflows, kernel: always compute fidelity and filter by fidelity

parent 8590b74a
Pipeline #85415 passed with stages
in 19 minutes and 37 seconds
......@@ -802,7 +802,7 @@ def _parameters_cache(f):
class ScalarFunctionWrapper:
def __init__(self, sample, potentials, include_coordinates=True, include_vectors=False, normalize=None,
prefer_parallel=None, cell_logger=None, track_potential_fidelity=False, **kwargs):
prefer_parallel=None, cell_logger=None, **kwargs):
"""
A wrapper providing interfaces to total energy and gradient computation.
......@@ -822,9 +822,6 @@ class ScalarFunctionWrapper:
A flag to prefer parallel potential computations.
cell_logger : Callable
A function accumulating intermediate cell objects.
track_potential_fidelity : bool
If True, computes potential fidelity and stores it
as a part of the history recorded.
kwargs
Additional arguments to ``compute_images``.
"""
......@@ -838,7 +835,6 @@ class ScalarFunctionWrapper:
self.listeners = []
if cell_logger is not None:
self.listeners.append(cell_logger)
self.track_potential_fidelity = track_potential_fidelity
self.compute_images_kwargs = kwargs
self.potentials = potentials
......@@ -848,12 +844,6 @@ class ScalarFunctionWrapper:
@potentials.setter
def potentials(self, potentials):
if self.track_potential_fidelity: # early alert
for i in potentials:
try:
i.get_kernel_by_name("fidelity")
except KeyError:
raise ValueError("One or more potentials do not include the 'fidelity' kernel")
self.compute_images_kwargs["cutoff"] = common_cutoff(potentials)
self._potentials = potentials
self.eval.invalidate_cache()
......@@ -903,9 +893,15 @@ class ScalarFunctionWrapper:
else:
gv = None
if self.track_potential_fidelity:
# compute fidelity
try:
images.cell.meta["fidelity"] = total(images, self.potentials, "fidelity", resolving=True,
prefer_parallel=self.prefer_parallel)
prefer_parallel=self.prefer_parallel, missing_kernel=float("nan"))
except ValueError as e:
if e.args == ("All potentials miss kernel 'fidelity'",):
pass
else:
raise
self._notify(images)
return images, f, gc, gv
......
......@@ -417,6 +417,35 @@ class LocalPotential:
self.tag = tag
self.additional_inputs = list(additional_inputs) if additional_inputs is not None else None
def get_kernel_key(self, kname, resolving=True, prefer_parallel=None, size=None):
"""
Picks a kernel key based on the input.
Parameters
----------
kname : srt
resolving : bool
prefer_parallel : bool
Kernel parameters.
size : int
Expected input size.
Returns
-------
result : kernel_kind
The kernel key. The kernel may or may not be defined.
"""
if prefer_parallel is None:
if size is not None and size <= _serial_below:
prefer_parallel = False
else:
prefer_parallel = _prefer_parallel
if prefer_parallel:
k = kernel_kind(name=kname, parallel=True, resolving=resolving)
if k in self.kernels:
return k
return kernel_kind(name=kname, parallel=False, resolving=resolving)
def get_kernel_by_name(self, kname, resolving=True, prefer_parallel=None, rtn_key=False, size=None):
"""
Picks a potential kernel for the given requirements.
......@@ -441,21 +470,7 @@ class LocalPotential:
key
The key.
"""
if prefer_parallel is None:
if size is not None and size <= _serial_below:
prefer_parallel = False
else:
prefer_parallel = _prefer_parallel
if prefer_parallel:
k = kernel_kind(name=kname, parallel=True, resolving=resolving)
try:
if rtn_key:
return self.kernels[k], k
else:
return self.kernels[k]
except KeyError:
pass
k = kernel_kind(name=kname, parallel=False, resolving=resolving)
k = self.get_kernel_key(kname, resolving=resolving, prefer_parallel=prefer_parallel, size=size)
if rtn_key:
return self.kernels[k], k
else:
......@@ -1542,7 +1557,8 @@ def _pre_compute_r_quantities(distances, potentials):
def eval_potentials(encoded_potentials, kname, sparse_pair_distances, cartesian_row, cartesian_col, shift_vectors,
spec_encoded_row, pre_compute_r=False, additional_inputs=None, cutoff=None, out=None, **kwargs):
spec_encoded_row, pre_compute_r=False, additional_inputs=None, cutoff=None, out=None,
missing_kernel=None, **kwargs):
"""
Calculates potentials: values, gradients and more.
......@@ -1573,6 +1589,9 @@ def eval_potentials(encoded_potentials, kname, sparse_pair_distances, cartesian_
The output buffer `[n_potentials, n_atoms]` for
kname == "kernel" and `[n_potentials, n_atoms, n_atoms, 3]`
for kname == "kernel_gradient".
missing_kernel : float
If specified, will not raise missing kernel exception but instead will
fill the output with the value specified.
kwargs
Other common arguments to kernel functions.
......@@ -1598,14 +1617,38 @@ def eval_potentials(encoded_potentials, kname, sparse_pair_distances, cartesian_
kwargs_get = {k: kwargs[k] for k in ("resolving", "prefer_parallel") if k in kwargs}
if missing_kernel is not None:
encoded_potentials = list(
i
if i.get_kernel_key(kname, **kwargs_get) in i.kernels
else missing_kernel
for i in encoded_potentials
)
# Check all kernels produce the output of the same shape
shapes = set(i.get_kernel_by_name(kname, **kwargs_get).out_shape for i in encoded_potentials)
shapes = set(
i.get_kernel_by_name(kname, **kwargs_get).out_shape
for i in encoded_potentials
if i is not missing_kernel
)
if len(shapes) == 0:
if missing_kernel is not None:
if out is not None:
# nothing to do here: fill the buffer and return
out[:] += missing_kernel
return out
else:
if len(encoded_potentials) != 0:
raise ValueError(f"All potentials miss kernel '{kname}'")
raise ValueError("Empty potentials specified")
if len(shapes) != 1:
raise ValueError(f"The shape of the output across kernels is not the same: {shapes}")
if out is None:
out = np.zeros((len(encoded_potentials),) + PotentialKernel.compute_shape(shapes.pop(), len(cartesian_row)), dtype=float)
for i, potential in enumerate(encoded_potentials):
if missing_kernel is not None and potential is missing_kernel:
out[i] = missing_kernel
continue
if not isinstance(potential, LocalPotential):
raise ValueError(f'Not a LocalPotential: {repr(potential)}')
if not isinstance(potential.tag, np.ndarray) and potential.tag is not None:
......
......@@ -318,6 +318,33 @@ class MainTest(TestCase):
cells = Cell.load(out["cells_direct"])
assert len(cells) == 20 # first and last per input
def test_relax_snapshots_fidelity(self):
out = self.__run_yaml__("""
test-relax:
init:
units: default
prepare:
random_cells:
n: 10
density: !nu 0.0222 / Å ** 3
atoms:
bi: 26
se: 39
potentials:
- tag: harmonic repulsion
parameters:
a: !nu 2*Å
epsilon: !nu 5*eV
run:
snapshots: true
snapshots_f: true
options:
maxiter: 3
save: {cells_direct}
""", cells=amorphous_bise_json_sample, cells_direct_out=True)
cells = Cell.load(out["cells_direct"])
assert len(cells) == 0 # fidelity is zero everywhere
def test_md(self):
out = self.__run_yaml__("""
test-relax:
......
......@@ -563,7 +563,7 @@ class LJBoxTest(TestCase):
for i_k, k in enumerate(values_key)
]
)
wrapper = kernel.ScalarFunctionWrapper(self.cells[0], ml_potentials, track_potential_fidelity=True)
wrapper = kernel.ScalarFunctionWrapper(self.cells[0], ml_potentials)
with dyn.track_snapshots(wrapper) as result:
wrapper.f(self.cells[0].coordinates, self.cells[0].vectors)
assert len(result) == 1
......
......@@ -1556,7 +1556,9 @@ class log_dynamics:
fidelity = cell.meta.get('fidelity', None)
if fidelity is not None:
i = np.argmin(fidelity)
self.log.info(f" min(fidelity) {fidelity[i]} for atom {i}")
self.log.info(f" min(fidelity): {fidelity[i]} @{i}")
self.log.info(f" max(fidelity): {max(fidelity)} sum(fidelity): {sum(fidelity)} "
f"mean(fidelity): {np.mean(fidelity)}")
def __enter__(self):
self.last_energies = []
......@@ -1584,6 +1586,32 @@ class count_evaluations:
self.wrapper.listeners.remove(self)
def thin_out_by_fidelity(cells, fidelity, key="fidelity"):
"""
Filter a list of cells by requiring the fidelity.
Parameters
----------
cells : list
A list of cells.
fidelity : float
The minimal fidelity required.
key : str
The lookup key for fidelity.
Returns
-------
result : list
A list of resulting cells.
"""
for i in cells:
if key in i.meta:
break
else:
warnings.warn(f"Neither cell reports {key}: the returned cells are empty")
return list(i for i in cells if min(i.meta.get(key, [float("nan")])) >= fidelity)
def thin_out_by_cartesian_delta_thr(cells, thr, enforce_last=True):
"""
Filter a list of cells by accepting a minimal cartesian step
......@@ -1735,7 +1763,7 @@ class SDWorkflow(Workflow):
@staticmethod
def worker(dynamics_data, driver, prefer_parallel=None, starting_point=False, logger=None, logger_units=None,
potentials=None, snapshots_d=None, **kwargs):
potentials=None, snapshots_d=None, snapshots_f=None, **kwargs):
job_id, dynamics = dynamics_data
with count_evaluations(dynamics.wrapper) as counter:
if prefer_parallel is not None:
......@@ -1770,6 +1798,8 @@ class SDWorkflow(Workflow):
filename=w.filename,
lineno=w.lineno,
)
if snapshots_f is not None:
result = thin_out_by_fidelity(result, snapshots_f)
if snapshots_d is not None:
result = thin_out_by_cartesian_delta_thr(result, snapshots_d)
return job_id, result, {"evaluations": counter.n}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment