Skip to content
Snippets Groups Projects
Commit 329f447b authored by Bas Nijholt's avatar Bas Nijholt
Browse files

add an option to use a list of filenames when saving a BalancingLearner

parent fcdb0532
No related branches found
No related tags found
1 merge request!133make 'fname' a parameter to 'save' and 'load' only
Pipeline #13576 passed with warnings
# -*- coding: utf-8 -*-
from collections import defaultdict
from collections import defaultdict, Iterable
from contextlib import suppress
from functools import partial
from operator import itemgetter
......@@ -323,8 +323,9 @@ class BalancingLearner(BaseLearner):
Parameters
----------
fname: callable
Given a learner, returns a filename into which to save the data
fname: callable or sequence of strings
Given a learner, returns a filename into which to save the data.
Or a list (or iterable) with filenames.
compress : bool, default True
Compress the data upon saving using `gzip`. When saving
using compression, one must load it with compression too.
......@@ -347,8 +348,12 @@ class BalancingLearner(BaseLearner):
>>> # Then save
>>> learner.save(combo_fname) # use 'load' in the same way
"""
for l in self.learners:
l.save(fname(l), compress=compress)
if isinstance(fname, Iterable):
for l, _fname in zip(fname, self.learners):
l.save(_fname, compress=compress)
else:
for l in self.learners:
l.save(fname(l), compress=compress)
def load(self, fname, compress=True):
"""Load the data of the child learners from pickle files
......@@ -356,8 +361,9 @@ class BalancingLearner(BaseLearner):
Parameters
----------
fname: callable
Given a learner, returns a filename into which to save the data
fname: callable or sequence of strings
Given a learner, returns a filename from which to load the data.
Or a list (or iterable) with filenames.
compress : bool, default True
If the data is compressed when saved, one must load it
with compression too.
......@@ -366,8 +372,12 @@ class BalancingLearner(BaseLearner):
-------
See the example in the `BalancingLearner.save` doc-string.
"""
for l in self.learners:
l.load(fname(l), compress=compress)
if isinstance(fname, Iterable):
for l, _fname in zip(fname, self.learners):
l.load(_fname, compress=compress)
else:
for l in self.learners:
l.load(fname(l), compress=compress)
def _get_data(self):
return [l._get_data() for l in learner.learners]
......
......@@ -33,13 +33,10 @@ Saving and loading learners
Every learner has a `~adaptive.BaseLearner.save` and `~adaptive.BaseLearner.load`
method that can be used to save and load **only** the data of a learner.
There are **two ways** of naming the files: 1. Using the ``fname``
argument in ``learner.save(fname=...)`` 2. Setting the ``fname``
attribute, like ``learner.fname = 'data/example.p`` and then
``learner.save()``
Use the ``fname`` argument in ``learner.save(fname=...)``.
The second way *must be used* when saving the ``learner``\s of a
`~adaptive.BalancingLearner`.
Or, when using a `~adaptive.BalancingLearner` one can use either a callable
that takes the child learner and returns a filename **or** a list of filenames.
By default the resulting pickle files are compressed, to turn this off
use ``learner.save(fname=..., compress=False)``
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment