From 290ad8783168e77ce7a8c71f92ebee42358295f9 Mon Sep 17 00:00:00 2001 From: Bas Nijholt <basnijholt@gmail.com> Date: Thu, 16 Feb 2017 20:47:23 +0100 Subject: [PATCH] introduce self.futures --- Learner-parallel.ipynb | 8 ++++---- learner1D.py | 20 ++++++++++++++++---- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/Learner-parallel.ipynb b/Learner-parallel.ipynb index b516f0d5..fa74784f 100644 --- a/Learner-parallel.ipynb +++ b/Learner-parallel.ipynb @@ -140,7 +140,7 @@ "learner.initialize(func2, -1, 1)\n", "\n", "while True:\n", - " if len(client.futures) < num_cores:\n", + " if len(learner.futures) < num_cores:\n", " xs = learner.choose_points(n=1)\n", " learner.map(func, xs)\n", " if len(learner.data) > 100: # bad criterion\n", @@ -161,7 +161,7 @@ "learner.initialize(func2, -1, 1)\n", "\n", "while True:\n", - " if len(client.futures) < num_cores:\n", + " if len(learner.futures) < num_cores:\n", " xs = learner.choose_points(n=1)\n", " learner.map(func, xs)\n", " if len(learner.get_done()) > 150: # bad criterion\n", @@ -182,13 +182,13 @@ "learner.initialize(func_wait, -1, 1)\n", "\n", "while True:\n", - " if len(client.futures) < num_cores:\n", + " if len(learner.futures) < num_cores:\n", " xs = learner.choose_points(n=1)\n", " learner.map(func_wait, xs)\n", " if learner.get_largest_interval() < 0.01 * learner.x_range:\n", " break\n", "\n", - "print(len(learner.data), len(client.futures))\n", + "print(len(learner.data), len(learner.futures))\n", "plot(learner)" ] }, diff --git a/learner1D.py b/learner1D.py index 5aaf9afd..1d062fe6 100644 --- a/learner1D.py +++ b/learner1D.py @@ -68,6 +68,8 @@ class Learner1D(object): self.num_done = 0 + self.futures = {} + def loss(self, x_left, x_right): """Calculate loss in the interval x_left, x_right. @@ -152,10 +154,6 @@ class Learner1D(object): self.largest_interval = np.diff(xs).max() return self.largest_interval - def get_done(self): - done = {x: y for x, y in self.data.items() if y is not None} - return done - def interpolate(self): xdata = [] ydata = [] @@ -199,10 +197,23 @@ class Learner1D(object): except KeyError: pass + def get_done(self): + done = {x: y for x, y in self.data.items() if y is not None} + return done + + def add_futures(self, xs, ys): + """Add concurrent.futures to the self.futures dict.""" + try: + for x, y in zip(xs, ys): + self.futures[x] = y + except TypeError: + self.futures[xs] = ys + def done_callback(self, n, tol): @synchronized def wrapped(future): x, y = future.result() + self.futures.pop(x) return self.add_data(x, y) return wrapped @@ -210,6 +221,7 @@ class Learner1D(object): ys = self.client.map(add_arg(func), xs) for y in ys: y.add_done_callback(self.done_callback(tol, n)) + self.add_futures(xs, ys) def initialize(self, func, xmin, xmax): self.map(func, [xmin, xmax]) -- GitLab