from . import kernel, potentials, dyn, util

import numpy as np
from numpy import testing
from unittest import TestCase


class H2test(TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        cls.system = dyn.DynWrapper(kernel.ScalarFunctionWrapper(
            kernel.Cell(np.eye(3) * 10, [(.455, .5, .5), (.545, .5, .5)], ["H", "H"]),
            [potentials.sw2_potential_family(gauge_a=7.049556227, gauge_b=0.6022245584, a=1.8, p=4, q=0, epsilon=0.5, sigma=1)],
            pbc=False,
        ))
        cls.d_eq = 1.122462

    def test_jac(self):
        x = self.system.x
        j = self.system.wrapper.g(x)
        j_num = util.num_grad(self.system.wrapper.f, x)
        testing.assert_allclose(j, j_num)

    def test_relax(self):
        self.system.push()
        self.system.relax(snapshots=True)
        self.system.snapshots, snapshots = [], self.system.snapshots
        for i in snapshots:
            assert "total-energy" in i.meta
            assert "forces" in i.meta
        assert self.system.state is snapshots[-1]
        c = self.system.state
        self.system.pop()
        e = c.meta['total-energy']
        d = c.cartesian
        d = np.linalg.norm(d[0] - d[1])
        testing.assert_allclose(e, -1)
        testing.assert_allclose(d, self.d_eq)

    def test_dyn(self):
        self.system.push()
        e0 = self.system.wrapper.f(self.system.x)
        assert e0 != 0
        self.system.integrate(1, rtol=1e-5, snapshots=True)
        self.system.snapshots, snapshots = [], self.system.snapshots
        for i in snapshots:
            assert "total-energy" in i.meta
            assert "forces" in i.meta
            assert 0 <= i.meta["time"] <= 1
            testing.assert_allclose(i.meta["total-energy"] + (i.meta["velocities"] ** 2 / 2).sum(), e0, rtol=2e-4)
        assert self.system.state is snapshots[-1]
        self.system.pop()

    def test_nvt_vs(self, ek=0.05):
        self.system.push()
        self.system.nvt_vs(ek=ek, dt=0.1, alpha=0.1, n_epochs=10, snapshots=True)
        snapshots, self.system.snapshots = self.system.snapshots, []
        for i_i, i in enumerate(snapshots):
            assert "total-energy" in i.meta
            assert "forces" in i.meta
            testing.assert_allclose(i.meta["time"], (i_i + 1) * 0.1)
        self.system.pop()
        ek = np.array(tuple((i.meta["velocities"] ** 2 / 2).sum(axis=-1) for i in snapshots))
        testing.assert_allclose(ek, ek)


class SiTest(TestCase):
    @classmethod
    def setUpClass(cls) -> None:
        density = 0.46
        a = (1 / density) ** (1./3)
        a = 1
        cell = kernel.Cell(
            np.array([[0, a, a], [a, 0, a], [a, a, 0]]),
            [[0, 0, 0], [.25, .25, .25]],
            ["si", "si"],
        )
        cls.system = dyn.DynWrapper(kernel.ScalarFunctionWrapper(
            cell,
            [
                potentials.sw2_potential_family(gauge_a=7.049556227, gauge_b=0.6022245584, a=1.8, p=4, q=0, epsilon=0.5, sigma=1),
                potentials.sw3_potential_family(l=21, gamma=1.2, cos_theta0=-1. / 3, a=1.8, epsilon=.5, sigma=1),
            ],
            include_vectors=True,
        ))

    def test_jac(self):
        x = self.system.x
        j = self.system.wrapper.g(x)
        j_num = util.num_grad(self.system.wrapper.f, x)
        testing.assert_allclose(j, j_num, atol=1e-6)

    def test_relax(self):
        self.system.push()
        self.system.relax(method="L-BFGS-B")
        self.system.relax(method="CG")
        c = self.system.state
        testing.assert_allclose(c.meta['stress'], 0, atol=1e-5)
        testing.assert_allclose(c.meta['forces'], 0, atol=1e-5)
        testing.assert_allclose(c.meta['total-energy'], -4, atol=1e-3)
        testing.assert_allclose(np.abs(np.abs(c.coordinates[1] - c.coordinates[0]) -.5), .25, atol=1e-5)
        testing.assert_allclose(c.volume, 2. / 0.46, atol=1e-2)   # 0.46 = equilibrium density
        self.system.pop()

    def test_dyn(self):
        self.system.push()
        e0 = self.system.wrapper.f(self.system.x)
        self.system.integrate(1, rtol=1e-5)
        c = self.system.state
        testing.assert_allclose(c.meta["total-energy"] + (c.meta["velocities"] ** 2 / 2).sum(), e0, rtol=2e-4)
        self.system.pop()