From 3691b86459e439fd7a21890b89c89be67c72e8e5 Mon Sep 17 00:00:00 2001
From: Bas Nijholt <basnijholt@gmail.com>
Date: Fri, 13 Sep 2019 11:26:23 +0200
Subject: [PATCH] add 2D figure

---
 figures.ipynb | 195 +++++++++++++++++++++++++++++++++++++++++++++-----
 1 file changed, 178 insertions(+), 17 deletions(-)

diff --git a/figures.ipynb b/figures.ipynb
index 2bd53ee..93d0a78 100644
--- a/figures.ipynb
+++ b/figures.ipynb
@@ -185,46 +185,62 @@
     "\n",
     "\n",
     "def f(xy, offset=0.123):\n",
-    "    a = 0.1\n",
+    "    a = 0.2\n",
     "    x, y = xy\n",
-    "    return x * y + a ** 2 / (a ** 2 + (x - offset) ** 2 + (y - offset) ** 2)\n",
+    "    return x + np.exp(-(x ** 2 + y ** 2 - 0.75 ** 2) ** 2 / a ** 4)\n",
+    "\n",
     "\n",
     "@functools.lru_cache()\n",
     "def g_setup(fname):\n",
     "    data = adaptive.utils.load(fname)\n",
     "    points = np.array(list(data.keys()))\n",
     "    values = np.array(list(data.values()), dtype=float)\n",
-    "    bounds = [(points[:, 0].min(), points[:, 0].max()), (points[:, 1].min(), points[:, 1].max())]\n",
+    "    bounds = [\n",
+    "        (points[:, 0].min(), points[:, 0].max()),\n",
+    "        (points[:, 1].min(), points[:, 1].max()),\n",
+    "    ]\n",
     "    ll, ur = np.reshape(bounds, (2, 2)).T\n",
     "    inds = np.all(np.logical_and(ll <= points, points <= ur), axis=1)\n",
     "    points, values = points[inds], values[inds].reshape(-1, 1)\n",
     "    return interpolate.LinearNDInterpolator(points, values), bounds\n",
     "\n",
+    "\n",
     "def g(xy, fname):\n",
     "    ip, _ = g_setup(fname)\n",
-    "    return ip(xy)\n",
+    "    return np.round(ip(xy))\n",
+    "\n",
+    "\n",
+    "def density(x, eps=0):\n",
+    "    e = [0.8, 0.2]\n",
+    "    delta = [0.5, 0.5, 0.5]\n",
+    "    c = 3\n",
+    "    omega = [0.02, 0.05]\n",
+    "\n",
+    "    H = np.array(\n",
+    "        [\n",
+    "            [e[0] + 1j * omega[0], delta[0], delta[1]],\n",
+    "            [delta[0], e[1] + c * x + 1j * omega[1], delta[1]],\n",
+    "            [delta[1], delta[2], e[1] - c * x + 1j * omega[1]],\n",
+    "        ]\n",
+    "    )\n",
+    "    H += np.eye(3) * eps\n",
+    "    return np.trace(np.linalg.inv(H)).imag\n",
     "\n",
     "\n",
     "def h(xy):\n",
     "    x, y = xy\n",
-    "    return np.sin(100 * x * y) * np.exp(-x ** 2 / 0.1 ** 2 - y ** 2 / 0.4 ** 2)\n",
+    "    return density(x, y) + y\n",
     "\n",
     "\n",
     "funcs = [\n",
-    "    dict(function=f, bounds=[(-1, 1), (-1, 1)], title=\"peak\", npoints=50,),\n",
+    "    dict(function=f, bounds=[(-1, 1), (-1, 1)], npoints=33),\n",
     "    dict(\n",
     "        function=g,\n",
     "        bounds=g_setup(\"phase_diagram.pickle\")[1],\n",
-    "        title=\"tanh\",\n",
-    "        npoints=140,\n",
+    "        npoints=100,\n",
     "        fname=\"phase_diagram.pickle\",\n",
     "    ),\n",
-    "    dict(\n",
-    "        function=h,\n",
-    "        bounds=[(-0.3, 0.3), (-0.3, 0.3)],\n",
-    "        title=\"wave packet\",\n",
-    "        npoints=50,\n",
-    "    ),\n",
+    "    dict(function=h, bounds=[(-1, 1), (-3, 3)], npoints=50),\n",
     "]\n",
     "fig, axs = plt.subplots(2, len(funcs), figsize=(fig_width, 1.5 * fig_height))\n",
     "\n",
@@ -233,6 +249,15 @@
     "with_tri = False\n",
     "\n",
     "for i, ax in enumerate(axs.T.flatten()):\n",
+    "    label = \"abcdef\"[i]\n",
+    "    ax.text(\n",
+    "        0.5,\n",
+    "        1.05,\n",
+    "        f\"$\\mathrm{{({label})}}$\",\n",
+    "        transform=ax.transAxes,\n",
+    "        horizontalalignment=\"center\",\n",
+    "        verticalalignment=\"bottom\",\n",
+    "    )\n",
     "    ax.xaxis.set_ticks([])\n",
     "    ax.yaxis.set_ticks([])\n",
     "    kind = \"homogeneous\" if i % 2 == 0 else \"adaptive\"\n",
@@ -245,16 +270,20 @@
     "        f = functools.partial(f, fname=fname)\n",
     "\n",
     "    if kind == \"homogeneous\":\n",
-    "        ax.set_title(rf\"\\textrm{{{d['title']}}}\")\n",
     "        xs, ys = [np.linspace(*bound, npoints) for bound in bounds]\n",
     "        data = {xy: f(xy) for xy in itertools.product(xs, ys)}\n",
     "        learner = adaptive.Learner2D(f, bounds=bounds)\n",
     "        learner.data = data\n",
+    "        d[\"learner_hom\"] = learner\n",
     "    elif kind == \"adaptive\":\n",
     "        learner = adaptive.Learner2D(f, bounds=bounds)\n",
     "        if fname is not None:\n",
     "            learner.load(fname)\n",
+    "        learner.data = {\n",
+    "            k: v for i, (k, v) in enumerate(learner.data.items()) if i <= npoints ** 2\n",
+    "        }\n",
     "        adaptive.runner.simple(learner, goal=lambda l: l.npoints >= npoints ** 2)\n",
+    "        d[\"learner\"] = learner\n",
     "\n",
     "    if with_tri:\n",
     "        tri = learner.ip().tri\n",
@@ -262,7 +291,11 @@
     "        ax.triplot(triang, c=\"w\", lw=0.2, alpha=0.8)\n",
     "\n",
     "    values = np.array(list(learner.data.values()))\n",
-    "    ax.imshow(learner.plot().Image.I.data, extent=(-0.5, 0.5, -0.5, 0.5))\n",
+    "    ax.imshow(\n",
+    "        learner.plot(npoints if kind == \"homogeneous\" else None).Image.I.data,\n",
+    "        extent=(-0.5, 0.5, -0.5, 0.5),\n",
+    "        interpolation=\"none\",\n",
+    "    )\n",
     "    ax.set_xticks([])\n",
     "    ax.set_yticks([])\n",
     "\n",
@@ -276,7 +309,135 @@
    "execution_count": null,
    "metadata": {},
    "outputs": [],
