From 7c34ff0766f2e3ffa083aad626c7384dc0c0b6c6 Mon Sep 17 00:00:00 2001
From: Bas Nijholt <basnijholt@gmail.com>
Date: Wed, 25 Sep 2019 10:56:23 +0200
Subject: [PATCH] consistent colors in fig

---
 figures.ipynb | 86 +++++++++++++++++++++++++++++++++------------------
 1 file changed, 56 insertions(+), 30 deletions(-)

diff --git a/figures.ipynb b/figures.ipynb
index 8187dba..3fa61dd 100644
--- a/figures.ipynb
+++ b/figures.ipynb
@@ -26,7 +26,9 @@
     "import pickle\n",
     "\n",
     "import holoviews.plotting.mpl\n",
-    "import matplotlib; matplotlib.use(\"agg\")\n",
+    "import matplotlib\n",
+    "\n",
+    "matplotlib.use(\"agg\")\n",
     "import matplotlib.pyplot as plt\n",
     "import matplotlib.tri as mtri\n",
     "import numpy as np\n",
@@ -61,7 +63,26 @@
     "}\n",
     "\n",
     "plt.rcParams.update(params)\n",
-    "plt.rc(\"text.latex\", preamble=[r\"\\usepackage{xfrac}\", r\"\\usepackage{siunitx}\"])"
+    "plt.rc(\"text.latex\", preamble=[r\"\\usepackage{xfrac}\", r\"\\usepackage{siunitx}\"])\n",
+    "\n",
+    "\n",
+    "class HistogramNormalize(matplotlib.colors.Normalize):\n",
+    "    def __init__(self, data, vmin=None, vmax=None, mixing_degree=1):\n",
+    "        self.mixing_degree = mixing_degree\n",
+    "        if vmin is not None:\n",
+    "            data = data[data >= vmin]\n",
+    "        if vmax is not None:\n",
+    "            data = data[data <= vmax]\n",
+    "\n",
+    "        self.sorted_data = np.sort(data.flatten())\n",
+    "        matplotlib.colors.Normalize.__init__(self, vmin, vmax)\n",
+    "\n",
+    "    def __call__(self, value, clip=None):\n",
+    "        hist_norm = np.ma.masked_array(\n",
+    "            np.searchsorted(self.sorted_data, value) / len(self.sorted_data)\n",
+    "        )\n",
+    "        linear_norm = super().__call__(value, clip)\n",
+    "        return self.mixing_degree * hist_norm + (1 - self.mixing_degree) * linear_norm"
    ]
   },
   {
@@ -315,16 +336,22 @@
     "        ax.triplot(triang, c=\"w\", lw=0.2, alpha=0.8)\n",
     "\n",
     "    values = np.array(list(learner.data.values()))\n",
-    "    ax.imshow(\n",
-    "        learner.plot(npoints if kind == \"homogeneous\" else None).Image.I.data,\n",
+    "    plot_data = learner.plot(npoints if kind == \"homogeneous\" else None).Image.I.data\n",
+    "    im = ax.imshow(\n",
+    "        plot_data,\n",
     "        extent=(-0.5, 0.5, -0.5, 0.5),\n",
     "        interpolation=\"none\",\n",
     "    )\n",
     "    ax.set_xticks([])\n",
     "    ax.set_yticks([])\n",
     "\n",
+    "    if i in [2, 3]:\n",
+    "        norm = HistogramNormalize(plot_data, mixing_degree=0.6)\n",
+    "        im.set_norm(norm)\n",
+    "\n",
     "axs[0][0].set_ylabel(r\"$\\textrm{homogeneous}$\")\n",
     "axs[1][0].set_ylabel(r\"$\\textrm{adaptive}$\")\n",
+    "\n",
     "plt.savefig(\"figures/Learner2D.pdf\", bbox_inches=\"tight\", transparent=True)"
    ]
   },
@@ -390,7 +417,7 @@
     "        axs.scatter(xs, ys, alpha=1, zorder=2, c=\"k\")\n",
     "        (x_left, x_right), loss = list(learner.losses.items())[\n",
     "            0\n",
-    "        ]  # it's a ItemSortedDict\n",
+    "        ]  # it's an ItemSortedDict\n",
     "        (y_left, y_right) = [\n",
     "            learner.data[x_left] + offset,\n",
     "            learner.data[x_right] + offset,\n",
@@ -540,7 +567,7 @@
     "for i, ax in enumerate(axs):\n",
     "    ax.axis(\"off\")\n",
     "    ax.set_ylim(-1.5, 1.5)\n",
-    "    label = \"abcde\"[i]\n",
+    "    label = \"abcdefgh\"[i]\n",
     "    ax.text(\n",
     "        0.5,\n",
     "        0.9,\n",
@@ -551,36 +578,35 @@
     "    )\n",
     "\n",
     "\n",
-    "def plot_tri(xs, ax):\n",
+    "def plot_tri(xs, ax, colors):\n",
     "    ys = f(xs)\n",
     "    for i in range(len(xs)):\n",
     "        if i == 0 or i == len(xs) - 1:\n",
     "            continue\n",
-    "        color = f\"C{i}\"\n",
     "        verts = [(xs[i - 1], ys[i - 1]), (xs[i], ys[i]), (xs[i + 1], ys[i + 1])]\n",
-    "        poly = Polygon(verts, facecolor=color, alpha=0.4)\n",
+    "        poly = Polygon(verts, facecolor=colors[xs[i]], alpha=0.4)\n",
     "        ax.add_patch(poly)\n",
-    "        ax.scatter([xs[i]], [ys[i]], c=color, s=6, zorder=11)\n",
+    "        ax.scatter([xs[i]], [ys[i]], c=colors[xs[i]], s=6, zorder=11)\n",
     "        ax.plot(\n",
     "            [xs[i - 1], xs[i + 1]], [ys[i - 1], ys[i + 1]], c=\"k\", ls=\"--\", alpha=0.3\n",
     "        )\n",
     "\n",
     "\n",
-    "plot_tri(xs, axs[1])\n",
+    "learner = adaptive.Learner1D(\n",
+    "    f,\n",
+    "    (xs[0], xs[-1]),\n",
+    "    loss_per_interval=adaptive.learner.learner1D.curvature_loss_function(),\n",
+    ")\n",
+    "learner.tell_many(xs, ys)\n",
     "\n",
-    "for i in [2, 3, 4]:\n",
-    "    ax = axs[i]\n",
-    "    learner = adaptive.Learner1D(\n",
-    "        f,\n",
-    "        (xs[0], xs[-1]),\n",
-    "        loss_per_interval=adaptive.learner.learner1D.curvature_loss_function(),\n",
-    "    )\n",
-    "    learner.tell_many(xs, ys)\n",
-    "    x_new = learner.ask(1)[0][0]\n",
-    "    learner.tell(x_new, f(x_new))\n",
-    "    xs, ys = zip(*sorted(learner.data.items()))\n",
+    "for i, ax in enumerate(axs[1:]):\n",
+    "    if i != 0:\n",
+    "        x_new = learner.ask(1)[0][0]\n",
+    "        learner.tell(x_new, f(x_new))\n",
+    "        xs, ys = zip(*sorted(learner.data.items()))\n",
     "    plot(xs, ax)\n",
-    "    plot_tri(xs, ax)\n",
+    "    colors = {x: f\"C{i}\" for i, x in enumerate(learner.data.keys())}\n",
+    "    plot_tri(xs, ax, colors=colors)\n",
     "\n",
     "plt.savefig(\"figures/line_loss.pdf\", bbox_inches=\"tight\", transparent=True)\n",
     "plt.show()"
@@ -716,13 +742,13 @@
     "        color = f\"C{j}\"\n",
     "        label = \"abc\"[j]\n",
     "        label = f\"$\\mathrm{{({label})}}$ {title}\"\n",
-    "#         ax.loglog(Ns, err_hom, ls=\"--\", c=color)\n",
-    "#         ax.loglog(Ns, err_adaptive, label=label, c=color)\n",
-    "        error = np.array(err_hom) / np.array(err_adaptive)\n",
-    "        if i == 0:\n",
-    "            ax.loglog(Ns[:36], error[:36], c=color, label=label)\n",
-    "        else:\n",
-    "            ax.loglog(Ns, error, c=color, label=label)\n",
+    "        ax.loglog(Ns, err_hom, ls=\"--\", c=color)\n",
+    "        ax.loglog(Ns, err_adaptive, label=label, c=color)\n",
+    "#         error = np.array(err_hom) / np.array(err_adaptive)\n",
+    "#         if i == 0:\n",
+    "#             ax.loglog(Ns[:36], error[:36], c=color, label=label)\n",
+    "#         else:\n",
+    "#             ax.loglog(Ns, error, c=color, label=label)\n",
     "        ax.legend()\n",
     "\n",
     "plt.savefig(\"figures/line_loss_error.pdf\", bbox_inches=\"tight\", transparent=True)\n",
-- 
GitLab