{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib\n",
    "\n",
    "matplotlib.use(\"agg\")\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'svg'\n",
    "\n",
    "golden_mean = (np.sqrt(5) - 1) / 2  # Aesthetic ratio\n",
    "fig_width_pt = 246.0  # Columnwidth\n",
    "inches_per_pt = 1 / 72.27  # Convert pt to inches\n",
    "fig_width = fig_width_pt * inches_per_pt\n",
    "fig_height = fig_width * golden_mean  # height in inches\n",
    "fig_size = [fig_width, fig_height]\n",
    "\n",
    "params = {\n",
    "    \"backend\": \"ps\",\n",
    "    \"axes.labelsize\": 13,\n",
    "    \"font.size\": 13,\n",
    "    \"legend.fontsize\": 10,\n",
    "    \"xtick.labelsize\": 10,\n",
    "    \"ytick.labelsize\": 10,\n",
    "    \"text.usetex\": True,\n",
    "    \"figure.figsize\": fig_size,\n",
    "    \"font.family\": \"serif\",\n",
    "    \"font.serif\": \"Computer Modern Roman\",\n",
    "    \"legend.frameon\": True,\n",
    "    \"savefig.dpi\": 300,\n",
    "}\n",
    "\n",
    "plt.rcParams.update(params)\n",
    "plt.rc(\"text.latex\", preamble=[r\"\\usepackage{xfrac}\", r\"\\usepackage{siunitx}\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fig 1."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(1)\n",
    "xs = np.array([0.1, 0.3, 0.35, 0.45])\n",
    "f = lambda x: x**3\n",
    "ys = f(xs)\n",
    "means = lambda x: np.convolve(x, np.ones(2) / 2, mode=\"valid\")\n",
    "xs_means = means(xs)\n",
    "ys_means = means(ys)\n",
    "\n",
    "fig, ax = plt.subplots(figsize=fig_size)\n",
    "ax.scatter(xs, ys, c=\"k\")\n",
    "ax.plot(xs, ys, c=\"k\")\n",
    "# ax.scatter()\n",
    "ax.annotate(\n",
    "    s=r\"$L_{1,2} = \\sqrt{\\Delta x^2 + \\Delta y^2}$\",\n",
    "    xy=(np.mean([xs[0], xs[1]]), np.mean([ys[0], ys[1]])),\n",
    "    xytext=(xs[0]+0.05, ys[0] - 0.05),\n",
    "    arrowprops=dict(arrowstyle=\"->\"),\n",
    "    ha=\"center\",\n",
    "    zorder=10,\n",
    ")\n",
    "\n",
    "for i, (x, y) in enumerate(zip(xs, ys)):\n",
    "    sign = [1, -1][i % 2]\n",
    "    ax.annotate(\n",
    "        s=fr\"$x_{i+1}, y_{i+1}$\",\n",
    "        xy=(x, y),\n",
    "        xytext=(x + 0.01, y + sign * 0.04),\n",
    "        arrowprops=dict(arrowstyle=\"->\"),\n",
    "        ha=\"center\",\n",
    "    )\n",
    "    \n",
    "ax.scatter(xs, ys, c=\"green\", s=5, zorder=5, label=\"existing data\")\n",
    "losses = np.hypot(xs[1:] - xs[:-1], ys[1:] - ys[:-1])\n",
    "ax.scatter(xs_means, ys_means, c=\"red\", s=300*losses, zorder=8, label=\"candidate points\")\n",
    "xs_dense = np.linspace(xs[0], xs[-1], 400)\n",
    "ax.plot(xs_dense, f(xs_dense), alpha=0.3, zorder=7, label=\"function\")\n",
    "\n",
    "ax.legend()\n",
    "ax.axis(\"off\")\n",
    "plt.savefig(\"figures/loss_1D.pdf\", bbox_inches=\"tight\", transparent=True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fig 2."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import adaptive\n",
    "\n",
    "def f(x, offset=0.123):\n",
    "    a = 0.02\n",
    "    return x + a**2 / (a**2 + (x - offset)**2)\n",
    "\n",
    "def g(x):\n",
    "    return np.tanh(x*40)\n",
    "\n",
    "def h(x):\n",
    "    return np.sin(100*x) * np.exp(-x**2 / 0.1**2)\n",
    "\n",
    "funcs = [dict(function=f, bounds=(-1, 1), title=\"peak\"), dict(function=g, bounds=(-1, 1), title=\"tanh\"), dict(function=h, bounds=(-0.3, 0.3), title=\"wave packet\")]\n",
    "fig, axs = plt.subplots(2, len(funcs), figsize=(fig_width, 1.5*fig_height))\n",
    "n_points = 50\n",
    "for i, ax in enumerate(axs.T.flatten()):\n",
    "    ax.xaxis.set_ticks([])\n",
    "    ax.yaxis.set_ticks([])\n",
    "    if i % 2 == 0:\n",
    "        d = funcs[i // 2]\n",
    "        # homogeneous\n",
    "        xs = np.linspace(*d['bounds'], n_points)\n",
    "        ys = d['function'](xs)\n",
    "        ax.set_title(rf\"\\textrm{{{d['title']}}}\")\n",
    "    else:\n",
    "        d = funcs[(i - 1) // 2]\n",
    "        loss = adaptive.learner.learner1D.curvature_loss_function()\n",
    "        learner = adaptive.Learner1D(d['function'], bounds=d['bounds'], loss_per_interval=loss)\n",
    "        adaptive.runner.simple(learner, goal=lambda l: l.npoints >= n_points)\n",
    "        # adaptive\n",
    "        xs, ys = zip(*sorted(learner.data.items()))\n",
    "    xs_dense = np.linspace(*d['bounds'], 1000)\n",
    "    ax.plot(xs_dense, d['function'](xs_dense), c='red', alpha=0.3, lw=0.5)\n",
    "    ax.scatter(xs, ys, s=0.5, c='k')\n",
    "    \n",
    "axs[0][0].set_ylabel(r'$\\textrm{homogeneous}$')\n",
    "axs[1][0].set_ylabel(r'$\\textrm{adaptive}$')\n",
    "plt.savefig(\"figures/adaptive_vs_grid.pdf\", bbox_inches=\"tight\", transparent=True)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}