Commit 6dd6359c authored by Artem Pulkin's avatar Artem Pulkin
Browse files

workflows: make save similar to load

parent 52f1eccd
Pipeline #85960 passed with stages
in 25 minutes and 46 seconds
......@@ -179,6 +179,9 @@ def pull(a):
return a.detach().cpu().numpy()
file_openers = {"gz": gzip.open}
class Workflow:
def __init__(self, dtype=torch.float64, log=None, seed=None, mpl_backend=None, mpl_save_ext="png", units=None,
units_are_known=False, tag=None):
......@@ -256,7 +259,7 @@ class Workflow:
"""
filename = Path(filename)
if opener == "auto":
opener = {"gz": gzip.open}.get(filename.suffix, open)
opener = file_openers.get(filename.suffix, open)
with opener(filename, 'rt') as f:
result = Cell.load(f)
if isinstance(result, list):
......@@ -358,27 +361,29 @@ class Workflow:
self.cells = self.cells[:subset]
return self.cells
def save_cells(self, destination, cells=None, **kwargs):
def save_cells(self, filename, cells=None, opener="auto", **kwargs):
"""
Save cells into a file.
Parameters
----------
destination : str, file
filename : str, Path
The file to save to.
cells : Iterable, None
Cells to save. Defaults to ``self.cells``.
opener : {"opener", Callable}
File opener.
kwargs
Arguments to serializer.
"""
filename = Path(filename)
if cells is None:
cells = self.cells
self.log.info(f"Saving {len(cells):d} structures to {destination} ...")
if isinstance(destination, str):
with open(destination, 'w') as f:
Cell.save(cells, f, **kwargs)
else:
Cell.save(cells, destination, **kwargs)
self.log.info(f"Saving {len(cells):d} structures to {filename} ...")
if opener == "auto":
opener = file_openers.get(filename.suffix, open)
with opener(filename, 'wt') as f:
Cell.save(cells, f, **kwargs)
@property
def cutoff(self):
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment