diff --git a/figures.ipynb b/figures.ipynb index 8187dba0f45838e964cfcd51d825f3d819038047..3fa61dd08eb55d438572eeae74a30378672978bd 100644 --- a/figures.ipynb +++ b/figures.ipynb @@ -26,7 +26,9 @@ "import pickle\n", "\n", "import holoviews.plotting.mpl\n", - "import matplotlib; matplotlib.use(\"agg\")\n", + "import matplotlib\n", + "\n", + "matplotlib.use(\"agg\")\n", "import matplotlib.pyplot as plt\n", "import matplotlib.tri as mtri\n", "import numpy as np\n", @@ -61,7 +63,26 @@ "}\n", "\n", "plt.rcParams.update(params)\n", - "plt.rc(\"text.latex\", preamble=[r\"\\usepackage{xfrac}\", r\"\\usepackage{siunitx}\"])" + "plt.rc(\"text.latex\", preamble=[r\"\\usepackage{xfrac}\", r\"\\usepackage{siunitx}\"])\n", + "\n", + "\n", + "class HistogramNormalize(matplotlib.colors.Normalize):\n", + " def __init__(self, data, vmin=None, vmax=None, mixing_degree=1):\n", + " self.mixing_degree = mixing_degree\n", + " if vmin is not None:\n", + " data = data[data >= vmin]\n", + " if vmax is not None:\n", + " data = data[data <= vmax]\n", + "\n", + " self.sorted_data = np.sort(data.flatten())\n", + " matplotlib.colors.Normalize.__init__(self, vmin, vmax)\n", + "\n", + " def __call__(self, value, clip=None):\n", + " hist_norm = np.ma.masked_array(\n", + " np.searchsorted(self.sorted_data, value) / len(self.sorted_data)\n", + " )\n", + " linear_norm = super().__call__(value, clip)\n", + " return self.mixing_degree * hist_norm + (1 - self.mixing_degree) * linear_norm" ] }, { @@ -315,16 +336,22 @@ " ax.triplot(triang, c=\"w\", lw=0.2, alpha=0.8)\n", "\n", " values = np.array(list(learner.data.values()))\n", - " ax.imshow(\n", - " learner.plot(npoints if kind == \"homogeneous\" else None).Image.I.data,\n", + " plot_data = learner.plot(npoints if kind == \"homogeneous\" else None).Image.I.data\n", + " im = ax.imshow(\n", + " plot_data,\n", " extent=(-0.5, 0.5, -0.5, 0.5),\n", " interpolation=\"none\",\n", " )\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", "\n", + " if i in [2, 3]:\n", + " norm = HistogramNormalize(plot_data, mixing_degree=0.6)\n", + " im.set_norm(norm)\n", + "\n", "axs[0][0].set_ylabel(r\"$\\textrm{homogeneous}$\")\n", "axs[1][0].set_ylabel(r\"$\\textrm{adaptive}$\")\n", + "\n", "plt.savefig(\"figures/Learner2D.pdf\", bbox_inches=\"tight\", transparent=True)" ] }, @@ -390,7 +417,7 @@ " axs.scatter(xs, ys, alpha=1, zorder=2, c=\"k\")\n", " (x_left, x_right), loss = list(learner.losses.items())[\n", " 0\n", - " ] # it's a ItemSortedDict\n", + " ] # it's an ItemSortedDict\n", " (y_left, y_right) = [\n", " learner.data[x_left] + offset,\n", " learner.data[x_right] + offset,\n", @@ -540,7 +567,7 @@ "for i, ax in enumerate(axs):\n", " ax.axis(\"off\")\n", " ax.set_ylim(-1.5, 1.5)\n", - " label = \"abcde\"[i]\n", + " label = \"abcdefgh\"[i]\n", " ax.text(\n", " 0.5,\n", " 0.9,\n", @@ -551,36 +578,35 @@ " )\n", "\n", "\n", - "def plot_tri(xs, ax):\n", + "def plot_tri(xs, ax, colors):\n", " ys = f(xs)\n", " for i in range(len(xs)):\n", " if i == 0 or i == len(xs) - 1:\n", " continue\n", - " color = f\"C{i}\"\n", " verts = [(xs[i - 1], ys[i - 1]), (xs[i], ys[i]), (xs[i + 1], ys[i + 1])]\n", - " poly = Polygon(verts, facecolor=color, alpha=0.4)\n", + " poly = Polygon(verts, facecolor=colors[xs[i]], alpha=0.4)\n", " ax.add_patch(poly)\n", - " ax.scatter([xs[i]], [ys[i]], c=color, s=6, zorder=11)\n", + " ax.scatter([xs[i]], [ys[i]], c=colors[xs[i]], s=6, zorder=11)\n", " ax.plot(\n", " [xs[i - 1], xs[i + 1]], [ys[i - 1], ys[i + 1]], c=\"k\", ls=\"--\", alpha=0.3\n", " )\n", "\n", "\n", - "plot_tri(xs, axs[1])\n", + "learner = adaptive.Learner1D(\n", + " f,\n", + " (xs[0], xs[-1]),\n", + " loss_per_interval=adaptive.learner.learner1D.curvature_loss_function(),\n", + ")\n", + "learner.tell_many(xs, ys)\n", "\n", - "for i in [2, 3, 4]:\n", - " ax = axs[i]\n", - " learner = adaptive.Learner1D(\n", - " f,\n", - " (xs[0], xs[-1]),\n", - " loss_per_interval=adaptive.learner.learner1D.curvature_loss_function(),\n", - " )\n", - " learner.tell_many(xs, ys)\n", - " x_new = learner.ask(1)[0][0]\n", - " learner.tell(x_new, f(x_new))\n", - " xs, ys = zip(*sorted(learner.data.items()))\n", + "for i, ax in enumerate(axs[1:]):\n", + " if i != 0:\n", + " x_new = learner.ask(1)[0][0]\n", + " learner.tell(x_new, f(x_new))\n", + " xs, ys = zip(*sorted(learner.data.items()))\n", " plot(xs, ax)\n", - " plot_tri(xs, ax)\n", + " colors = {x: f\"C{i}\" for i, x in enumerate(learner.data.keys())}\n", + " plot_tri(xs, ax, colors=colors)\n", "\n", "plt.savefig(\"figures/line_loss.pdf\", bbox_inches=\"tight\", transparent=True)\n", "plt.show()" @@ -716,13 +742,13 @@ " color = f\"C{j}\"\n", " label = \"abc\"[j]\n", " label = f\"$\\mathrm{{({label})}}$ {title}\"\n", - "# ax.loglog(Ns, err_hom, ls=\"--\", c=color)\n", - "# ax.loglog(Ns, err_adaptive, label=label, c=color)\n", - " error = np.array(err_hom) / np.array(err_adaptive)\n", - " if i == 0:\n", - " ax.loglog(Ns[:36], error[:36], c=color, label=label)\n", - " else:\n", - " ax.loglog(Ns, error, c=color, label=label)\n", + " ax.loglog(Ns, err_hom, ls=\"--\", c=color)\n", + " ax.loglog(Ns, err_adaptive, label=label, c=color)\n", + "# error = np.array(err_hom) / np.array(err_adaptive)\n", + "# if i == 0:\n", + "# ax.loglog(Ns[:36], error[:36], c=color, label=label)\n", + "# else:\n", + "# ax.loglog(Ns, error, c=color, label=label)\n", " ax.legend()\n", "\n", "plt.savefig(\"figures/line_loss_error.pdf\", bbox_inches=\"tight\", transparent=True)\n",