From 290ad8783168e77ce7a8c71f92ebee42358295f9 Mon Sep 17 00:00:00 2001
From: Bas Nijholt <basnijholt@gmail.com>
Date: Thu, 16 Feb 2017 20:47:23 +0100
Subject: [PATCH] introduce self.futures

---
 Learner-parallel.ipynb |  8 ++++----
 learner1D.py           | 20 ++++++++++++++++----
 2 files changed, 20 insertions(+), 8 deletions(-)

diff --git a/Learner-parallel.ipynb b/Learner-parallel.ipynb
index b516f0d5..fa74784f 100644
--- a/Learner-parallel.ipynb
+++ b/Learner-parallel.ipynb
@@ -140,7 +140,7 @@
     "learner.initialize(func2, -1, 1)\n",
     "\n",
     "while True:\n",
-    "    if len(client.futures) < num_cores:\n",
+    "    if len(learner.futures) < num_cores:\n",
     "        xs = learner.choose_points(n=1)\n",
     "        learner.map(func, xs)\n",
     "    if len(learner.data) > 100: # bad criterion\n",
@@ -161,7 +161,7 @@
     "learner.initialize(func2, -1, 1)\n",
     "\n",
     "while True:\n",
-    "    if len(client.futures) < num_cores:\n",
+    "    if len(learner.futures) < num_cores:\n",
     "        xs = learner.choose_points(n=1)\n",
     "        learner.map(func, xs)\n",
     "    if len(learner.get_done()) > 150: # bad criterion\n",
@@ -182,13 +182,13 @@
     "learner.initialize(func_wait, -1, 1)\n",
     "\n",
     "while True:\n",
-    "    if len(client.futures) < num_cores:\n",
+    "    if len(learner.futures) < num_cores:\n",
     "        xs = learner.choose_points(n=1)\n",
     "        learner.map(func_wait, xs)\n",
     "    if learner.get_largest_interval() < 0.01 * learner.x_range:\n",
     "        break\n",
     "\n",
-    "print(len(learner.data), len(client.futures))\n",
+    "print(len(learner.data), len(learner.futures))\n",
     "plot(learner)"
    ]
   },
diff --git a/learner1D.py b/learner1D.py
index 5aaf9afd..1d062fe6 100644
--- a/learner1D.py
+++ b/learner1D.py
@@ -68,6 +68,8 @@ class Learner1D(object):
 
         self.num_done = 0
 
+        self.futures = {}
+
     def loss(self, x_left, x_right):
         """Calculate loss in the interval x_left, x_right.
 
@@ -152,10 +154,6 @@ class Learner1D(object):
             self.largest_interval = np.diff(xs).max()
             return self.largest_interval
 
-    def get_done(self):
-        done = {x: y for x, y in self.data.items() if y is not None}
-        return done
-
     def interpolate(self):
         xdata = []
         ydata = []
@@ -199,10 +197,23 @@ class Learner1D(object):
             except KeyError:
                 pass
 
+    def get_done(self):
+        done = {x: y for x, y in self.data.items() if y is not None}
+        return done
+
+    def add_futures(self, xs, ys):
+        """Add concurrent.futures to the self.futures dict."""
+        try:
+            for x, y in zip(xs, ys):
+                self.futures[x] = y
+        except TypeError:
+            self.futures[xs] = ys
+
     def done_callback(self, n, tol):
         @synchronized
         def wrapped(future):
             x, y = future.result()
+            self.futures.pop(x)
             return self.add_data(x, y)
         return wrapped
 
@@ -210,6 +221,7 @@ class Learner1D(object):
         ys = self.client.map(add_arg(func), xs)
         for y in ys:
             y.add_done_callback(self.done_callback(tol, n))
+        self.add_futures(xs, ys)
 
     def initialize(self, func, xmin, xmax):
         self.map(func, [xmin, xmax])
-- 
GitLab