From 9f42b003c018dabbfd6ca072245c6a46ee2484a6 Mon Sep 17 00:00:00 2001 From: Bas Nijholt <basnijholt@gmail.com> Date: Mon, 16 Sep 2019 17:25:38 +0200 Subject: [PATCH] add algo description figure --- figures.ipynb | 147 +++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 146 insertions(+), 1 deletion(-) diff --git a/figures.ipynb b/figures.ipynb index 0ec8695..733b3d7 100644 --- a/figures.ipynb +++ b/figures.ipynb @@ -14,7 +14,7 @@ "import matplotlib.pyplot as plt\n", "\n", "%matplotlib inline\n", - "%config InlineBackend.figure_format = 'svg'\n", + "# %config InlineBackend.figure_format = 'svg'\n", "\n", "golden_mean = (np.sqrt(5) - 1) / 2 # Aesthetic ratio\n", "fig_width_pt = 246.0 # Columnwidth\n", @@ -322,6 +322,151 @@ "plt.plot(np.cumsum(times))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Algo explaination" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig, axs = plt.subplots(1, 1, figsize=(fig_width, 3 * fig_height))\n", + "\n", + "def f(x, offset=0.123):\n", + " a = 0.2\n", + " return a ** 2 / (a ** 2 + (x - offset) ** 2)\n", + "\n", + "learner = adaptive.Learner1D(\n", + " f, (0, 2),\n", + " loss_per_interval=adaptive.learner.learner1D.curvature_loss_function()\n", + ")\n", + "learner._recompute_losses_factor = 0.1\n", + "xs_dense = np.linspace(*learner.bounds, 400)\n", + "ys_dense = f(xs_dense)\n", + "step = 0.4\n", + "\n", + "for i in range(11):\n", + " offset = -i * step\n", + "\n", + " x = learner.ask(1)[0][0]\n", + " y = f(x)\n", + " learner.tell(x, y)\n", + " xs, ys = map(np.array, zip(*sorted(learner.data.items())))\n", + " ys = ys + offset\n", + " if i >= 1:\n", + " axs.plot(xs_dense, ys_dense + offset, c=\"k\", alpha=0.3, zorder=0)\n", + " axs.plot(xs, ys, zorder=1, c=\"k\")\n", + " axs.scatter(xs, ys, alpha=1, zorder=2, c=\"k\")\n", + " (x_left, x_right), loss = list(learner.losses.items())[0] # it's a ItemSortedDict\n", + " (y_left, y_right) = [\n", + " learner.data[x_left] + offset,\n", + " learner.data[x_right] + offset,\n", + " ]\n", + " axs.scatter([x_left, x_right], [y_left, y_right], c=\"r\", s=10, zorder=3)\n", + " x_mid = np.mean((x_left, x_right))\n", + " y_mid = np.interp(x_mid, (x_left, x_right), (y_left, y_right))\n", + " axs.scatter(x_mid, y_mid, zorder=4, marker=\"x\", c=\"green\")\n", + "\n", + "axs.text(\n", + " -0.1,\n", + " 0.5,\n", + " (r\"$\\mathrm{time}$\" + \"\\n\" + \"$\\longleftarrow$\"),\n", + " transform=axs.transAxes,\n", + " horizontalalignment=\"center\",\n", + " verticalalignment=\"center\",\n", + " rotation=90,\n", + " fontsize=18,\n", + ")\n", + "\n", + "\n", + "# legend\n", + "\n", + "import matplotlib.patches as mpatches\n", + "import matplotlib.lines as mlines\n", + "class LargestInterval:\n", + " pass\n", + "\n", + "\n", + "class Interval:\n", + " pass\n", + "\n", + "\n", + "class Function:\n", + " pass\n", + "\n", + "\n", + "class IntervalHandler:\n", + " def __init__(self, with_inner=True, length=20, *args, **kwargs):\n", + " super().__init__(*args, **kwargs)\n", + " self.with_inner = with_inner\n", + " self.length = length\n", + "\n", + " def legend_artist(self, legend, orig_handle, fontsize, handlebox):\n", + " x0, y0 = handlebox.xdescent, handlebox.ydescent\n", + " offsets = [0, self.length]\n", + " line = mlines.Line2D((0, offsets[-1]), (0, 0), zorder=0, c=\"k\")\n", + " handlebox.add_artist(line)\n", + "\n", + " for offset in offsets:\n", + " circle1 = mpatches.Circle(\n", + " [x0 + offset, y0], 4, facecolor=\"k\", lw=3, zorder=1\n", + " )\n", + " handlebox.add_artist(circle1)\n", + " if self.with_inner:\n", + " circle2 = mpatches.Circle(\n", + " [x0 + offset, y0], 3, facecolor=\"red\", lw=3, zorder=1\n", + " )\n", + " handlebox.add_artist(circle2)\n", + "\n", + "\n", + "class FunctionHandler:\n", + " def __init__(self, xs, ys, *args, **kwargs):\n", + " super().__init__(*args, **kwargs)\n", + " self.xs = xs / xs.ptp() * 20\n", + " self.ys = ys - ys.mean()\n", + "\n", + " def legend_artist(self, legend, orig_handle, fontsize, handlebox):\n", + " x0, y0 = handlebox.xdescent, handlebox.ydescent\n", + "\n", + " line = mlines.Line2D(self.xs, self.ys * 10, zorder=0, c=\"k\", alpha=0.3)\n", + "\n", + " handlebox.add_artist(line)\n", + "\n", + "\n", + "plt.legend(\n", + " [\n", + " Function(),\n", + " mlines.Line2D([], [], marker=\"o\", lw=0, c=\"k\"),\n", + " LargestInterval(),\n", + " Interval(),\n", + " mlines.Line2D([], [], marker=\"x\", lw=0, c=\"green\"),\n", + " ],\n", + " [\n", + " \"original function\",\n", + " \"known point\",\n", + " \"interval\",\n", + " \"largest interval\",\n", + " \"next candidate point\",\n", + " ],\n", + " handler_map={\n", + " LargestInterval: IntervalHandler(False),\n", + " Interval: IntervalHandler(True),\n", + " Function: FunctionHandler(xs, ys),\n", + " },\n", + " bbox_to_anchor=(0.25, 0.9, 1.0, 0.0),\n", + " ncol=1,\n", + ")\n", + "\n", + "axs.axis(\"off\")\n", + "plt.savefig(\"figures/algo.pdf\", bbox_inches=\"tight\", transparent=True)\n", + "plt.show()" + ] + }, { "cell_type": "code", "execution_count": null, -- GitLab