Skip to content
Snippets Groups Projects
Commit 5ba8c44b authored by Artem Pulkin's avatar Artem Pulkin
Browse files

kernel: add support for cell_gradients

parent 55c99d81
No related branches found
No related tags found
No related merge requests found
......@@ -11,6 +11,7 @@ from scipy.optimize import minimize
from scipy.spatial import cKDTree
from itertools import product
from collections import namedtuple, OrderedDict
def encode_species(species, lookup, default=None):
......@@ -441,22 +442,35 @@ class NeighborWrapper(SpeciesEncoder):
self._cart2cry_transform_matrix = np.linalg.inv(self.cell.vectors)
self.cartesian_row = cell.cartesian()
def set_cell_cartesian(self, cartesian, normalize=True):
def set_coordinates(self, coordinates=None, vectors=None, normalize=True):
"""
Sets new cartesian coordinates.
Sets new cartesian coordinates and cell vectors.
Parameters
----------
cartesian
Atomic coordinates.
coordinates : np.ndarray
New atomic coordinates.
vectors : np.ndarray
New cell vectors.
normalize : bool
Normalizes the cell if True.
If True, normalizes coordinates.
"""
self.cell.coordinates = np.array(cartesian @ self._cart2cry_transform_matrix)
self.cartesian_row = cartesian
if normalize:
self.cell.coordinates %= 1
self.cartesian_row = self.cell.coordinates @ self.cell.vectors
if coordinates is None and vectors is None:
raise ValueError("Coordinates and vectors are both None: nothing to update")
self.sparse_pair_distances = self.cartesian_row = self.cartesian_col = None
if vectors is not None:
vectors = np.array(vectors)
self.reciprocal_grid = None
self.cell.vectors = vectors
self._cart2cry_transform_matrix = np.linalg.inv(self.cell.vectors)
if coordinates is not None:
coordinates = np.array(coordinates)
coordinates_cry = coordinates @ self._cart2cry_transform_matrix
if normalize:
coordinates_cry %= 1
coordinates = coordinates_cry @ self.cell.vectors
self.cartesian_row = coordinates
self.cell.coordinates = coordinates_cry
def compute_distances(self, cutoff=None):
"""
......@@ -472,8 +486,8 @@ class NeighborWrapper(SpeciesEncoder):
if cutoff is None:
raise ValueError("No cutoff specified")
# Create a super-cell with neighbors
self.shift_id = np.repeat(self.shift_vectors[:, np.newaxis, :], len(self.cartesian_row)).reshape(-1, 3)
self.cartesian_col = (self.cartesian_row[np.newaxis, :, :] + self.shift_vectors[:, np.newaxis, :] @ self.vectors).reshape(-1, 3)
self.shift_id = np.repeat(self.shift_vectors[:, np.newaxis, :], len(self.cartesian_row), 1).reshape(-1, 3)
self.cartesian_col = (self.cartesian_row[np.newaxis, :, :] + (self.shift_vectors[:, np.newaxis, :] @ self.vectors)).reshape(-1, 3)
# Collect close neighbors
t_row = cKDTree(self.cartesian_row)
......@@ -481,7 +495,7 @@ class NeighborWrapper(SpeciesEncoder):
spd = t_row.sparse_distance_matrix(t_col, cutoff, output_type="coo_matrix")
# Get rid of the-diagonal
# Get rid of the diagonal
mask = spd.row + len(self.shift_vectors) // 2 * len(self.cartesian_row) != spd.col
self.sparse_pair_distances = coo_matrix((spd.data[mask], (spd.row[mask], spd.col[mask])), shape=spd.shape).tocsr()
......@@ -730,6 +744,27 @@ class NeighborWrapper(SpeciesEncoder):
"""
return self.total(potentials, kname=kname, **kwargs)
def grad_cell(self, potentials, kname="kernel_cell_gradient", **kwargs):
"""
Total energy cell gradients.
Parameters
----------
potentials : list, LocalPotential
A list of potentials or a single potential.
kname : str, None
Function to evaluate: 'kernel', 'kernel_gradient' or whatever
other kernel function set for all potentials in the list.
kwargs
Other arguments to `total`.
Returns
-------
gradients : np.ndarray
Total energy gradients.
"""
return self.total(potentials, kname=kname, **kwargs)
def relax(self, potentials, rtn_history=False, normalize=True, inplace=False, prefer_parallel=None,
driver=minimize, **kwargs):
"""
......@@ -781,8 +816,12 @@ class SnapshotHistory(list):
self.append(cell)
sf_parameters = namedtuple("sf_parameters", ("coordinates", "vectors"))
class ScalarFunctionWrapper:
def __init__(self, nw, potentials, normalize=True, prefer_parallel=None, cell_logger=None):
def __init__(self, nw, potentials, include_coordinates=True, include_vectors=False, normalize=True,
prefer_parallel=None, cell_logger=None):
"""
A wrapper providing interfaces to total energy and gradient evaluation.
......@@ -792,6 +831,10 @@ class ScalarFunctionWrapper:
The (initial) structure. It will be overwritten.
potentials : list, LocalPotential
Potentials defining the total energy value.
include_coordinates : bool
If True, includes coordinates into parameters.
include_vectors : bool
If True, includes vectors into parameters.
normalize : bool
Normalizes the cell at each step if True.
prefer_parallel : bool
......@@ -803,6 +846,7 @@ class ScalarFunctionWrapper:
raise ValueError("No NeighborWrapper distance information available: please run `nw.compute_distances()`")
self.nw = nw
self.potentials = potentials
self.include = OrderedDict([("coordinates", include_coordinates), ("vectors", include_vectors)])
self.normalize = normalize
self.prefer_parallel = prefer_parallel
self.cell_logger = cell_logger
......@@ -811,9 +855,32 @@ class ScalarFunctionWrapper:
self._last = {}
self._history = []
def p2c(self, parameters: np.ndarray) -> sf_parameters:
"""Unpacks parameters."""
result = []
for k, v in self.include.items():
if v:
ref = getattr(self.nw, k)
result.append(parameters[:ref.size].reshape(*ref.shape))
parameters = parameters[ref.size:]
else:
result.append(None)
return sf_parameters(*result)
def c2p(self, coordinates: sf_parameters) -> np.ndarray:
"""Packs parameters."""
result = []
for k, v in self.include.items():
if v:
result.append(getattr(coordinates, k).reshape(-1))
return np.concatenate(result)
def push(self):
"""Saves the state."""
self._history.append(self.nw.cell.cartesian())
self._history.append(self.c2p(sf_parameters(
coordinates=self.nw.cell.cartesian(),
vectors=self.nw.vectors,
)))
def pop(self):
"""Restores the previous state."""
......@@ -852,8 +919,8 @@ class ScalarFunctionWrapper:
True if no operation was performed.
"""
if self._last_parameters is None or abs(parameters - self._last_parameters).max() > 0:
cartesian = parameters.reshape(-1, 3)
self.nw.set_cell_cartesian(cartesian, normalize=self.normalize)
unpacked = self.p2c(parameters)
self.nw.set_coordinates(**unpacked._asdict(), normalize=self.normalize)
self.nw.compute_distances()
self._last_parameters = parameters
self._last = {}
......
......@@ -306,6 +306,14 @@ class _KernelCellGrad(_KernelGrad):
return self.f.raw(r_indptr, r_indices, _r_data, cartesian_row, _cartesian_col, shift_vectors, parameters,
species_row=species_row, species_mask=species_mask)
def __call__(self, r_indptr, r_indices, r_data, cartesian_row, cartesian_col, shift_vectors, species_row,
species_mask, out, **parameters):
all_kwargs = {**self.numgrad_kwargs, **parameters}
if "eps" not in all_kwargs and "a" in all_kwargs:
all_kwargs["eps"] = 1e-4 * all_kwargs["a"]
out += num_grad(self._target, np.zeros((3, 3), dtype=float), r_indptr, r_indices, r_data, cartesian_row,
cartesian_col, shift_vectors, species_row, species_mask, **all_kwargs)
def kernel_on_site(r_indptr, r_indices, r_data, cartesian_row, cartesian_col, shift_vectors, v0, species_row,
species_mask, out):
......
......@@ -6,7 +6,6 @@ from scipy.integrate import quad
import numericalunits as nu
from unittest import TestCase
from pathlib import Path
from functools import partial
from io import StringIO
......@@ -107,6 +106,7 @@ class TestNW(TestCase):
testing.assert_equal(self.nw.cutoff, a * 1.4)
testing.assert_equal(self.nw.species, ["a", "b"])
testing.assert_equal(self.nw.spec_encoded_row, [0, 1])
testing.assert_equal(self.nw.shift_id[:4], [(-1, -1, -1), (-1, -1, -1), (-1, -1, 0), (-1, -1, 0)])
self.assertEqual(self.nw.spec_encoded_row.dtype, np.int32)
def test_nw_fields_pairs(self):
......@@ -170,6 +170,15 @@ class TestNW(TestCase):
def test_dummy_gradient_sp(self):
testing.assert_allclose(self.nw.grad(self.potentials_dummy_sp), 0, atol=1e-12)
def test_dummy_cell_gradient_sp(self):
a = self.bond_length
t = self.potentials_dummy_p[1].parameters["f"].keywords["a"]
testing.assert_allclose(self.nw.grad_cell(self.potentials_dummy_sp), [
(2 * a * t, 2 * a * t * 3. ** .5, 0),
(2 * a * t, - 2 * a * t * 3. ** .5, 0),
(0, 0, 0)
], atol=1e-10)
def test_dummy_energy_sp_distorted(self):
testing.assert_allclose(self.nw_distorted.total(self.potentials_dummy_sp),
self.potentials_dummy_p[1].parameters["f"].keywords["a"] * 2 * (self.bond_length_small ** 2 + 2 * self.bond_length_large ** 2) +
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment