Skip to content
Snippets Groups Projects

make 'fname' a parameter to 'save' and 'load' only

Merged Joseph Weston requested to merge feature/fname into master
2 files
+ 22
15
Compare changes
  • Side-by-side
  • Inline
Files
2
# -*- coding: utf-8 -*-
from collections import defaultdict
from collections import defaultdict, Iterable
from contextlib import suppress
from functools import partial
from operator import itemgetter
@@ -323,8 +323,9 @@ class BalancingLearner(BaseLearner):
Parameters
----------
fname: callable
Given a learner, returns a filename into which to save the data
fname: callable or sequence of strings
Given a learner, returns a filename into which to save the data.
Or a list (or iterable) with filenames.
compress : bool, default True
Compress the data upon saving using `gzip`. When saving
using compression, one must load it with compression too.
@@ -347,8 +348,12 @@ class BalancingLearner(BaseLearner):
>>> # Then save
>>> learner.save(combo_fname) # use 'load' in the same way
"""
for l in self.learners:
l.save(fname(l), compress=compress)
if isinstance(fname, Iterable):
for l, _fname in zip(fname, self.learners):
l.save(_fname, compress=compress)
else:
for l in self.learners:
l.save(fname(l), compress=compress)
def load(self, fname, compress=True):
"""Load the data of the child learners from pickle files
@@ -356,8 +361,9 @@ class BalancingLearner(BaseLearner):
Parameters
----------
fname: callable
Given a learner, returns a filename into which to save the data
fname: callable or sequence of strings
Given a learner, returns a filename from which to load the data.
Or a list (or iterable) with filenames.
compress : bool, default True
If the data is compressed when saved, one must load it
with compression too.
@@ -366,8 +372,12 @@ class BalancingLearner(BaseLearner):
-------
See the example in the `BalancingLearner.save` doc-string.
"""
for l in self.learners:
l.load(fname(l), compress=compress)
if isinstance(fname, Iterable):
for l, _fname in zip(fname, self.learners):
l.load(_fname, compress=compress)
else:
for l in self.learners:
l.load(fname(l), compress=compress)
def _get_data(self):
return [l._get_data() for l in learner.learners]
Loading