-   "source": []
+   "source": [
+    "from scipy import interpolate\n",
+    "import functools\n",
+    "import itertools\n",
+    "import adaptive\n",
+    "import holoviews.plotting.mpl\n",
+    "import matplotlib.tri as mtri\n",
+    "\n",
+    "\n",
+    "def f(xy, offset=0.123):\n",
+    "    a = 0.2\n",
+    "    x, y = xy\n",
+    "    return x + np.exp(-(x ** 2 + y ** 2 - 0.75 ** 2) ** 2 / a ** 4)\n",
+    "\n",
+    "\n",
+    "@functools.lru_cache()\n",
+    "def g_setup(fname):\n",
+    "    data = adaptive.utils.load(fname)\n",
+    "    points = np.array(list(data.keys()))\n",
+    "    values = np.array(list(data.values()), dtype=float)\n",
+    "    bounds = [\n",
+    "        (points[:, 0].min(), points[:, 0].max()),\n",
+    "        (points[:, 1].min(), points[:, 1].max()),\n",
+    "    ]\n",
+    "    ll, ur = np.reshape(bounds, (2, 2)).T\n",
+    "    inds = np.all(np.logical_and(ll <= points, points <= ur), axis=1)\n",
+    "    points, values = points[inds], values[inds].reshape(-1, 1)\n",
+    "    return interpolate.LinearNDInterpolator(points, values), bounds\n",
+    "\n",
+    "\n",
+    "def g(xy, fname):\n",
+    "    ip, _ = g_setup(fname)\n",
+    "    return np.round(ip(xy))\n",
+    "\n",
+    "\n",
+    "def density(x, eps=0):\n",
+    "    e = [0.8, 0.2]\n",
+    "    delta = [0.5, 0.5, 0.5]\n",
+    "    c = 3\n",
+    "    omega = [0.02, 0.05]\n",
+    "\n",
+    "    H = np.array(\n",
+    "        [\n",
+    "            [e[0] + 1j * omega[0], delta[0], delta[1]],\n",
+    "            [delta[0], e[1] + c * x + 1j * omega[1], delta[1]],\n",
+    "            [delta[1], delta[2], e[1] - c * x + 1j * omega[1]],\n",
+    "        ]\n",
+    "    )\n",
+    "    H += np.eye(3) * eps\n",
+    "    return np.trace(np.linalg.inv(H)).imag\n",
+    "\n",
+    "\n",
+    "def h(xy):\n",
+    "    x, y = xy\n",
+    "    return density(x, y) + y\n",
+    "\n",
+    "\n",
+    "funcs = [\n",
+    "    dict(function=f, bounds=[(-1, 1), (-1, 1)], npoints=33),\n",
+    "    dict(\n",
+    "        function=g,\n",
+    "        bounds=g_setup(\"phase_diagram.pickle\")[1],\n",
+    "        npoints=100,\n",
+    "        fname=\"phase_diagram.pickle\",\n",
+    "    ),\n",
+    "    dict(function=h, bounds=[(-1, 1), (-3, 3)], npoints=50),\n",
+    "]\n",
+    "fig, axs = plt.subplots(len(funcs), 2, figsize=(fig_width, 2 * fig_height))\n",
+    "\n",
+    "plt.subplots_adjust(hspace=0.1, wspace=0.1)\n",
+    "\n",
+    "with_tri = False\n",
+    "\n",
+    "for i, ax in enumerate(axs.flatten()):\n",
+    "    label = \"abcdef\"[i]\n",
+    "    ax.text(\n",
+    "        -0.03,\n",
+    "        0.98,\n",
+    "        f\"$\\mathrm{{({label})}}$\",\n",
+    "        transform=ax.transAxes,\n",
+    "        horizontalalignment=\"right\",\n",
+    "        verticalalignment=\"top\",\n",
+    "    )\n",
+    "    ax.xaxis.set_ticks([])\n",
+    "    ax.yaxis.set_ticks([])\n",
+    "    kind = \"homogeneous\" if i % 2 == 0 else \"adaptive\"\n",
+    "    d = funcs[i // 2] if kind == \"homogeneous\" else funcs[(i - 1) // 2]\n",
+    "    bounds = d[\"bounds\"]\n",
+    "    npoints = d[\"npoints\"]\n",
+    "    f = d[\"function\"]\n",
+    "    fname = d.get(\"fname\")\n",
+    "    if fname is not None:\n",
+    "        f = functools.partial(f, fname=fname)\n",
+    "\n",
+    "    if kind == \"homogeneous\":\n",
+    "        xs, ys = [np.linspace(*bound, npoints) for bound in bounds]\n",
+    "        data = {xy: f(xy) for xy in itertools.product(xs, ys)}\n",
+    "        learner = adaptive.Learner2D(f, bounds=bounds)\n",
+    "        learner.data = data\n",
+    "        d[\"learner_hom\"] = learner\n",
+    "    elif kind == \"adaptive\":\n",
+    "        learner = adaptive.Learner2D(f, bounds=bounds)\n",
+    "        if fname is not None:\n",
+    "            learner.load(fname)\n",
+    "        learner.data = {\n",
+    "            k: v for i, (k, v) in enumerate(learner.data.items()) if i <= npoints ** 2\n",
+    "        }\n",
+    "        adaptive.runner.simple(learner, goal=lambda l: l.npoints >= npoints ** 2)\n",
+    "        d[\"learner\"] = learner\n",
+    "\n",
+    "    if with_tri:\n",
+    "        tri = learner.ip().tri\n",
+    "        triang = mtri.Triangulation(*tri.points.T, triangles=tri.vertices)\n",
+    "        ax.triplot(triang, c=\"w\", lw=0.2, alpha=0.8)\n",
+    "\n",
+    "    values = np.array(list(learner.data.values()))\n",
+    "    ax.imshow(\n",
+    "        learner.plot(npoints if kind == \"homogeneous\" else None).Image.I.data,\n",
+    "        extent=(-0.5, 0.5, -0.5, 0.5),\n",
+    "        interpolation=\"none\",\n",
+    "    )\n",
+    "    ax.set_xticks([])\n",
+    "    ax.set_yticks([])\n",
+    "\n",
+    "axs[0][0].set_title(r\"$\\textrm{homogeneous}$\")\n",
+    "axs[0][1].set_title(r\"$\\textrm{adaptive}$\")\n",
+    "\n",
+    "plt.savefig(\"figures/adaptive_2D.pdf\", bbox_inches=\"tight\", transparent=True)"
+   ]
   }
  ],
  "metadata": {
-- 
GitLab