Skip to content
Snippets Groups Projects
Commit 7c938860 authored by Joseph Weston's avatar Joseph Weston
Browse files

make 'fname' a parameter to 'save' and 'load' only

This simplifies the API by making sure that the filenames are
only provided in one place (the calls to save and load).

Closes #122
parent ccba17dd
No related branches found
No related tags found
No related merge requests found
Pipeline #13414 passed
......@@ -317,70 +317,57 @@ class BalancingLearner(BaseLearner):
learners.append(learner)
return cls(learners, cdims=arguments)
def save(self, folder, compress=True):
def save(self, fname, compress=True):
"""Save the data of the child learners into pickle files
in a directory.
Parameters
----------
folder : str
Directory in which the learners's data will be saved.
fname: callable
Given a learner, returns a filename into which to save the data
compress : bool, default True
Compress the data upon saving using `gzip`. When saving
using compression, one must load it with compression too.
Notes
-----
The child learners need to have a 'fname' attribute in order to use
this method.
Example
-------
>>> def combo_fname(val):
... return '__'.join([f'{k}_{v}.p' for k, v in val.items()])
...
... def f(x, a, b): return a * x**2 + b
...
>>> learners = []
>>> for combo in adaptive.utils.named_product(a=[1, 2], b=[1]):
... l = Learner1D(functools.partial(f, combo=combo))
... l.fname = combo_fname(combo) # 'a_1__b_1.p', 'a_2__b_1.p' etc.
... learners.append(l)
... learner = BalancingLearner(learners)
... # Run the learner
... runner = adaptive.Runner(learner)
... # Then save
... learner.save('data_folder') # use 'load' in the same way
>>> def combo_fname(learner):
... val = learner.function.keywords # because functools.partial
... fname = '__'.join([f'{k}_{v}.pickle' for k, v in val])
... return 'data_folder/' + fname
>>>
>>> def f(x, a, b): return a * x**2 + b
>>>
>>> learners = [Learner1D(functools.partial(f, **combo), (-1, 1))
... for combo in adaptive.utils.named_product(a=[1, 2], b=[1]]
>>>
>>> learner = BalancingLearner(learners)
>>> # Run the learner
>>> runner = adaptive.Runner(learner)
>>> # Then save
>>> learner.save(combo_fname) # use 'load' in the same way
"""
if len(self.learners) != len(set(l.fname for l in self.learners)):
raise RuntimeError("The 'learner.fname's are not all unique.")
for l in self.learners:
l.save(os.path.join(folder, l.fname), compress=compress)
l.save(fname(l), compress=compress)
def load(self, folder, compress=True):
def load(self, fname, compress=True):
"""Load the data of the child learners from pickle files
in a directory.
Parameters
----------
folder : str
Directory from which the learners's data will be loaded.
fname: callable
Given a learner, returns a filename into which to save the data
compress : bool, default True
If the data is compressed when saved, one must load it
with compression too.
Notes
-----
The child learners need to have a 'fname' attribute in order to use
this method.
Example
-------
See the example in the `BalancingLearner.save` doc-string.
"""
for l in self.learners:
l.load(os.path.join(folder, l.fname), compress=compress)
l.load(fname(l), compress=compress)
def _get_data(self):
return [l._get_data() for l in learner.learners]
......
......@@ -107,48 +107,31 @@ class BaseLearner(metaclass=abc.ABCMeta):
"""
self._set_data(other._get_data())
def save(self, fname=None, compress=True):
def save(self, fname, compress=True):
"""Save the data of the learner into a pickle file.
Parameters
----------
fname : str, optional
The filename of the learner's pickle data file. If None use
the 'fname' attribute, like 'learner.fname = "example.p".
fname : str
The filename into which to save the learner's data.
compress : bool, default True
Compress the data upon saving using 'gzip'. When saving
using compression, one must load it with compression too.
Notes
-----
There are **two ways** of naming the files:
1. Using the ``fname`` argument in ``learner.save(fname='example.p')``
2. Setting the ``fname`` attribute, like
``learner.fname = "data/example.p"`` and then ``learner.save()``.
"""
fname = fname or self.fname
data = self._get_data()
save(fname, data, compress)
def load(self, fname=None, compress=True):
def load(self, fname, compress=True):
"""Load the data of a learner from a pickle file.
Parameters
----------
fname : str, optional
The filename of the saved learner's pickled data file.
If None use the 'fname' attribute, like
'learner.fname = "example.p".
fname : str
The filename from which to load the learner's data.
compress : bool, default True
If the data is compressed when saved, one must load it
with compression too.
Notes
-----
See the notes in the `save` doc-string.
"""
fname = fname or self.fname
with suppress(FileNotFoundError, EOFError):
data = load(fname, compress)
self._set_data(data)
......@@ -158,19 +141,3 @@ class BaseLearner(metaclass=abc.ABCMeta):
def __setstate__(self, state):
self.__dict__ = state
@property
def fname(self):
"""Filename for the learner when it is saved (or loaded) using
`~adaptive.BaseLearner.save` (or `~adaptive.BaseLearner.load` ).
"""
# This is a property because then it will be availible in the DataSaver
try:
return self._fname
except AttributeError:
raise AttributeError("Set 'learner.fname' or use the 'fname'"
" argument when using 'learner.save' or 'learner.load'.")
@fname.setter
def fname(self, fname):
self._fname = fname
......@@ -53,13 +53,13 @@ class DataSaver:
self.learner._set_data(learner_data)
@copy_docstring_from(BaseLearner.save)
def save(self, fname=None, compress=True):
def save(self, fname, compress=True):
# We copy this method because the 'DataSaver' is not a
# subclass of the 'BaseLearner'.
BaseLearner.save(self, fname, compress)
@copy_docstring_from(BaseLearner.load)
def load(self, fname=None, compress=True):
def load(self, fname, compress=True):
# We copy this method because the 'DataSaver' is not a
# subclass of the 'BaseLearner'.
BaseLearner.load(self, fname, compress)
......
......@@ -430,15 +430,15 @@ def test_saving_of_balancing_learner(learner_type, f, learner_kwargs):
learner = BalancingLearner([learner_type(f, **learner_kwargs)])
control = BalancingLearner([learner_type(f, **learner_kwargs)])
# set fnames
learner.learners[0].fname = 'test'
control.learners[0].fname = 'test'
simple(learner, lambda l: l.learners[0].npoints > 100)
folder = tempfile.mkdtemp()
def fname(learner):
return folder + 'test'
try:
learner.save(folder=folder)
control.load(folder=folder)
learner.save(fname)
control.load(fname)
if learner_type is not Learner1D:
# Because different scales result in differnt losses
np.testing.assert_almost_equal(learner.loss(), control.loss())
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment