From 24221a9bd6bf9982825f947020a04f56f4b9982f Mon Sep 17 00:00:00 2001
From: Artem <gpulkin@gmail.com>
Date: Tue, 15 Jun 2021 16:08:20 +0200
Subject: [PATCH] kernel: better code through assuming Cell is immutable

---
 miniff/dyn.py         | 10 +++++-----
 miniff/kernel.py      | 35 ++++++++++++++++++++++++-----------
 miniff/test_dyn.py    |  2 +-
 miniff/test_kernel.py |  2 +-
 4 files changed, 31 insertions(+), 18 deletions(-)

diff --git a/miniff/dyn.py b/miniff/dyn.py
index 1dc1268..7ab205e 100644
--- a/miniff/dyn.py
+++ b/miniff/dyn.py
@@ -87,7 +87,7 @@ class DynWrapper:
     @property
     def cartesian(self) -> np.ndarray:
         """Cartesian coordinates."""
-        return self.wrapper.nw.cell.cartesian()
+        return self.wrapper.nw.cell.cartesian
 
     def snapshot(self, cell=None, save=True, v=True):
         """
@@ -114,9 +114,9 @@ class DynWrapper:
         if not custom:
             cell = self.wrapper.get_current_cell()
         if isinstance(cell, np.ndarray):
-            cell = self.wrapper.nw.cell.__class__(self.wrapper.nw.vectors, cell @ self.wrapper.nw._cart2cry_transform_matrix,
-                        self.wrapper.nw.values)  # TODO: fix this
-        c = cell.cartesian().reshape(-1)
+            cell = self.wrapper.nw.cell.__class__(self.wrapper.nw.vectors, cell, self.wrapper.nw.values,
+                                                  c_basis="cartesian")
+        c = cell.cartesian.reshape(-1)
         if "total-energy" not in cell.meta:
             cell.meta["total-energy"] = self.wrapper.f(c)
         if "forces" not in cell.meta:
@@ -170,7 +170,7 @@ class DynWrapper:
             self.wrapper.start_recording()
         result = driver(
             self.wrapper.f,
-            x0=self.wrapper.nw.cell.cartesian().reshape(-1),
+            x0=self.wrapper.nw.cell.cartesian.reshape(-1),
             jac=self.wrapper.g,
             **kwargs
         )
diff --git a/miniff/kernel.py b/miniff/kernel.py
index ce94272..d5d978f 100644
--- a/miniff/kernel.py
+++ b/miniff/kernel.py
@@ -12,6 +12,7 @@ from scipy.spatial import cKDTree
 
 from itertools import product
 from collections import namedtuple, OrderedDict
+from functools import cached_property
 
 
 def encode_species(species, lookup, default=None):
@@ -132,7 +133,7 @@ class SpeciesEncoder:
 
 
 class Cell:
-    def __init__(self, vectors, coordinates, values, meta=None):
+    def __init__(self, vectors, coordinates, values, meta=None, c_basis=None):
         """
         A minimal implementation of a box with points.
 
@@ -146,16 +147,26 @@ class Cell:
             Point specifiers.
         meta : dict
             Optional metadata.
+        c_basis : {'cartesian', 'cell', None}
+            Coordinate basis: cartesian or cell (default).
         """
         vectors = np.asanyarray(vectors)
         coordinates = np.asanyarray(coordinates)
         values = np.asanyarray(values)
+        if c_basis is None:
+            c_basis = 'cell'
 
         inputs = locals()
 
         __assert_dimension_count__(inputs, "vectors", 2, "coordinates", 2, "values", 1)
         __assert_same_dimension__(inputs, "basis size", "coordinates", 1, "vectors", 0)
         self.vectors = vectors
+        if c_basis == 'cell':
+            pass
+        elif c_basis == 'cartesian':
+            coordinates = self.transform_from_cartesian(coordinates)
+        else:
+            raise ValueError(f"Unknown basis: {c_basis}")
         self.coordinates = coordinates
         self.values = values
         if meta is None:
@@ -163,10 +174,14 @@ class Cell:
         else:
             self.meta = dict(meta)
 
-    @property
+    @cached_property
     def size(self):
         return len(self.coordinates)
 
+    @cached_property
+    def vectors_inv(self):
+        return np.linalg.inv(self.vectors)
+
     def copy(self):
         """
         A copy of the box.
@@ -178,6 +193,7 @@ class Cell:
         """
         return Cell(self.vectors.copy(), self.coordinates.copy(), self.values.copy(), self.meta)
 
+    @cached_property
     def cartesian(self):
         """
         Cartesian coordinates of points.
@@ -203,7 +219,7 @@ class Cell:
         result : ndarray
             The transformed coordinates.
         """
-        return coordinates @ np.linalg.inv(self.vectors)
+        return coordinates @ self.vectors_inv
 
     def distances(self, cutoff=None, other=None):
         """
@@ -221,11 +237,11 @@ class Cell:
         result : np.ndarray, csr_matrix
             The resulting distance matrix.
         """
-        this = self.cartesian()
+        this = self.cartesian
         if other is None:
             other = this
         elif isinstance(other, Cell):
-            other = other.cartesian()
+            other = other.cartesian
 
         if cutoff is None:
             return np.linalg.norm(this[:, np.newaxis] - other[np.newaxis, :], axis=-1)
@@ -234,7 +250,7 @@ class Cell:
             other = cKDTree(other)
             return this.sparse_distance_matrix(other, max_distance=cutoff, )
 
-    @property
+    @cached_property
     def volume(self):
         return abs(np.linalg.det(self.vectors))
 
@@ -366,7 +382,6 @@ class NeighborWrapper(SpeciesEncoder):
         self.shift_vectors = self.shift_id = self.cell = self.cutoff =\
             self.sparse_pair_distances = self.species = self.spec_encoded_row = self.cartesian_row =\
             self.cartesian_col = self.species_lookup = self.reciprocal_cutoff = self.reciprocal_grid = None
-        self._cart2cry_transform_matrix = None
         self.set_cell(cell, normalize=normalize)
         if cutoff is not None or x is not None:
             self.set_cutoff(cutoff=cutoff, x=x, pbc=pbc)
@@ -439,8 +454,7 @@ class NeighborWrapper(SpeciesEncoder):
             self.cell.coordinates %= 1
         self.cutoff = self.sparse_pair_distances = self.species = self.spec_encoded_row = self.cartesian_row =\
             self.cartesian_col = self.reciprocal_cutoff = self.reciprocal_grid = None
-        self._cart2cry_transform_matrix = np.linalg.inv(self.cell.vectors)
-        self.cartesian_row = cell.cartesian()
+        self.cartesian_row = cell.cartesian
 
     def set_coordinates(self, coordinates=None, vectors=None, normalize=True):
         """
@@ -462,10 +476,9 @@ class NeighborWrapper(SpeciesEncoder):
             vectors = np.array(vectors)
             self.reciprocal_grid = None
             self.cell.vectors = vectors
-            self._cart2cry_transform_matrix = np.linalg.inv(self.cell.vectors)
         if coordinates is not None:
             coordinates = np.array(coordinates)
-            coordinates_cry = coordinates @ self._cart2cry_transform_matrix
+            coordinates_cry = self.cell.transform_from_cartesian(coordinates)
             if normalize:
                 coordinates_cry %= 1
                 coordinates = coordinates_cry @ self.cell.vectors
diff --git a/miniff/test_dyn.py b/miniff/test_dyn.py
index 18e5ff8..17ef458 100644
--- a/miniff/test_dyn.py
+++ b/miniff/test_dyn.py
@@ -20,7 +20,7 @@ class H2test(TestCase):
         c = self.system.snapshot(save=False)
         self.system.pop()
         e = c.meta['total-energy']
-        d = c.cartesian()
+        d = c.cartesian
         d = np.linalg.norm(d[0] - d[1])
         testing.assert_allclose(e, -1)
         testing.assert_allclose(d, self.d_eq)
diff --git a/miniff/test_kernel.py b/miniff/test_kernel.py
index 03579d6..b8e5d83 100644
--- a/miniff/test_kernel.py
+++ b/miniff/test_kernel.py
@@ -194,7 +194,7 @@ class TestNW(TestCase):
             nw.compute_distances(self.nw_distorted.cutoff)
             assert nw.sparse_pair_distances.nnz == 6
             return nw.total(self.potentials_dummy_sp)
-        testing.assert_allclose(self.nw_distorted.grad(self.potentials_dummy_sp), util.num_grad(f, self.nw_distorted.cell.cartesian()), atol=1e-10)
+        testing.assert_allclose(self.nw_distorted.grad(self.potentials_dummy_sp), util.num_grad(f, self.nw_distorted.cell.cartesian), atol=1e-10)
 
     def test_relax(self):
         relaxed = self.nw_distorted.relax(self.potentials_p)
-- 
GitLab