From 1f25d59cdc239090056a8efeb4d05cfab4a3f7b1 Mon Sep 17 00:00:00 2001
From: Michael Wimmer <wimmer@lorentz.leidenuniv.nl>
Date: Thu, 5 Apr 2012 00:58:14 +0200
Subject: [PATCH] implement MUMPS wrapper (only complex for now)

---
 kwant/linalg/_mumps.pyx           | 208 ++++++++++++
 kwant/linalg/cmumps.pxd           |  49 +++
 kwant/linalg/cmumps.py            |   3 +
 kwant/linalg/fortran_helpers.py   | 111 +++++++
 kwant/linalg/mumps.py             | 519 ++++++++++++++++++++++++++++++
 kwant/linalg/tests/_test_utils.py |  76 +++++
 kwant/linalg/tests/test_linalg.py |  79 +----
 kwant/linalg/tests/test_mumps.py  |  76 +++++
 setup.py                          |   7 +-
 9 files changed, 1049 insertions(+), 79 deletions(-)
 create mode 100644 kwant/linalg/_mumps.pyx
 create mode 100644 kwant/linalg/cmumps.pxd
 create mode 100644 kwant/linalg/cmumps.py
 create mode 100644 kwant/linalg/fortran_helpers.py
 create mode 100644 kwant/linalg/mumps.py
 create mode 100644 kwant/linalg/tests/_test_utils.py
 create mode 100644 kwant/linalg/tests/test_mumps.py

diff --git a/kwant/linalg/_mumps.pyx b/kwant/linalg/_mumps.pyx
new file mode 100644
index 0000000..9826056
--- /dev/null
+++ b/kwant/linalg/_mumps.pyx
@@ -0,0 +1,208 @@
+cimport numpy as np
+import numpy as np
+cimport cmumps
+import cmumps
+from fortran_helpers import assert_fortran_matvec, assert_fortran_mat
+
+int_dtype = cmumps.int_dtype
+
+# Proxy classes for Python access to the control and info parameters of MUMPS
+
+cdef class mumps_int_array:
+     cdef cmumps.MUMPS_INT *array
+
+     def __init__(self):
+         self.array = NULL
+
+     def __getitem__(self, key):
+         return self.array[key-1]
+
+     def __setitem__(self, key, value):
+         self.array[key-1] = value
+
+
+# workaround for the fact that cython cannot pass pointers to an __init__()
+cdef make_mumps_int_array(cmumps.MUMPS_INT *array):
+     wrapper = mumps_int_array()
+     wrapper.array = array
+     return wrapper
+
+
+cdef class smumps_real_array:
+     cdef cmumps.SMUMPS_REAL *array
+
+     def __init__(self):
+         self.array = NULL
+
+     def __getitem__(self, key):
+         return self.array[key-1]
+
+     def __setitem__(self, key, value):
+         self.array[key-1] = value
+
+
+cdef make_smumps_real_array(cmumps.SMUMPS_REAL *array):
+     wrapper = smumps_real_array()
+     wrapper.array = array
+     return wrapper
+
+
+cdef class dmumps_real_array:
+     cdef cmumps.DMUMPS_REAL *array
+
+     def __init__(self):
+         self.array = NULL
+
+     def __getitem__(self, key):
+         return self.array[key-1]
+
+     def __setitem__(self, key, value):
+         self.array[key-1] = value
+
+
+cdef make_dmumps_real_array(cmumps.DMUMPS_REAL *array):
+     wrapper = dmumps_real_array()
+     wrapper.array = array
+     return wrapper
+
+
+cdef class cmumps_real_array:
+     cdef cmumps.CMUMPS_REAL *array
+
+     def __init__(self):
+         self.array = NULL
+
+     def __getitem__(self, key):
+         return self.array[key-1]
+
+     def __setitem__(self, key, value):
+         self.array[key-1] = value
+
+
+cdef make_cmumps_real_array(cmumps.CMUMPS_REAL *array):
+     wrapper = cmumps_real_array()
+     wrapper.array = array
+     return wrapper
+
+
+cdef class zmumps_real_array:
+     cdef cmumps.ZMUMPS_REAL *array
+
+     def __init__(self):
+         self.array = NULL
+
+     def __getitem__(self, key):
+         return self.array[key-1]
+
+     def __setitem__(self, key, value):
+         self.array[key-1] = value
+
+
+cdef make_zmumps_real_array(cmumps.ZMUMPS_REAL *array):
+     wrapper = zmumps_real_array()
+     wrapper.array = array
+     return wrapper
+
+#############################################################
+
+cdef class zmumps:
+    cdef cmumps.ZMUMPS_STRUC_C params
+
+    cdef public mumps_int_array icntl
+    cdef public zmumps_real_array cntl
+    cdef public mumps_int_array info
+    cdef public mumps_int_array infog
+    cdef public zmumps_real_array rinfo
+    cdef public zmumps_real_array rinfog
+
+    def __init__(self, verbose=False, sym=0):
+        self.params.job = -1
+        self.params.sym = sym
+        self.params.par = 1
+        self.params.comm_fortran = -987654
+
+        cmumps.zmumps_c(&self.params)
+
+        self.icntl = make_mumps_int_array(self.params.icntl)
+        self.cntl = make_zmumps_real_array(self.params.cntl)
+        self.info = make_mumps_int_array(self.params.info)
+        self.infog = make_mumps_int_array(self.params.infog)
+        self.rinfo = make_zmumps_real_array(self.params.rinfo)
+        self.rinfog = make_zmumps_real_array(self.params.rinfog)
+
+        # no diagnostic output (MUMPS is very verbose normally)
+        if not verbose:
+            self.icntl[1] = 0
+            self.icntl[3] = 0
+
+    def __dealloc__(self):
+        self.params.job = -2
+        cmumps.zmumps_c(&self.params)
+
+    def call(self):
+        cmumps.zmumps_c(&self.params)
+
+    def _set_job(self, value):
+        self.params.job = value
+
+    def _get_job(self):
+        return self.params.job
+
+    job = property(_get_job, _set_job)
+
+    @property
+    def sym(self):
+        return self.params.sym
+
+    def set_assembled_matrix(self,
+                             cmumps.MUMPS_INT N,
+                             np.ndarray[cmumps.MUMPS_INT, ndim=1] i,
+                             np.ndarray[cmumps.MUMPS_INT, ndim=1] j,
+                             np.ndarray[np.complex128_t, ndim=1] a):
+        self.params.n = N
+        self.params.nz = a.shape[0]
+        self.params.irn = <cmumps.MUMPS_INT *>i.data
+        self.params.jcn = <cmumps.MUMPS_INT *>j.data
+        self.params.a = <cmumps.ZMUMPS_COMPLEX *>a.data
+
+    def set_dense_rhs(self, np.ndarray rhs):
+
+        assert_fortran_matvec(rhs)
+        if rhs.dtype != np.complex128:
+            raise ValueError("numpy array must be of dtype complex128!")
+
+        if rhs.ndim == 1:
+            self.params.nrhs = 1
+        else:
+            self.params.nrhs = rhs.shape[1]
+        self.params.lrhs = rhs.shape[0]
+        self.params.rhs = <cmumps.ZMUMPS_COMPLEX *>rhs.data
+
+    def set_sparse_rhs(self,
+                       np.ndarray[cmumps.MUMPS_INT, ndim=1] col_ptr,
+                       np.ndarray[cmumps.MUMPS_INT, ndim=1] row_ind,
+                       np.ndarray[np.complex128_t, ndim=1] data):
+
+        if row_ind.shape[0] != data.shape[0]:
+            raise ValueError("Number of entries in row index and value "
+                             "array differ!")
+
+        self.params.nz_rhs = data.shape[0]
+        self.params.nrhs = col_ptr.shape[0] - 1
+        self.params.rhs_sparse = <cmumps.ZMUMPS_COMPLEX *>data.data
+        self.params.irhs_sparse = <cmumps.MUMPS_INT *>row_ind.data
+        self.params.irhs_ptr = <cmumps.MUMPS_INT *>col_ptr.data
+
+    def set_schur(self,
+                  np.ndarray[np.complex128_t, ndim=2, mode='c'] schur,
+                  np.ndarray[cmumps.MUMPS_INT, ndim=1] schur_vars):
+
+        if schur.shape[0] != schur.shape[1]:
+            raise ValueError("Schur matrix must be squared!")
+        if schur.shape[0] != schur_vars.shape[0]:
+            raise ValueError("Number of Schur variables must agree "
+                             "with Schur complement size!")
+
+        self.params.size_schur = schur.shape[0]
+        self.params.schur = <cmumps.ZMUMPS_COMPLEX *>schur.data
+        self.params.listvar_schur = <cmumps.MUMPS_INT *>schur_vars.data
diff --git a/kwant/linalg/cmumps.pxd b/kwant/linalg/cmumps.pxd
new file mode 100644
index 0000000..f4030c4
--- /dev/null
+++ b/kwant/linalg/cmumps.pxd
@@ -0,0 +1,49 @@
+cdef extern from "mumps_c_types.h":
+     ctypedef int MUMPS_INT
+
+     ctypedef float SMUMPS_REAL
+     ctypedef float SMUMPS_COMPLEX
+
+     ctypedef double DMUMPS_REAL
+     ctypedef double DMUMPS_COMPLEX
+
+     ctypedef struct mumps_complex:
+         pass
+
+     ctypedef struct mumps_double_complex:
+         pass
+
+     ctypedef float CMUMPS_REAL
+     ctypedef mumps_complex CMUMPS_COMPLEX
+
+     ctypedef double ZMUMPS_REAL
+     ctypedef mumps_double_complex ZMUMPS_COMPLEX
+
+
+cdef extern from "zmumps_c.h":
+     ctypedef struct ZMUMPS_STRUC_C:
+         MUMPS_INT sym, par, job
+         MUMPS_INT comm_fortran
+         MUMPS_INT icntl[40]
+         ZMUMPS_REAL cntl[15]
+
+         MUMPS_INT n
+
+         MUMPS_INT nz
+         MUMPS_INT *irn, *jcn
+         ZMUMPS_COMPLEX *a
+
+         MUMPS_INT nrhs, lrhs
+         ZMUMPS_COMPLEX *rhs
+
+         MUMPS_INT info[40], infog[40]
+         ZMUMPS_REAL rinfo[40], rinfog[40]
+
+         MUMPS_INT nz_rhs
+         ZMUMPS_COMPLEX *rhs_sparse
+         MUMPS_INT *irhs_sparse, *irhs_ptr
+
+         MUMPS_INT size_schur, *listvar_schur
+         ZMUMPS_COMPLEX *schur
+
+     cdef void zmumps_c(ZMUMPS_STRUC_C *)
diff --git a/kwant/linalg/cmumps.py b/kwant/linalg/cmumps.py
new file mode 100644
index 0000000..4754483
--- /dev/null
+++ b/kwant/linalg/cmumps.py
@@ -0,0 +1,3 @@
+import numpy as np
+
+int_dtype = np.int32
diff --git a/kwant/linalg/fortran_helpers.py b/kwant/linalg/fortran_helpers.py
new file mode 100644
index 0000000..2cc65ef
--- /dev/null
+++ b/kwant/linalg/fortran_helpers.py
@@ -0,0 +1,111 @@
+import numpy as np
+
+
+def prepare_for_fortran(overwrite, *args):
+    """Convert arrays to Fortran format.
+
+    This function takes a number of array objects in `args` and converts them
+    to a format that can be directly passed to a Fortran function (Fortran
+    contiguous numpy array). If the arrays have different data type, they
+    converted arrays are cast to a common compatible data type (one of numpy's
+    `float32`, `float64`, `complex64`, `complex128` data types).
+
+    If `overwrite` is ``False``, an numpy array that would already be in the
+    correct format (Fortran contiguous, right data type) is neverthelessed
+    copied. (Hence, overwrite = True does not imply that acting on the
+    converted array in the return values will overwrite the original array in
+    all cases -- it does only so if the original array was already in the
+    correct format. The conversions require copying. In fact, that's the same
+    behavior as in scipy, it's just not explicitly stated there)
+
+    If an argument is ``None``, it is just passed through and not used to
+    determine the proper data type.
+
+    `prepare_for_lapack` returns a character indicating the proper
+    data type in LAPACK style ('s', 'd', 'c', 'z') and a list of
+    properly converted arrays.
+    """
+
+    # Make sure we have numpy arrays
+    mats = [None]*len(args)
+    for i in xrange(len(args)):
+        if args[i] is not None:
+            arr = np.asanyarray(args[i])
+            if not np.issubdtype(arr.dtype, np.number):
+                raise ValueError("Argument cannot be interpreted "
+                                 "as a numeric array")
+
+            mats[i] = (arr, arr is not args[i] or overwrite)
+        else:
+            mats[i] = (None, True)
+
+    # First figure out common dtype
+    # Note: The return type of common_type is guaranteed to be a floating point
+    #       kind.
+    dtype = np.common_type(*[arr for arr, ovwrt in mats if arr is not None])
+
+    if dtype == np.float32:
+        lapacktype = 's'
+    elif dtype == np.float64:
+        lapacktype = 'd'
+    elif dtype == np.complex64:
+        lapacktype = 'c'
+    elif dtype == np.complex128:
+        lapacktype = 'z'
+    else:
+        raise AssertionError("Unexpected data type from common_type")
+
+    ret = [ lapacktype ]
+    for npmat, ovwrt in mats:
+        # Now make sure that the array is contiguous, and copy if necessary.
+        if npmat is not None:
+            if npmat.ndim == 2:
+                if not npmat.flags["F_CONTIGUOUS"]:
+                    npmat = np.asfortranarray(npmat, dtype = dtype)
+                elif npmat.dtype != dtype:
+                    npmat = npmat.astype(dtype)
+                elif not ovwrt:
+                    # ugly here: copy makes always C-array, no way to tell it
+                    # to make a Fortran array.
+                    npmat = np.asfortranarray(npmat.copy())
+            elif npmat.ndim == 1:
+                if not npmat.flags["C_CONTIGUOUS"]:
+                    npmat = np.ascontiguousarray(npmat, dtype = dtype)
+                elif npmat.dtype != dtype:
+                    npmat = npmat.astype(dtype)
+                elif not ovwrt:
+                    npmat = np.asfortranarray(npmat.copy())
+            else:
+                raise ValueError("Dimensionality of array is not 1 or 2")
+
+        ret.append(npmat)
+
+    return tuple(ret)
+
+
+def assert_fortran_mat(*mats):
+    """Check if the input ndarrays are all proper Fortran matrices."""
+
+    # This is a workaround for a bug in numpy version < 2.0,
+    # where 1x1 matrices do not have the F_Contiguous flag set correctly.
+    for mat in mats:
+        if (mat is not None and (mat.shape[0] > 1 or mat.shape[1] > 1) and
+            not mat.flags["F_CONTIGUOUS"]):
+            raise ValueError("Input matrix must be Fortran contiguous")
+
+
+def assert_fortran_matvec(*arrays):
+    """Check if the input ndarrays are all proper Fortran matrices
+    or vectors."""
+
+    # This is a workaround for a bug in numpy version < 2.0,
+    # where 1x1 matrices do not have the F_Contiguous flag set correctly.
+    for arr in arrays:
+        if not arr.ndim in (1, 2):
+            raise ValueError("Input must be either a vector "
+                             "or a matrix.")
+
+        if (not arr.flags["F_CONTIGUOUS"] or
+            (arr.ndim == 2 and arr.shape[0] == 1 and arr.shape[1] == 1) ):
+            raise ValueError("Input must be a Fortran ordered "
+                             "numpy array")
diff --git a/kwant/linalg/mumps.py b/kwant/linalg/mumps.py
new file mode 100644
index 0000000..9e03925
--- /dev/null
+++ b/kwant/linalg/mumps.py
@@ -0,0 +1,519 @@
+"""Interface to the MUMPS sparse solver library"""
+
+__all__ = ['MUMPSContext', 'schur_complement', 'AnalysisStatistics',
+           'FactorizationStatistics', 'MUMPSError']
+
+import time
+import numpy as np
+import scipy.sparse
+import _mumps
+from fortran_helpers import prepare_for_fortran
+
+orderings = { 'amd' : 0, 'amf' : 2, 'scotch' : 3, 'pord' : 4, 'metis' : 5,
+              'qamd' : 6, 'auto' : 7 }
+
+ordering_name = [ 'amd', 'user-defined', 'amf',
+                  'scotch', 'pord', 'metis', 'qamd']
+
+_possible_orderings = None
+
+
+def possible_orderings():
+    """Return the ordering options that are available in the current
+    installation of MUMPS.
+
+    Which ordering options are actually available depends how MUMPs was
+    compiled. Note that passing an ordering that is not avaialble in the
+    current installation of MUMPS will not fail, instead MUMPS will fall back
+    to a supported one.
+
+    Returns
+    -------
+    orderings : list of strings
+       A list of installed orderings that can be used in the `ordering` option
+       of MUMPS.
+    """
+    global _possible_orderings
+
+    if not _possible_orderings:
+        # Try all orderings on a small test matrix, and check which one was
+        # actually used.
+
+        _possible_orderings = ['auto']
+        for ordering in [0, 2, 3, 4, 5, 6]:
+            data = np.asfortranarray([1, 1], dtype=np.complex128)
+            row = np.asfortranarray([1, 2], dtype=_mumps.int_dtype)
+            col = np.asfortranarray([1, 2], dtype=_mumps.int_dtype)
+
+            instance = _mumps.zmumps()
+            instance.set_assembled_matrix(2, row, col, data)
+            instance.icntl[7] = ordering
+            instance.job = 1
+            instance.call()
+
+            if instance.infog[7] == ordering:
+                _possible_orderings.append(ordering_name[ordering])
+
+    return _possible_orderings
+
+
+class MUMPSError(RuntimeError):
+    def __init__(self, error):
+        self.error = error
+        RuntimeError.__init__(self, "MUMPS failed with error " +
+                              str(error))
+
+
+class AnalysisStatistics(object):
+    def __init__(self, inst, time=None):
+        self.est_mem_incore = inst.infog[17]
+        self.est_mem_ooc = inst.infog[27]
+        self.est_nonzeros = (inst.infog[20] if inst.infog[20] > 0 else
+                             -inst.infog[20] * 1000000)
+        self.est_flops = inst.rinfog[1]
+        self.ordering = ordering_name[inst.infog[7]]
+        self.time = time
+
+    def __str__(self):
+        string = " estimated memory for in-core factorization: " + \
+            str(self.est_mem_incore) + " mbytes\n"
+        string += " estimated memory for out-of-core factorization: " + \
+            str(self.est_mem_ooc) + " mbytes\n"
+        string += " estimated number of nonzeros in factors: " + \
+            str(self.est_nonzeros) + "\n"
+        string += " estimated number of flops: " + str(self.est_flops) + "\n"
+        string += " ordering used: " + self.ordering
+        if hasattr(self, "time"):
+            string += "\n analysis time: " + str(self.time) + " secs"
+
+        return string
+
+
+class FactorizationStatistics(object):
+    def __init__(self, inst, time=None, include_ordering=False):
+        # information about pivoting
+        self.offdiag_pivots = inst.infog[12] if inst.sym == 0 else 0
+        self.delayed_pivots = inst.infog[13]
+        self.tiny_pivots = inst.infog[25]
+
+        # possibly include ordering (used in schur_complement)
+        if include_ordering:
+            self.ordering = ordering_name[inst.infog[7]]
+
+        # information about runtime effiency
+        self.memory = inst.infog[22]
+        self.nonzeros = (inst.infog[29] if inst.infog[29] > 0 else
+                         -inst.infog[29] * 1000000)
+        self.flops = inst.rinfog[3]
+        if time:
+            self.time = time
+
+    def __str__(self):
+        string = " off-diagonal pivots: " + str(self.offdiag_pivots) + "\n"
+        string += " delayed pivots: " + str(self.delayed_pivots) + "\n"
+        string += " tiny pivots: " + str(self.tiny_pivots) + "\n"
+        if hasattr(self, "ordering"):
+            string += " ordering used: " + self.ordering + "\n"
+        string += " memory used during factorization : " + str(self.memory) + \
+            " mbytes\n"
+        string += " nonzeros in factored matrix: " + str(self.nonzeros) + "\n"
+        string += " floating point operations: " + str(self.flops)
+        if hasattr(self, "time"):
+            string += "\n factorization time: " + str(self.time) +" secs"
+
+        return string
+
+
+class MUMPSContext(object):
+    """MUMPSContext contains the internal data structures needed by the
+    MUMPS library and contains a user-friendly interface.
+
+    WARNING: Only complex numbers supported.
+
+    Examples
+    --------
+
+    Solving a small system of equations.
+
+    >>> import scipy.sparse as sp
+    >>> sp.coo_matrix([[1.,0],[0,2.]])
+    >>> ctx = kwant.linalg.mumps.MUMPSContext()
+    >>> ctx.factor(a)
+    >>> ctx.solve([1., 1.])
+    array([ 1.0+0.j,  0.5+0.j])
+
+    Instance variables
+    ------------------
+
+    analysis_stats : `AnalysisStatistics`
+        contains MUMPS statistics after an analysis step (i.e.  after a call to
+        `analyze` or `factor`)
+    factor_stats : `FactorizationStatistics`
+        contains MUMPS statistics after a factorization step (i.e.  after a
+        call to `factor`)
+
+    """
+
+    def __init__(self, verbose=False):
+        """Init the MUMPSContext class
+
+        Parameters
+        ----------
+
+        verbose : True or False
+            control whether MUMPS prints lots of internal statistics
+            and debug information to screen.
+        """
+        self.mumps_instance = None
+        self.dtype = None
+        self.verbose = verbose
+        self.factored = False
+
+    def analyze(self, a, ordering='auto', overwrite_a=False):
+        """Perform analysis step of MUMPS.
+
+        In the analyis step, MUMPS figures out a reordering for the matrix and
+        estimates number of operations and memory needed for the factorization
+        time. This step usually needs not be called separately (it is done
+        automatically by `factor`), but it can be useful to test which ordering
+        would give best performance in the actual factorization, as MUMPS
+        estimates are available in `analysis_stats`.
+
+        Parameters
+        ----------
+
+        a : sparse scipy matrix
+            input matrix. Internally, the matrix is converted to `coo` format
+            (so passing this format is best for performance)
+        ordering : { 'auto', 'amd', 'amf', 'scotch', 'pord', 'metis', 'qamd' }
+            ordering to use in the factorization. The availability of a
+            particular ordering depends on the MUMPS installation.  Default is
+            'auto'.
+        overwrite_a : True or False
+            whether the data in a may be overwritten, which can lead to a small
+            performance gain. Default is False.
+        """
+
+        a = a.tocoo()
+
+        if a.ndim != 2 or a.shape[0] != a.shape[1]:
+            raise ValueError("Input matrix must be square!")
+
+        if not ordering in orderings.keys():
+            raise ValueError("Unknown ordering '"+ordering+"'!")
+
+        dtype, row, col, data = _make_assembled_from_coo(a, overwrite_a)
+
+        if dtype != self.dtype:
+            self.mumps_instance = getattr(_mumps, dtype+"mumps")(self.verbose)
+            self.dtype = dtype
+
+        self.n = a.shape[0]
+        self.row = row
+        self.col = col
+        self.data = data
+        # Note: if I don't store them, they go out of scope and are
+        #       deleted. I however need the memory to stay around!
+
+        self.mumps_instance.set_assembled_matrix(a.shape[0], row, col, data)
+        self.mumps_instance.icntl[7] = orderings[ordering]
+        self.mumps_instance.job = 1
+        t1 = time.clock()
+        self.mumps_instance.call()
+        t2 = time.clock()
+        self.factored = False
+
+        if self.mumps_instance.infog[1] < 0:
+            raise MUMPSError(self.mumps_instance.infog[1])
+
+        self.analysis_stats = AnalysisStatistics(self.mumps_instance,
+                                                 t2 - t1)
+
+    def factor(self, a, ordering='auto', ooc=False, pivot_tol=0.01,
+               reuse_analysis=False, overwrite_a=False):
+        """Perform the LU factorization of the matrix.
+
+        This LU factorization can then later be used to solve a linear system
+        with `solve`. Statistical data of the factorization is stored in
+        `factor_stats`.
+
+        Parameters
+        ----------
+
+        a : sparse scipy matrix
+            input matrix. Internally, the matrix is converted to `coo` format
+            (so passing this format is best for performance)
+        ordering : { 'auto', 'amd', 'amf', 'scotch', 'pord', 'metis', 'qamd' }
+            ordering to use in the factorization. The availability of a
+            particular ordering depends on the MUMPS installation.  Default is
+            'auto'.
+        ooc : True or False
+            whether to use the out-of-core functionality of MUMPS.
+            (out-of-core means that data is written to disk to reduce memory
+            usage.) Default is False.
+        pivot_tol: number in the range [0, 1]
+            pivoting threshold. Pivoting is typically limited in sparse
+            solvers, as too much pivoting destroys sparsity. 1.0 means full
+            pivoting, whereas 0.0 means no pivoting. Default is 0.01.
+        reuse_analysis: True or False
+            whether to reuse the analysis done in a previous call to `analyze`
+            or `factor`. If the structure of the matrix stays the same, and the
+            numerical values do not change much, the previous analysis can be
+            reused, saving some time.  WARNING: There is no check whether the
+            structure of your matrix is compatible with the previous
+            analysis. Also, if the values are not similar enough, there might
+            be loss of accuracy, without a warning. Default is False.
+        overwrite_a : True or False
+            whether the data in a may be overwritten, which can lead to a small
+            performance gain. Default is False.
+        """
+        a = a.tocoo()
+
+        if a.ndim != 2 or a.shape[0] != a.shape[1]:
+            raise ValueError("Input matrix must be square!")
+
+        # Analysis phase must be done before factorization
+        # Note: previous analysis is reused only if reuse_analysis == True
+
+        if reuse_analysis:
+            if mumps_instance is None:
+                warnings.warn("Missing analysis although reuse_analysis=True. "
+                              "New analysis is performed.",
+                              RuntimeWarning)
+
+                self.analyze(a, ordering=ordering, overwrite_a=overwrite_a)
+            else:
+                dtype, row, col, data = _make_assembled_from_coo(a,
+                                                                 overwrite_a)
+
+                if self.dtype != dtype:
+                    raise ValueError("MUMPSContext dtype and matrix dtype "
+                                     "incompatible!")
+
+                self.n = a.shape[0]
+                self.row = row
+                self.col = col
+                self.data = data
+                self.mumps_instance.set_assembled_matrix(a.shape[0],
+                                                         row, col, data)
+        else:
+            self.analyze(a, ordering=ordering, overwrite_a=overwrite_a)
+
+        self.mumps_instance.icntl[22] = 1 if ooc else 0
+        self.mumps_instance.job = 2
+        self.mumps_instance.cntl[1] = pivot_tol
+
+        done = False
+        while not done:
+            t1 = time.clock()
+            self.mumps_instance.call()
+            t2 = time.clock()
+
+            # error -9 (not enough allocated memory) is treated
+            # specially, by increasing the memory relaxation parameter
+            if self.mumps_instance.infog[1] < 0:
+                if self.mumps_instance.infog[1] == -9:
+                    # double the additional memory
+                    self.mumps_instance.icntl[14] *= 2
+                else:
+                    raise MUMPSError(self.mumps_instance.infog[1])
+            else:
+                done = True
+
+        self.factored = True
+        self.factor_stats = FactorizationStatistics(self.mumps_instance,
+                                                    t2 - t1)
+
+    def _solve_sparse(self, b):
+        b = b.tocsc()
+        x = np.empty((b.shape[0], b.shape[1]),
+                     order='F', dtype=self.data.dtype)
+
+        dtype, col_ptr, row_ind, data = \
+            _make_sparse_rhs_from_csc(b, self.data.dtype)
+
+        if b.shape[0] != self.n:
+            raise ValueError("Right hand side has wrong size")
+
+        if self.dtype != dtype:
+            raise ValueError("Data type of right hand side is not "
+                             "compatible with the dtype of the "
+                             "linear system")
+
+        self.mumps_instance.set_sparse_rhs(col_ptr, row_ind, data)
+        self.mumps_instance.set_dense_rhs(x)
+        self.mumps_instance.job = 3
+        self.mumps_instance.icntl[20] = 1
+        self.mumps_instance.call()
+
+        return x
+
+    def _solve_dense(self, b, overwrite_b=False):
+        dtype, b = prepare_for_fortran(overwrite_b, b,
+                                       np.zeros(1, dtype=self.data.dtype))[:2]
+
+        if b.shape[0] != self.n:
+            raise ValueError("Right hand side has wrong size")
+
+        if self.dtype != dtype:
+            raise ValueError("Data type of right hand side is not "
+                             "compatible with the dtype of the "
+                             "linear system")
+
+        self.mumps_instance.set_dense_rhs(b)
+        self.mumps_instance.job = 3
+        self.mumps_instance.call()
+
+        return b
+
+    def solve(self, b, overwrite_b=False):
+        """Solve a linear system after the LU factorization has previously
+        been performed by `factor`.
+
+        Supports both dense and sparse right hand sides.
+
+        Parameters
+        ----------
+
+        b : dense (numpy) matrix or vector or sparse (scipy) matrix
+            the right hand side to solve. Accepts both dense and sparse input;
+            if the input is sparse 'csc' format is used internally (so passing
+            a 'csc' matrix gives best performance).
+        overwrite_b : True or False
+            whether the data in b may be overwritten, which can lead to a small
+            performance gain. Default is False.
+
+        Returns
+        -------
+
+        x : numpy array
+            the solution to the linear system as a dense matrix (a vector is
+            returned if b was a vector, otherwise a matrix is returned).
+        """
+
+        if not self.factored:
+            raise RuntimeError("Factorization must be done before solving!")
+
+        if scipy.sparse.isspmatrix(b):
+            return self._solve_sparse(b)
+        else:
+            return self._solve_dense(b, overwrite_b)
+
+
+def schur_complement(a, indices, ordering='auto', ooc=False, pivot_tol=0.01,
+                     calc_stats=False, overwrite_a=False):
+    """Compute the Schur complement block of matrix a using MUMPS.
+
+    Parameters:
+    a : sparse matrix
+        input matrix. Internally, the matrix is converted to `coo` format (so
+        passing this format is best for performance)
+    indices : 1d array
+        indices (row and column) of the desired Schur complement block.  (The
+        Schur complement block is square, so that the indices are both row and
+        column indices.)
+    ordering : { 'auto', 'amd', 'amf', 'scotch', 'pord', 'metis', 'qamd' }
+        ordering to use in the factorization. The availability of a particular
+        ordering depends on the MUMPS installation.  Default is 'auto'.
+    ooc : True or False
+        whether to use the out-of-core functionality of MUMPS.  (out-of-core
+        means that data is written to disk to reduce memory usage.) Default is
+        False.
+    pivot_tol: number in the range [0, 1]
+        pivoting threshold. Pivoting is typically limited in sparse solvers, as
+        too much pivoting destroys sparsity. 1.0 means full pivoting, whereas
+        0.0 means no pivoting. Default is 0.01.
+    calc_stats: True or False
+        whether to return the analysis and factorization statistics collected
+        by MUMPS. Default is False.
+    overwrite_a : True or False
+        whether the data in a may be overwritten, which can lead to a small
+        performance gain. Default is False.
+
+    Returns
+    -------
+
+    s : numpy array
+        Schur complement block
+    factor_stats: `FactorizationStatistics`
+        statistics of the factorization as collected by MUMPS.  Only returned
+        if ``calc_stats==True``.
+    """
+
+    if not scipy.sparse.isspmatrix(a):
+        raise ValueError("a must be a sparse scipy matrix!")
+
+    a = a.tocoo()
+
+    if a.ndim != 2 or a.shape[0] != a.shape[1]:
+        raise ValueError("Input matrix must be square!")
+
+    indices = np.asanyarray(indices)
+
+    if indices.ndim != 1:
+        raise ValueError("Schur indices must be specified in a 1d array!")
+
+    if not ordering in orderings.keys():
+        raise ValueError("Unknown ordering '"+ordering+"'!")
+
+    dtype, row, col, data = _make_assembled_from_coo(a, overwrite_a)
+    indices = _make_mumps_index_array(indices)
+
+    mumps_instance = getattr(_mumps, dtype+"mumps")()
+
+    mumps_instance.set_assembled_matrix(a.shape[0], row, col, data)
+    mumps_instance.icntl[7] = orderings[ordering]
+    mumps_instance.icntl[19] = 1
+    mumps_instance.icntl[31] = 1  # discard factors, from 4.10.0
+                                  # has no effect in earlier versions
+
+    schur_compl = np.empty((indices.size, indices.size),
+                           order='C', dtype=data.dtype)
+    mumps_instance.set_schur(schur_compl, indices)
+
+    mumps_instance.job = 4   # job=4 -> 1 and 2 after each other
+    t1 = time.clock()
+    mumps_instance.call()
+    t2 = time.clock()
+
+    if not calc_stats:
+        return schur_compl
+    else:
+        return schur_compl, \
+            FactorizationStatistics(mumps_instance, time=t2 - t1,
+                                    include_ordering=True)
+
+
+# Some internal helper functions
+def _make_assembled_from_coo(a, overwrite_a):
+    dtype, data = prepare_for_fortran(overwrite_a, a.data)
+
+    row = np.asfortranarray(a.row.astype(_mumps.int_dtype))
+    col = np.asfortranarray(a.col.astype(_mumps.int_dtype))
+
+    # MUMPS uses Fortran indices.
+    row += 1
+    col += 1
+
+    return dtype, row, col, data
+
+
+def _make_sparse_rhs_from_csc(b, dtype):
+    dtype, data = prepare_for_fortran(True, b.data,
+                                      np.zeros(1, dtype=dtype))[:2]
+
+    col_ptr = np.asfortranarray(b.indptr.astype(_mumps.int_dtype))
+    row_ind = np.asfortranarray(b.indices.astype(_mumps.int_dtype))
+
+    # MUMPS uses Fortran indices.
+    col_ptr += 1
+    row_ind += 1
+
+    return dtype, col_ptr, row_ind, data
+
+
+def _make_mumps_index_array(a):
+    a = np.asfortranarray(a.astype(_mumps.int_dtype))
+    a += 1                      # Fortran indices
+
+    return a
diff --git a/kwant/linalg/tests/_test_utils.py b/kwant/linalg/tests/_test_utils.py
new file mode 100644
index 0000000..9c14d10
--- /dev/null
+++ b/kwant/linalg/tests/_test_utils.py
@@ -0,0 +1,76 @@
+import numpy as np
+
+class _Random:
+    def __init__(self):
+        self._x = 13
+
+    def _set_seed(self, seed):
+        self._x = seed
+
+    def _randf(self):
+        # A very bad random number generator returning numbers between -1 and
+        # +1.  Just for making some matrices, and being sure that they are the
+        # same on any architecture.
+        m = 2**16
+        a = 11929
+        c = 36491
+
+        self._x = (a * self._x + c) % m
+
+        return (float(self._x)/m-0.5)*2
+
+    def _randi(self):
+        # A very bad random number generator returning number between 0 and 20.
+        # Just for making some matrices, and being sure that they are the same
+        # on any architecture.
+        m = 2**16
+        a = 11929
+        c = 36491
+
+        self._x = (a * self._x + c) % m
+
+        return self._x % 21
+
+    def randmat(self, n, m, dtype):
+        mat = np.empty((n, m), dtype = dtype)
+
+        if issubclass(dtype, np.complexfloating):
+            for i in xrange(n):
+                for j in xrange(m):
+                    mat[i,j] = self._randf() + 1j * self._randf()
+        elif issubclass(dtype, np.floating):
+            for i in xrange(n):
+                for j in xrange(m):
+                    mat[i,j] = self._randf()
+        elif issubclass(dtype, np.integer):
+            for i in xrange(n):
+                for j in xrange(m):
+                    mat[i,j] = self._randi()
+
+        return mat
+
+    def randvec(self, n, dtype):
+        vec = np.empty(n, dtype = dtype)
+
+        if issubclass(dtype, np.complexfloating):
+            for i in xrange(n):
+                vec[i] = self._randf() + 1j * self._randf()
+        elif issubclass(dtype, np.floating):
+            for i in xrange(n):
+                vec[i] = self._randf()
+        elif issubclass(dtype, np.integer):
+            for i in xrange(n):
+                vec[i] = self._randi()
+
+        return vec
+
+# Improved version of assert_arrays_almost_equal that uses the dtype to set the
+# precision (The default precision of assert_arrays_almost_equal is sometimes
+# too small for single-precision comparisions.)
+def assert_array_almost_equal(dtype, a, b):
+    if dtype == np.float32 or dtype == np.complex64:
+        prec = 5
+    else:
+        prec = 10
+
+    np.testing.assert_array_almost_equal(a, b, decimal=prec)
diff --git a/kwant/linalg/tests/test_linalg.py b/kwant/linalg/tests/test_linalg.py
index 19febb6..bc43d8a 100644
--- a/kwant/linalg/tests/test_linalg.py
+++ b/kwant/linalg/tests/test_linalg.py
@@ -3,84 +3,7 @@ from kwant.linalg import lu_factor, lu_solve, rcond_from_lu, gen_eig, schur, \
     convert_r2c_gen_schur, order_gen_schur, evecs_from_gen_schur
 from nose.tools import assert_equal, assert_true
 import numpy as np
-
-class _Random:
-    def __init__(self):
-        self._x = 13
-
-    def _set_seed(self, seed):
-        self._x = seed
-
-    def _randf(self):
-        #a very bad random number generator returning
-        #number between -1 and +1
-        #Just for making some matrices, and being sure that they
-        #are the same on any architecture
-        m = 2**16
-        a = 11929
-        c = 36491
-
-        self._x = (a * self._x + c) % m
-
-        return (float(self._x)/m-0.5)*2
-
-    def _randi(self):
-        #a very bad random number generator returning
-        #number between 0 and 20
-        #Just for making some matrices, and being sure that they
-        #are the same on any architecture
-        m = 2**16
-        a = 11929
-        c = 36491
-
-        self._x = (a * self._x + c) % m
-
-        return self._x % 21
-
-    def randmat(self, n, m, dtype):
-        mat = np.empty((n, m), dtype = dtype)
-
-        if issubclass(dtype, np.complexfloating):
-            for i in xrange(n):
-                for j in xrange(m):
-                    mat[i,j] = self._randf() + 1j * self._randf()
-        elif issubclass(dtype, np.floating):
-            for i in xrange(n):
-                for j in xrange(m):
-                    mat[i,j] = self._randf()
-        elif issubclass(dtype, np.integer):
-            for i in xrange(n):
-                for j in xrange(m):
-                    mat[i,j] = self._randi()
-
-        return mat
-
-    def randvec(self, n, dtype):
-        vec = np.empty(n, dtype = dtype)
-
-        if issubclass(dtype, np.complexfloating):
-            for i in xrange(n):
-                vec[i] = self._randf() + 1j * self._randf()
-        elif issubclass(dtype, np.floating):
-            for i in xrange(n):
-                vec[i] = self._randf()
-        elif issubclass(dtype, np.integer):
-            for i in xrange(n):
-                vec[i] = self._randi()
-
-        return vec
-
-#Improved version of assert_arrays_almost_equal that
-#uses the dtype to set the precision
-#(The default precision of assert_arrays_almost_equal is sometimes
-# too small for single-precision comparisions)
-def assert_array_almost_equal(dtype, a, b):
-    if dtype == np.float32 or dtype == np.complex64:
-        prec = 5
-    else:
-        prec = 10
-
-    np.testing.assert_array_almost_equal(a, b, decimal=prec)
+from _test_utils import _Random, assert_array_almost_equal
 
 def test_gen_eig():
     def _test_gen_eig(dtype):
diff --git a/kwant/linalg/tests/test_mumps.py b/kwant/linalg/tests/test_mumps.py
new file mode 100644
index 0000000..db9842f
--- /dev/null
+++ b/kwant/linalg/tests/test_mumps.py
@@ -0,0 +1,76 @@
+try:
+    from kwant.linalg.mumps import MUMPSContext, schur_complement
+    _no_mumps = False
+except ImportError:
+    _no_mumps = True
+
+from kwant.lattice import Honeycomb
+from kwant import Builder
+from nose.tools import assert_equal, assert_true
+from numpy.testing.decorators import skipif
+import numpy as np
+import scipy.sparse as sp
+from _test_utils import _Random, assert_array_almost_equal
+
+@skipif(_no_mumps)
+def test_lu_with_dense():
+    def _test_lu_with_dense(dtype):
+        rand = _Random()
+        a = rand.randmat(5, 5, dtype)
+        bmat = rand.randmat(5, 5, dtype)
+        bvec = rand.randvec(5, dtype)
+
+        ctx = MUMPSContext()
+        ctx.factor(sp.coo_matrix(a))
+
+        xvec = ctx.solve(bvec)
+        xmat = ctx.solve(bmat)
+
+        assert_array_almost_equal(dtype, np.dot(a, xmat), bmat)
+        assert_array_almost_equal(dtype, np.dot(a, xvec), bvec)
+
+        # now "sparse" right hand side
+
+        xvec = ctx.solve(sp.csc_matrix(bvec.reshape(5,1)))
+        xmat = ctx.solve(sp.csc_matrix(bmat))
+
+        assert_array_almost_equal(dtype, np.dot(a, xmat), bmat)
+        assert_array_almost_equal(dtype, np.dot(a, xvec),
+                                  bvec.reshape(5,1))
+
+    _test_lu_with_dense(np.complex128)
+
+
+@skipif(_no_mumps)
+def test_schur_complement_with_dense():
+    def _test_schur_complement_with_dense(dtype):
+        rand = _Random()
+        a = rand.randmat(10, 10, dtype)
+        s = schur_complement(sp.coo_matrix(a), range(3))
+        assert_array_almost_equal(dtype, np.linalg.inv(s),
+                                  np.linalg.inv(a)[:3, :3])
+
+    _test_schur_complement_with_dense(np.complex128)
+
+
+@skipif(_no_mumps)
+def test_error_minus_9(r=10):
+    """Test if MUMPSError -9 is properly caught by increasing memory"""
+
+    graphene = Honeycomb()
+    a, b = graphene.sublattices
+
+    def circle(pos):
+        x, y = pos
+        return x**2 + y**2 < r**2
+
+    sys = Builder()
+    sys[graphene.shape(circle, (0,0))] = -0.0001
+    hoppings = (((0, 0), b, a), ((0, 1), b, a), ((-1, 1), b, a))
+    for hopping in hoppings:
+        sys[sys.possible_hoppings(*hopping)] = - 1
+
+    ham = sys.finalized().hamiltonian_submatrix(sparse=True)[0]
+
+    # No need to check result, it's enough if no exception is raised
+    MUMPSContext().factor(ham)
diff --git a/setup.py b/setup.py
index d3b6298..f7300da 100755
--- a/setup.py
+++ b/setup.py
@@ -86,7 +86,12 @@ extensions = [ # (["kwant.graph.scotch", ["kwant/graph/scotch.pyx"]],
                               "kwant/graph/c_slicer/slicer.h"]}),
                (["kwant.linalg.lapack", ["kwant/linalg/lapack.pyx"]],
                 {"libraries" : ["lapack", "blas"],
-                 "depends" : ["kwant/linalg/f_lapack.pxd"]}) ]
+                 "depends" : ["kwant/linalg/f_lapack.pxd"]}),
+               (["kwant.linalg._mumps", ["kwant/linalg/_mumps.pyx"]],
+                {"libraries" : ["zmumps", "mumps_common", "pord",
+                                "metis", "mpiseq", "lapack", "blas",
+                                "gfortran"],
+                 "depends" : ["kwant/linalg/cmumps.pxd"]}) ]
 
 ext_modules = []
 for args, keywords in extensions:
-- 
GitLab