Skip to content
Snippets Groups Projects
green.pyx 5.11 KiB
Newer Older
import kwant
import numpy as np

cimport cython

cimport cython.operator as co
cimport libcpp.vector
cimport libcpp.map

cdef extern from "math.h":
    double sincos(double x, double *sin, double *cos)
    double exp(double x)

cdef extern from "gsl/gsl_math.h":
    ctypedef struct gsl_function:
        double (* function) (double x, void* params)
        void* params

cdef extern from "gsl/gsl_integration.h":

    ctypedef struct gsl_integration_workspace

    gsl_integration_workspace* gsl_integration_workspace_alloc(size_t n)

    void  gsl_integration_workspace_free(gsl_integration_workspace* w)

    int  gsl_integration_qags(gsl_function* f, double a, double b, double epsabs, double epsrel, size_t limit, gsl_integration_workspace* workspace, double* result, double* abserr)


cdef gsl_integration_workspace* workspace

cdef struct Entry:
    libcpp.vector.vector[double complex] psis
    libcpp.vector.vector[double] fermis

cdef libcpp.map.map[double, Entry] cache

cdef object sys
cdef int n, m, nleads, norbs, index
cdef double t, fermi_energy
cdef double [:] pots, temps


cdef double real_f(double e, void* ignore):
    return func(e).real


cdef double imag_f(double e, void* ignore):
    return func(e).imag


cdef double complex complex_quad(double a, double b,
                                 double epsabs, double epsrel, int limit):
    cdef double real_res, imag_res, error

    cdef gsl_function func
    func.params = NULL

    func.function = real_f
    if gsl_integration_qags(&func, a, b, epsabs, epsrel, limit, workspace,
                            &real_res, &error) != 0:
        raise RuntimeError()

    func.function = imag_f
    if gsl_integration_qags(&func, a, b, epsabs, epsrel, limit, workspace,
                            &imag_res, &error) != 0:
        raise RuntimeError()

    return complex(real_res, imag_res)


cdef double fermi_of_lead(double E, int lead):
    if temps[lead] == 0:
        return float(pots[lead] + fermi_energy - E > 0)
    else:
        return 1. / (exp((E - (pots[lead] + fermi_energy) ) / temps[lead]) + 1)


cdef double complex func(double E):
    cdef int i, j, nmodes
    cdef double complex [:, :] psis_view
    cdef double [:] fermis_view
    cdef double complex g
    cdef double sin, cos, f
    cdef Entry* cached
    cdef libcpp.map.map[double, Entry].iterator cached_iter

    cached_iter = cache.find(E)
    if cached_iter == cache.end():
        wf = kwant.wave_function(sys, E)
        psis = [wf(i) for i in xrange(nleads)]
        fermis = []
        for lead, psi in enumerate(psis):
            fermi = np.empty(len(psi), float)
            fermis.append(fermi)
            fermi.fill(fermi_of_lead(E, lead))
        psis_view = np.asarray(np.concatenate(psis), complex)
        fermis_view = np.concatenate(fermis)

        # Create a default-constructed entry at energy E.
        cached = &cache[E]

        assert psis_view.shape[1] == norbs
        cached[0].psis.reserve(psis_view.shape[0] * psis_view.shape[1])
        for i in xrange(psis_view.shape[0]):
            for j in xrange(psis_view.shape[1]):
                cached[0].psis.push_back(psis_view[i, j])

        cached[0].fermis.reserve(fermis_view.shape[0])
        for i in xrange(fermis_view.shape[0]):
            cached[0].fermis.push_back(fermis_view[i])
    else:
        cached = &co.dereference(cached_iter).second

    nmodes = cached[0].fermis.size()
    g = 0
    if index:
        for i in xrange(nmodes):
            f = cached[0].fermis[i] # workaround for a Cython bug
            g += (cached[0].psis[i * norbs + n] *
                  cached[0].psis[i * norbs + m].conjugate() * (1 - f))
    else:
        for i in xrange(nmodes):
            f = cached[0].fermis[i] # workaround for a Cython bug
            g -= (cached[0].psis[i * norbs + n] *
                  cached[0].psis[i * norbs + m].conjugate() * f)
    sincos(-E * t, &sin, &cos)
    g *= (cos + 1j * sin)
    return g


def green(sys_, norbs_, interval, pots_, temps_, fermi_energy_, times,
          epsabs=1e-6, epsrel=1e-6, limit=1000):
    global cache, sys, n, m, norbs, nleads, index, t, fermi_energy, \
        pots, temps, workspace

    cache.clear()
    sys = sys_
    norbs = norbs_
    nleads = len(sys.leads)
    fermi_energy = fermi_energy_
    pots = np.asarray(pots_)
    temps = np.asarray(temps_)

    try:
        workspace = gsl_integration_workspace_alloc(limit)
        if workspace == NULL:
            raise RuntimeError()

        F = np.empty((norbs, norbs, len(times)), complex)
        G = np.empty((norbs, norbs, len(times)), complex)
        for n in xrange(norbs):
            for m in xrange(norbs):
                j = 0
                for t in times:
                    index = 0
                    F[n, m, j] = complex_quad(interval[0], interval[1],
                                              epsabs, epsrel, limit)
                    index = 1
                    G[n, m, j] = complex_quad(interval[0], interval[1],
                                              epsabs, epsrel, limit)
                    j += 1

        F *= 1j / (2*np.pi)
        G *= 1j / (2*np.pi)
    finally:
        gsl_integration_workspace_free(workspace)
    return F, G