Commit 28d1f541 authored by Artem Pulkin's avatar Artem Pulkin

ml: support partial energies

parent ddfcdc6f
Pipeline #43890 passed with stage
in 2 minutes and 32 seconds
......@@ -95,6 +95,7 @@ def collect_meta(field, cells, out, mask=None):
collect_energies = partial(collect_meta, "total-energy")
collect_forces = partial(collect_meta, "forces")
collect_charges = partial(collect_meta, "charges")
collect_partial_energies = partial(collect_meta, "partial-energy")
def prepare_descriptor_data(cells, descriptors, specimen, values, grad=False, dtype=torch.float64):
......@@ -329,7 +330,7 @@ class PerCellDataset(NoneTolerantTensorDataset):
class PerPointDataset(NoneTolerantTensorDataset):
def __init__(self, features, mask, features_g=None, charges=None, tag=None):
def __init__(self, features, mask, features_g=None, charges=None, energies_p=None, tag=None):
"""
A dataset with descriptors defined per specimen.
......@@ -349,24 +350,26 @@ class PerPointDataset(NoneTolerantTensorDataset):
`[n_samples, n_atoms, n_coords]`.
charges : torch.Tensor
An `[n_samples, n_species]` tensor with charges for all atoms of the same type.
energies_p : torch.Tensor
An `[n_samples, n_species]` tensor with per-atom energy contributions.
tag
An optional tag for this dataset.
"""
inputs = locals()
self.dtype = __assert_same_dtype__(inputs, "features", "features_g", "mask", "charges")
__assert_dimension_count__(inputs, "features", 3, "features_g", 5, "mask", 2, "charges", 3)
self.dtype = __assert_same_dtype__(inputs, "features", "features_g", "mask", "charges", "energies_p")
__assert_dimension_count__(inputs, "features", 3, "features_g", 5, "mask", 2, "charges", 3, "energies_p", 3)
self.n_samples = __assert_same_dimension__(inputs, "n_samples", "features", 0, "features_g", 0, "mask", 0,
"charges", 0)
"charges", 0, "energies_p", 0)
self.n_atoms = __assert_same_dimension__(inputs, "n_atoms", "features_g", 3)
self.n_species = __assert_same_dimension__(inputs, "n_species", "features", 1, "features_g", 1, "mask", 1,
"charges", 1)
"charges", 1, "energies_p", 1)
self.n_features = __assert_same_dimension__(inputs, "n_descriptors", "features", 2, "features_g", 2)
self.n_coords = __assert_same_dimension__(inputs, "n_coords", "features_g", 4)
__assert_same_dimension__(inputs, "[one]", "charges", 2, size=1)
__assert_same_dimension__(inputs, "[one]", "charges", 2, "energies_p", 2, size=1)
self.tag = tag
super().__init__(features, mask, features_g, charges)
super().__init__(features, mask, features_g, charges, energies_p)
@property
def features(self) -> torch.Tensor:
......@@ -384,14 +387,19 @@ class PerPointDataset(NoneTolerantTensorDataset):
def charges(self) -> torch.Tensor:
return self.tensors[3]
@property
def energies_p(self) -> torch.Tensor:
return self.tensors[4]
def is_gradient_available(self) -> bool:
"""Determines whether features gradients data is present."""
return self.features_g is not None
@staticmethod
def from_cells(cells, descriptors, specimen, values, grad=False, charge=False, tag=None, dtype=torch.float64):
def from_cells(cells, descriptors, specimen, values, grad=False, charge=False, energies_p=False, tag=None,
dtype=torch.float64):
"""
Prepares a per-point dataset with total features, feature gradients and energy gradients.
Prepares a per-point dataset with total features, feature gradients, energy gradients and partial energies.
Parameters
----------
......@@ -407,6 +415,8 @@ class PerPointDataset(NoneTolerantTensorDataset):
Include gradients.
charge : bool
Include atomic charges.
energies_p : bool
Include partial energies.
tag
Optional tag for the dataset.
dtype
......@@ -426,6 +436,11 @@ class PerPointDataset(NoneTolerantTensorDataset):
torch.zeros(*mask.shape, dtype=dtype),
mask=values == specimen
).unsqueeze(2) if charge else None,
energies_p=collect_partial_energies(
cells,
torch.zeros(*mask.shape, dtype=dtype),
mask=values == specimen
).unsqueeze(2) if energies_p else None,
tag=tag)
def to(self, dtype):
......@@ -449,6 +464,7 @@ class PerPointDataset(NoneTolerantTensorDataset):
mask=self.mask.to(dtype) if self.mask is not None else None,
features_g=self.features_g.to(dtype) if self.features_g is not None else None,
charges=self.charges.to(dtype) if self.charges is not None else None,
energies_p=self.energies_p.to(dtype) if self.energies_p is not None else None,
tag=self.tag
)
......@@ -905,7 +921,7 @@ class Normalization:
), dim=1)
@staticmethod
def __apply__(dataset, energy_op, energy_g_op, features_op, features_g_op, charges_op, inplace=False):
def __apply__(dataset, energy_op, energy_p_op, energy_g_op, features_op, features_g_op, charges_op, inplace=False):
"""
Applies operations to a dataset.
......@@ -915,6 +931,8 @@ class Normalization:
The dataset to apply operations to.
energy_op : Callable
Operation on energies.
energy_p_op : Callable
Operation on partial energies.
energy_g_op : Callable
Operation on energy gradients.
features_op : Callable
......@@ -943,6 +961,8 @@ class Normalization:
features = []
features_g = []
charges = []
energies_p = []
for i, point_dataset in enumerate(dataset.per_point_datasets):
features.append(features_op(point_dataset.features, i, inplace=inplace))
......@@ -956,14 +976,20 @@ class Normalization:
else:
charges.append(None)
if point_dataset.energies_p is not None:
energies_p.append(energy_p_op(point_dataset.energies_p, i, inplace=inplace))
else:
energies_p.append(None)
if inplace:
return dataset
else:
return Dataset(
PerCellDataset(energy=energy, mask=mask, energy_g=energy_g),
*(
PerPointDataset(features=f, mask=d.mask.clone(), features_g=fg, charges=c, tag=d.tag)
for f, fg, d, c in zip(features, features_g, dataset.per_point_datasets, charges)
PerPointDataset(features=f, mask=d.mask.clone(), features_g=fg, charges=c, energies_p=pe,
tag=d.tag)
for f, fg, d, c, pe in zip(features, features_g, dataset.per_point_datasets, charges, energies_p)
),
)
......@@ -983,8 +1009,8 @@ class Normalization:
result : Dataset
The resulting scaled dataset.
"""
return self.__apply__(dataset, self.fw_energy, self.fw_energy_g, self.fw_features, self.fw_features_g,
self.fw_charges, inplace=inplace)
return self.__apply__(dataset, self.fw_energy, self.fw_energy_components, self.fw_energy_g, self.fw_features,
self.fw_features_g, self.fw_charges, inplace=inplace)
def bw(self, dataset, inplace=False):
"""
......@@ -1002,8 +1028,8 @@ class Normalization:
result : Dataset
The original dataset.
"""
return self.__apply__(dataset, self.bw_energy, self.bw_energy_g, self.bw_features, self.bw_features_g,
self.bw_charges, inplace=inplace)
return self.__apply__(dataset, self.bw_energy, self.bw_energy_components, self.bw_energy_g, self.bw_features,
self.bw_features_g, self.bw_charges, inplace=inplace)
@staticmethod
def lsq_energy_offsets(dataset, pad=True):
......@@ -1210,7 +1236,7 @@ class Normalization:
def learn_cauldron(cells, descriptors, grad=False, normalize=True, extract_forces=True, extract_charges=False,
norm_kwargs=None):
energies_p=False, norm_kwargs=None):
"""
A function assembling data for learning.
......@@ -1230,6 +1256,8 @@ def learn_cauldron(cells, descriptors, grad=False, normalize=True, extract_force
If True, extracts forces from unit cell data.
extract_charges : bool
If True, extract atomic charges from unit cell data.
energies_p : bool
If True, extract partial energies.
norm_kwargs : dict
Arguments to normalization.
......@@ -1250,7 +1278,8 @@ def learn_cauldron(cells, descriptors, grad=False, normalize=True, extract_force
per_point = []
for i_s, s in enumerate(values_key):
per_point.append(PerPointDataset.from_cells(cells, descriptors[s], i_s, values, grad=grad,
charge=extract_charges, tag=i_s))
charge=extract_charges, energies_p=energies_p,
tag=i_s))
dataset = Dataset(per_cell, *per_point)
......@@ -1401,7 +1430,7 @@ class SequentialSoleEnergyNN(EnergyNNMixin, torch.nn.Sequential):
)
def fw_cauldron(modules, dataset, grad=False, normalization=None):
def fw_cauldron(modules, dataset, grad=False, energies_p=False, normalization=None):
"""
Propagates modules forward and assembles the total energy and gradients.
......@@ -1413,13 +1442,16 @@ def fw_cauldron(modules, dataset, grad=False, normalization=None):
The dataset with descriptors or tensors to assemble the dataset from.
grad : bool
If True, computes gradients wrt descriptors.
energies_p : bool
If True, presents total energy as a sum of per-atom contributions.
normalization : Normalization
Optional normalization to apply (backward).
Returns
-------
energy : Tensor
A `[n_samples, 1]` tensor with total energies.
energy : Tensor, list
A `[n_samples, 1]` tensor with total energies or a list of `[n_samples, n_species, 1]`
tensors with per-atom contributions.
gradients : Tensor, optional
A `[n_samples, n_atoms, 3]` tensor with total energy gradients.
"""
......@@ -1428,19 +1460,28 @@ def fw_cauldron(modules, dataset, grad=False, normalization=None):
if len(modules) != len(dataset.per_point_datasets):
raise ValueError(f"The module count does {len(modules):d} does not coincide with "
f"per-point dataset count {len(dataset.per_point_datasets):d}")
energy = torch.zeros_like(dataset.per_cell_dataset.energy)
if energies_p:
energy = []
else:
energy = torch.zeros_like(dataset.per_cell_dataset.energy)
if grad:
gradients = torch.zeros_like(dataset.per_cell_dataset.energy_g)
for m, d in zip(modules, dataset.per_point_datasets):
out = m(d.features, grad=grad)
if grad:
out, out_grad = out
energy += total_energy(out, d.mask)
if energies_p:
energy.append(out)
else:
energy += total_energy(out, d.mask)
if grad:
gradients += energy_gradients(out_grad, d.features_g)
if normalization:
atom_counts = normalization.atom_counts(dataset)
if energies_p:
for i, e in enumerate(energy):
normalization.bw_energy_components(e, i, inplace=True)
normalization.bw_energy(energy, atom_counts, inplace=True)
if grad:
normalization.bw_energy_g(gradients, inplace=True)
......
......@@ -326,7 +326,7 @@ def parse_lammps_input(f):
loss_result = namedtuple("stats_tuple", ("loss_id", "loss_value", "reference", "prediction", "components"))
def energy_loss(networks, data, criterion, w_energy, w_gradients):
def energy_loss(networks, data, criterion, w_energy, w_gradients, energies_p=False):
"""
Energy loss function.
......@@ -342,27 +342,52 @@ def energy_loss(networks, data, criterion, w_energy, w_gradients):
Energy weight in the loss function.
w_gradients : float
Gradients weight in the loss function.
energies_p : bool
If True, compares partial energies.
Returns
-------
result : loss_result
The resulting loss and accompanying information.
"""
result = fw_cauldron(networks, data, grad=w_gradients != 0)
result = fw_cauldron(networks, data, grad=w_gradients != 0, energies_p=energies_p)
if w_gradients != 0:
e, g = result
else:
e = result
g = None
de = criterion(e, data.per_cell_dataset.energy)
energy_loss_result = loss_result(
loss_id="energy",
loss_value=de,
reference=data.per_cell_dataset.energy,
prediction=e,
components=None,
)
if energies_p:
energy_loss_results = []
de = 0
for e_, data_ in zip(e, data.per_point_datasets):
de_ = criterion(e_ * data_.mask.unsqueeze(2), data_.energies_p)
de += de_
energy_loss_results.append(loss_result(
loss_id=f"partial energy {data_.tag}",
loss_value=de_,
reference=data_.energies_p,
prediction=e_,
components=None,
))
energy_loss_result = loss_result(
loss_id="energy",
loss_value=de,
reference=None,
prediction=None,
components=tuple(energy_loss_results),
)
else:
de = criterion(e, data.per_cell_dataset.energy)
energy_loss_result = loss_result(
loss_id="energy",
loss_value=de,
reference=data.per_cell_dataset.energy,
prediction=e,
components=None,
)
if w_gradients == 0:
return energy_loss_result
......@@ -535,7 +560,7 @@ class SimpleClosure:
def simple_energy_closure(networks, dataset=None, criterion=None, optimizer=None, optimizer_kwargs=None,
w_energy=1, w_gradients=0):
w_energy=1, w_gradients=0, energies_p=False):
"""
Energy and forces closure.
......@@ -555,6 +580,8 @@ def simple_energy_closure(networks, dataset=None, criterion=None, optimizer=None
Energy weight in the loss function.
w_gradients : float
Gradients weight in the loss function.
energies_p : bool
If True, considers the loss of partial energy.
Returns
-------
......@@ -563,7 +590,7 @@ def simple_energy_closure(networks, dataset=None, criterion=None, optimizer=None
"""
return SimpleClosure(networks, energy_loss, dataset=dataset, criterion=criterion, optimizer=optimizer,
optimizer_kwargs=optimizer_kwargs,
loss_function_kwargs=dict(w_energy=w_energy, w_gradients=w_gradients))
loss_function_kwargs=dict(w_energy=w_energy, w_gradients=w_gradients, energies_p=energies_p))
def simple_charges_closure(networks, dataset=None, criterion=None, optimizer=None, optimizer_kwargs=None):
......
......@@ -110,6 +110,8 @@ class LJBoxTest(TestCase):
w.compute_distances(cutoff)
if interaction is not None:
w.meta["total-energy"] = w.total(interaction, ignore_missing_species=True)
w.meta["partial-energy"] = w.eval(interaction, "kernel", ignore_missing_species=True, squeeze=False, resolved=True).sum(axis=0)
assert_allclose(w.meta["total-energy"], w.meta["partial-energy"].sum())
w.meta["forces"] = - w.grad(interaction, ignore_missing_species=True)
result.append(w)
return result
......@@ -207,7 +209,7 @@ class LJBoxTest(TestCase):
m[0].bias.data[:] = (normalization.features_offsets[i].sum() - normalization.energy_offsets[i]) / normalization.energy_scale
return modules
def __test_shapes__(self, dataset, grad, grad_pc=True, charges=False):
def __test_shapes__(self, dataset, grad, grad_pc=True, charges=False, energies_p=False):
assert_equal(dataset.dtype, torch.float64)
assert_equal(len(dataset), self.n_samples)
assert_equal(dataset.per_cell_dataset.is_gradient_available(), grad_pc)
......@@ -226,7 +228,6 @@ class LJBoxTest(TestCase):
if grad_pc:
assert_equal(dataset.per_cell_dataset.energy_g.shape, (self.n_samples, self.n_atoms.max(), 3))
assert_equal(dataset.per_cell_dataset.mask.shape, (self.n_samples, self.n_atoms.max()))
n_atoms_max = max(self.n_atoms)
assert_equal(len(dataset.per_point_datasets), len(self.descriptors))
......@@ -248,6 +249,8 @@ class LJBoxTest(TestCase):
if grad:
assert_equal(ppd.features_g.shape, [self.n_samples, n_species_max, len(self.descriptors[k]),
n_atoms_max, 3])
if energies_p:
assert_equal(ppd.energies_p.shape, (self.n_samples, n_species_max, 1))
assert_equal(ppd.mask.shape, [self.n_samples, n_species_max])
if charges:
......@@ -335,24 +338,24 @@ class LJBoxTest(TestCase):
def __gc__(self):
if "charges" not in self.cells_wrapped[0].meta:
return product((False, True), (False,))
return chain(product((False, True), (False,), (False,)), ((False, False, True),))
else:
return product((False, True), (False, True))
return chain(product((False, True), (False, True), (False,)), ((False, False, True),))
def test_integration(self):
for grad, charges in self.__gc__():
for grad, charges, energies_p in self.__gc__():
dataset = ml.learn_cauldron(self.cells_wrapped, self.descriptors, grad=grad, normalize=False,
extract_charges=charges)
extract_charges=charges, energies_p=energies_p)
self.__test_shapes__(dataset, grad=grad, charges=charges)
self.__test_shapes__(dataset, grad=grad, charges=charges, energies_p=energies_p)
self.__test_ranges_std__(dataset, charges=charges)
self.__test_mask__(dataset)
# self.__test_slicing__(dataset) TODO: slicing fails because some tensors may be None
def test_integration_norm(self):
for grad, charges in self.__gc__():
for grad, charges, energies_p in self.__gc__():
dataset_ref = ml.learn_cauldron(self.cells_wrapped, self.descriptors, grad=grad, normalize=False,
extract_charges=charges)
extract_charges=charges, energies_p=energies_p)
if len(self.descriptors) == 1 and charges:
# Only a single specimen: all charges are zero
assert_equal(dataset_ref.charges, 0)
......@@ -360,10 +363,11 @@ class LJBoxTest(TestCase):
ml.learn_cauldron(self.cells_wrapped, self.descriptors, grad=grad, normalize=True,
extract_charges=charges)
dataset, norm_info = ml.learn_cauldron(self.cells_wrapped, self.descriptors, grad=grad, normalize=True,
extract_charges=charges, ignore_normalization_errors=True)
extract_charges=charges, ignore_normalization_errors=True,
energies_p=energies_p)
else:
dataset, norm_info = ml.learn_cauldron(self.cells_wrapped, self.descriptors, grad=grad, normalize=True,
extract_charges=charges)
extract_charges=charges, energies_p=energies_p)
self.__test_shapes__(dataset, grad=grad, charges=charges)
self.__test_ranges_nrm__(dataset, charges=charges)
......@@ -461,6 +465,21 @@ class LJBoxTest(TestCase):
assert_allclose(closure().detach(), 0, atol=1e-6)
self.__test_modules_valid__(modules, atol=2e-3)
def test_integration_ml_run_partial(self):
dataset = ml.learn_cauldron(self.cells_wrapped_trivial, self.descriptor_collection_trivial, energies_p=True, normalize=False)
dataset = dataset.to(torch.float32)
modules = self.__prep_unity_modules__(random=True)
closure = ml_util.simple_energy_closure(modules, dataset=dataset, optimizer=torch.optim.LBFGS, energies_p=True)
self.assertLess(1, closure()) # Check if initial guess is not optimal
for _ in range(2): # Only one or two steps needed to reach convergence
closure.optimizer_step()
assert_allclose(closure().detach(), 0, atol=1e-6)
self.__test_modules_valid__(modules, atol=2e-3)
def test_integration_ml_run_norm(self):
dataset, normalization = ml.learn_cauldron(
self.cells_wrapped_trivial, self.descriptor_collection_trivial, grad=False,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment