Commit b9335bf3 authored by Artem Pulkin's avatar Artem Pulkin
Browse files

workflow: simplify saving and loading objects

parent ccaa4c50
Pipeline #85999 passed with stages
in 24 minutes and 36 seconds
......@@ -302,50 +302,6 @@ class Cell:
)
return result
@classmethod
def load(cls, f):
"""
Load Cell(s) from stream.
Parameters
----------
f : file
File-like object.
Returns
-------
result: list, Cell
The resulting Cell(s).
"""
json = load(f)
squeeze = isinstance(json, dict)
if squeeze:
json = [json]
result = [cls.from_state_dict(i) for i in json]
if squeeze:
return result[0]
else:
return result
@staticmethod
def save(cells, f, **kwargs):
"""
Saves cells.
Parameters
----------
cells : list, Cell
Cells to save.
f : file
File-like object.
kwargs
Arguments to serializer.
"""
if isinstance(cells, (list, tuple, np.ndarray)):
dump([i.state_dict() for i in cells], f, **kwargs)
else:
dump(cells.state_dict(), f, **kwargs)
@classmethod
def random(cls, density, atoms, shape=None):
"""
......
......@@ -9,8 +9,7 @@ from scipy.optimize import root_scalar
from pathlib import Path
from .potentials import behler2_descriptor_family, behler_turning_point,\
behler4_descriptor_family, behler5_descriptor_family, potential_from_state_dict, behler5x_descriptor_family,\
ewald_k_descriptor_family
behler4_descriptor_family, behler5_descriptor_family, behler5x_descriptor_family, ewald_k_descriptor_family
from .ml import fw_cauldron, fw_cauldron_charges, Dataset, Normalization, potentials_from_ml_data
from .util import dict_reduce
......@@ -451,41 +450,6 @@ def torch_load(f, **kwargs):
return torch.load(f, **defaults)
def load_potentials(f, deserializer=torch_load):
"""
Loads a list of potentials from a file.
Parameters
----------
f : str, Path
The file name.
deserializer : Callable
Deserializer routine.
Returns
-------
result : list
A list of potentials.
"""
return list(map(potential_from_state_dict, deserializer(f)))
def save_potentials(potentials, f, serializer=torch.save):
"""
Saves a list of potentials to a file.
Parameters
----------
potentials : list
A list of potentials.
f : str, Path
The file to save to.
serializer : Callable
Serializer routine.
"""
serializer(list(i.state_dict() for i in potentials), f)
loss_result = namedtuple("loss_result", ("loss_id", "loss_value", "reference", "prediction", "components"))
......
......@@ -54,13 +54,6 @@ class TestCell(TestCase):
r = kernel.Cell.from_state_dict(self.c.state_dict())
self.__assert_cells_same__(self.c, r)
def test_save_load(self):
buffer = StringIO()
kernel.Cell.save(self.c, buffer)
buffer.seek(0)
r = kernel.Cell.load(buffer)
self.__assert_cells_same__(self.c, r)
def test_prop(self):
c = self.c
testing.assert_equal(c.size, 2)
......
......@@ -7,9 +7,10 @@ from pathlib import Path
import warnings
from .test_samples import amorphous_bise_json_sample
from .ml_util import load_potentials
from .ml import ml_potential_family, PotentialExtrapolationWarning
from .kernel import Cell
from .util import load_obj_by_ext
from .workflows import load_cell_list, load_potential_list
from . import __version__
......@@ -89,11 +90,11 @@ class MainTest(TestCase):
run:
save: {cells_direct}
""", cells=amorphous_bise_json_sample, potentials_out=True, cells_direct_out=True)
potentials = load_potentials(out["potentials"].name)
potentials = load_potential_list(out["potentials"].name)
assert len(potentials) == 2
for i in potentials:
self.assertIs(i.family, ml_potential_family)
cells = Cell.load(out["cells_direct"])
cells = load_cell_list(out["cells_direct"])
assert len(cells) == 10
def test_direct_simple(self):
......@@ -112,7 +113,7 @@ class MainTest(TestCase):
run:
save: {cells_direct}
""", cells=amorphous_bise_json_sample, cells_direct_out=True)
cells = Cell.load(out["cells_direct"])
cells = load_cell_list(out["cells_direct"])
assert len(cells) == 10
def test_relax(self):
......@@ -137,7 +138,7 @@ class MainTest(TestCase):
maxiter: 3
save: {cells_direct}
""", cells=amorphous_bise_json_sample, cells_direct_out=True)
cells = Cell.load(out["cells_direct"])
cells = load_cell_list(out["cells_direct"])
assert len(cells) == 10
def test_relax_parallel(self):
......@@ -163,7 +164,7 @@ class MainTest(TestCase):
parallel: true
save: {cells_direct}
""", cells=amorphous_bise_json_sample, cells_direct_out=True)
cells = Cell.load(out["cells_direct"])
cells = load_cell_list(out["cells_direct"])
assert len(cells) == 10
def test_relax_openmp(self):
......@@ -189,7 +190,7 @@ class MainTest(TestCase):
parallel: openmp
save: {cells_direct}
""", cells=amorphous_bise_json_sample, cells_direct_out=True)
cells = Cell.load(out["cells_direct"])
cells = load_cell_list(out["cells_direct"])
assert len(cells) == 10
def test_pre_relax(self):
......@@ -230,11 +231,11 @@ class MainTest(TestCase):
maxiter: 10
save: {cells_direct}
""", cells=amorphous_bise_json_sample, potentials_out=True, cells_direct_out=True)
potentials = load_potentials(out["potentials"].name)
potentials = load_potential_list(out["potentials"].name)
assert len(potentials) == 2
for i in potentials:
self.assertIs(i.family, ml_potential_family)
cells = Cell.load(out["cells_direct"])
cells = load_cell_list(out["cells_direct"])
assert len(cells) == 1
def test_vc_relax(self):
......@@ -261,7 +262,7 @@ class MainTest(TestCase):
maxiter: 3
save: {cells_direct}
""", cells=amorphous_bise_json_sample, cells_direct_out=True)
cells = Cell.load(out["cells_direct"])
cells = load_cell_list(out["cells_direct"])
assert len(cells) == 10
def test_relax_snapshots(self):
......@@ -287,7 +288,7 @@ class MainTest(TestCase):
maxiter: 3
save: {cells_direct}
""", cells=amorphous_bise_json_sample, cells_direct_out=True)
cells = Cell.load(out["cells_direct"])
cells = load_cell_list(out["cells_direct"])
assert 100 > len(cells) >= 30
def test_relax_snapshots_cartesian_delta(self):
......@@ -314,7 +315,7 @@ class MainTest(TestCase):
maxiter: 3
save: {cells_direct}
""", cells=amorphous_bise_json_sample, cells_direct_out=True)
cells = Cell.load(out["cells_direct"])
cells = load_cell_list(out["cells_direct"])
assert len(cells) == 20 # first and last per input
def test_relax_snapshots_fidelity(self):
......@@ -341,7 +342,7 @@ class MainTest(TestCase):
maxiter: 3
save: {cells_direct}
""", cells=amorphous_bise_json_sample, cells_direct_out=True)
cells = Cell.load(out["cells_direct"])
cells = load_cell_list(out["cells_direct"])
assert 100 > len(cells) >= 20 # fidelity is +inf
def test_md(self):
......@@ -369,7 +370,7 @@ class MainTest(TestCase):
x: 2
save: {cells_direct}
""", cells=amorphous_bise_json_sample, cells_direct_out=True)
cells = Cell.load(out["cells_direct"])
cells = load_cell_list(out["cells_direct"])
assert len(cells) == 10
def test_parallel_descriptors(self):
......@@ -386,7 +387,7 @@ class MainTest(TestCase):
save: True
save_fn: {potential}
""", cells=amorphous_bise_json_sample, potentials_out=True)
potentials = load_potentials(out["potentials"].name)
potentials = load_potential_list(out["potentials"].name)
assert len(potentials) == 2
for i in potentials:
self.assertIs(i.family, ml_potential_family)
......@@ -405,7 +406,7 @@ class MainTest(TestCase):
save: True
save_fn: {potential}
""", cells=amorphous_bise_json_sample, potentials_out=True)
potentials = load_potentials(out["potentials"].name)
potentials = load_potential_list(out["potentials"].name)
assert len(potentials) == 2
for i in potentials:
self.assertIs(i.family, ml_potential_family)
......@@ -425,7 +426,7 @@ class MainTest(TestCase):
save: True
save_fn: {potential}
""", cells=amorphous_bise_json_sample, potentials_out=True)
potentials = load_potentials(out["potentials"].name)
potentials = load_potential_list(out["potentials"].name)
assert len(potentials) == 2
for i in potentials:
self.assertIs(i.family, ml_potential_family)
......@@ -447,7 +448,7 @@ class MainTest(TestCase):
save: True
save_fn: {potential}
""", cells=amorphous_bise_json_sample, potentials_out=True)
potentials = load_potentials(out["potentials"].name)
potentials = load_potential_list(out["potentials"].name)
assert len(potentials) == 2
for i in potentials:
self.assertIs(i.family, ml_potential_family)
......@@ -474,7 +475,7 @@ class MainTest(TestCase):
save: True
save_fn: {potential}
""", cells=amorphous_bise_json_sample, potentials_out=True)
potentials = load_potentials(out["potentials"].name)
potentials = load_potential_list(out["potentials"].name)
assert len(potentials) == 2
for i in potentials:
self.assertIs(i.family, ml_potential_family)
......@@ -506,13 +507,13 @@ class MainTest(TestCase):
run:
save: {cells_direct}
""", cells=amorphous_bise_json_sample, potentials_out=True, cells_direct_out=True)
potentials = load_potentials(out["potentials"].name)
potentials = load_potential_list(out["potentials"].name)
assert len(potentials) == 2
for i in potentials:
self.assertIs(i.family, ml_potential_family)
assert len(potentials[0].descriptors) == 38
assert len(potentials[1].descriptors) == 38
cells = Cell.load(out["cells_direct"])
cells = load_cell_list(out["cells_direct"])
assert len(cells) == 10
def test_reg(self):
......@@ -530,7 +531,7 @@ class MainTest(TestCase):
save: True
save_fn: {potential}
""", cells=amorphous_bise_json_sample, potentials_out=True, cells_direct_out=True)
potentials = load_potentials(out["potentials"].name)
potentials = load_potential_list(out["potentials"].name)
assert len(potentials) == 2
for i in potentials:
self.assertIs(i.family, ml_potential_family)
......
......@@ -298,15 +298,3 @@ def loads(*args, **kwargs):
defaults = dict(object_hook=units_object_hook)
defaults.update(kwargs)
return json.loads(*args, **defaults)
class delayed:
def __init__(self, x):
self.x = x
@property
def val(self):
if isinstance(self.x, str):
return nu.nu_eval(self.x)
return self.x
......@@ -6,8 +6,12 @@ import logging
import sys
import hashlib
from collections import defaultdict
from functools import partial
from pathlib import Path
import gzip
from .units import delayed
from .units import load, dump
from numericalunits import nu_eval
def num_grad(scalar_f, x, *args, x_name=None, eps=1e-4, **kwargs):
......@@ -417,6 +421,15 @@ class DDict(dict):
self.duplicates = dict(duplicates)
class delayed:
def __init__(self, f):
self.f = f
@property
def val(self):
return self.f()
def duplicate_yaml_load(stream):
"""
Loads a YAML file with all duplicates.
......@@ -442,11 +455,109 @@ def duplicate_yaml_load(stream):
construct_mapping)
def construct_nu(_loader, _node):
return delayed(_loader.construct_scalar(_node))
return delayed(partial(nu_eval, _loader.construct_scalar(_node)))
DuplicateLoader.add_constructor("!nu", construct_nu)
return yaml.load(stream, DuplicateLoader)
def open_file_by_ext(f, *args, **kwargs):
"""
Opens a file.
Parameters
----------
f : str, Path, file
The file to open.
args
kwargs
Other arguments to opener.
Returns
-------
result : list
A list of Cells.
"""
if isinstance(f, str):
f = Path(f)
if isinstance(f, Path):
ext = f.suffix.lower()
if ext == ".gz":
return gzip.open(f, *args, **kwargs)
return open(f, *args, **kwargs)
def _maybe_first_suffix(f) -> str:
if isinstance(f, Path):
try:
return f.suffixes[0].lower()
except IndexError:
pass
else: # file object
try:
return Path(f.name).suffixes[0].lower()
except (AttributeError, IndexError):
pass
def load_obj_by_ext(f):
"""
Reads an object from file.
Parameters
----------
f : str, Path, file
The file to open.
Returns
-------
object
The resulting object.
"""
if isinstance(f, str):
f = Path(f)
suffix = _maybe_first_suffix(f)
if suffix == ".pt":
if isinstance(f, Path):
f = open_file_by_ext(f, 'rb')
return torch.load(f)
else:
if isinstance(f, Path):
f = open_file_by_ext(f, 'rt')
return load(f)
def dump_obj_by_ext(obj, f):
"""
Saves an object to file.
Parameters
----------
obj
Object to save.
f : str, Path, file
The file to open.
Returns
-------
object
The resulting object.
"""
if isinstance(f, str):
f = Path(f)
suffix = _maybe_first_suffix(f)
if suffix == ".pt":
if isinstance(f, Path):
f = open_file_by_ext(f, 'wb')
return torch.save(obj, f)
else:
if isinstance(f, Path):
f = open_file_by_ext(f, 'wt')
return dump(obj, f)
def _ro(a: np.ndarray) -> np.ndarray:
a.flags.writeable = False
return a
......
from ._util import get_num_threads
from .dyn import for_name as dynamics_for_name, relax as dyn_relax, integrate as dyn_integrate, \
nvt_vs as dyn_nvt_vs
from .ewald import ewald_cutoffs
from .kernel import Cell, compute_images, ScalarFunctionWrapper
from .ml import learn_cauldron, Dataset, Normalization, potentials_from_ml_data, cpu_copy
from .ml_util import default_behler_descriptors, \
default_behler_descriptors_3, default_long_range, behler_nn, simple_energy_closure, \
l2_regularization, composite_regularization, simple_charges_closure
from .potentials import ewald_k_descriptor_family, PotentialRuntimeWarning, potential_from_state_dict, LocalPotential
from .presentation import text_bars, plot_convergence, plot_diagonal
from .units import UnitsDict, check_units_known, new_units_context, init_default_atomic_units
from .util import dict_reduce, split2, DDict, duplicate_yaml_load, load_obj_by_ext, dump_obj_by_ext, delayed
import logging
import warnings
from collections import Counter, namedtuple
from functools import partial, wraps
from pathlib import Path
import gzip
import json
import matplotlib
import numericalunits as nu
......@@ -13,20 +25,6 @@ import torch
from matplotlib import pyplot
from torch.nn import DataParallel
from miniff._util import get_num_threads
from miniff.dyn import for_name as dynamics_for_name, relax as dyn_relax, integrate as dyn_integrate, \
nvt_vs as dyn_nvt_vs
from miniff.ewald import ewald_cutoffs
from miniff.kernel import Cell, compute_images, ScalarFunctionWrapper
from miniff.ml import learn_cauldron, Dataset, Normalization, potentials_from_ml_data, cpu_copy
from miniff.ml_util import default_behler_descriptors, \
default_behler_descriptors_3, default_long_range, behler_nn, simple_energy_closure, \
l2_regularization, composite_regularization, save_potentials, load_potentials, \
simple_charges_closure
from miniff.potentials import ewald_k_descriptor_family, PotentialRuntimeWarning, potential_from_state_dict
from miniff.presentation import text_bars, plot_convergence, plot_diagonal
from miniff.units import UnitsDict, check_units_known, delayed, new_units_context, init_default_atomic_units
from miniff.util import dict_reduce, split2, DDict, duplicate_yaml_load
def requires_fields(*names):
......@@ -101,43 +99,6 @@ def compute_images_ewald(cell, cutoff, ewald_potentials=None, pbc=True, units=No
return compute_images(cell, cutoff, reciprocal_cutoff=k_cut, pbc=pbc)
def load_all_potentials(potentials, log=None):
"""
Load previously saved potentials or descriptors.
Parameters
----------
potentials : list
File names to load from or an explicit list of potentials
or a list of their parameters.
log
Logger.
Returns
-------
potentials : list
The loaded potentials.
"""
if log is None:
def log(*args):
pass
if not isinstance(potentials, list):
potentials = [potentials]
result = []
for src in potentials:
log(f"Restoring potentials from {src} ...")
if isinstance(src, (str, Path)):
log(f" file {src}")
result.extend(list(map(potential_from_state_dict, read_object(src))))
elif isinstance(src, dict):
log(f" serialized {src}")
result.append(potential_from_state_dict(src))
else:
log(f" object {src}")
result.append(src)
return result
def minimum_loss_save_policy(tag, loss, state):
"""
A saving policy triggering whenever a global minimum in the loss function occurs.
......@@ -180,59 +141,48 @@ def pull(a):
return a.detach().cpu().numpy()
file_openers = {".gz": gzip.open}
def _cast_list(ls: list, t: type, convert: callable) -> list:
return list(i if isinstance(i, t) else convert(i) for i in ls)
def open_file(f, mode, opener="auto", **kwargs):
def load_cell_list(f):
"""
Opens a file.
Load cells.
Parameters
----------
f : str, Path, file
The file to open.
mode : str
Open mode.
opener : {"auto", Callable}
File opener.
kwargs
Other arguments to opener.
f : list, str, Path, file
Cells data, a filename or a file to load from.
Returns
-------
result : list
A list of Cells.
"""
if isinstance(f, str):
f = Path(f)
if isinstance(f, Path):
if opener == "auto":
f = file_openers.get(f.suffix.lower(), open)(f, mode, **kwargs)
return f
if not isinstance(f, list):
f = load_obj_by_ext(f)
if not isinstance(f, list):
f = [f]
return _cast_list(f, Cell, Cell.from_state_dict)
def read_object(f, **kwargs):
def load_potential_list(f):
"""
Reads an object from file.
Load potentials.
Parameters
----------
f : str, Path, file
The file to open.
kwargs
Other arguments to opener.
f : list, str, Path, file
Potential data, a filename or a file to load from.
Returns
-------
object
The resulting object.
result : list
A list of potentials.
"""
if isinstance(f, (str, Path)):
f = Path(f)
suffix = f.suffixes[0].lower()
if suffix == ".pt":
return torch.load(open_file(f, 'rb', **kwargs))
return json.load(open_file(f, 'rt', **kwargs))
if not isinstance(f, list