Skip to content
Snippets Groups Projects
Commit d8aae75a authored by vlukes's avatar vlukes
Browse files

remove unsed assignments, fix docstrings and format

parent f4e95f8c
No related branches found
No related tags found
1 merge request!17Complete solution to Schur complement
......@@ -508,8 +508,16 @@ class Context:
else:
return self._solve_dense(b, overwrite_b)
def get_schur(self, indices, a=None, ordering='auto', ooc=False,
pivot_tol=0.01, overwrite_a=False, discard_factors=False):
def get_schur(
self,
indices,
a=None,
ordering="auto",
ooc=False,
pivot_tol=0.01,
overwrite_a=False,
discard_factors=False,
):
"""Compute the Schur complement block of matrix a using MUMPS.
Parameters:
......@@ -551,8 +559,9 @@ class Context:
if indices.ndim != 1:
raise ValueError("Schur indices must be specified in a 1d array!")
self.schur_indicies = indices = _makemumps_index_array(indices)
schur_compl = np.empty((indices.size, indices.size), order="C",
dtype=self.data.dtype)
schur_compl = np.empty(
(indices.size, indices.size), order="C", dtype=self.data.dtype
)
self.mumps_instance.icntl[19] = 1
self.mumps_instance.set_schur(schur_compl, indices)
......@@ -611,8 +620,7 @@ class Context:
"""
if self.schur_complement is None:
raise RuntimeError(
"Factorization must be done by calling 'get_schur()' "
"before solving!"
"Factorization must be done by calling 'get_schur()' before solving!"
)
if b.shape[0] != self.n:
......@@ -636,8 +644,7 @@ class Context:
self.mumps_instance.icntl[20] = 1
else:
dt, b = prepare_for_fortran(overwrite_b, b,
np.zeros(1, dtype=dtype))[:2]
dt, b = prepare_for_fortran(overwrite_b, b, np.zeros(1, dtype=dtype))[:2]
self.mumps_instance.set_dense_rhs(b)
x = b
......@@ -650,14 +657,14 @@ class Context:
self.mumps_instance.icntl[26] = 1 # Reduction/condensation phase
self.mumps_instance.job = 3
t = self.call()
self.call()
x2 = la.solve(self.schur_complement, schur_rhs) # solve dense system
schur_rhs[:] = x2
self.mumps_instance.icntl[26] = 2 # Expansion phase
self.mumps_instance.job = 3
t = self.call()
self.call()
return x
......@@ -709,9 +716,15 @@ def schur_complement(
if ``calc_stats==True``.
"""
with Context() as ctx:
schur_compl = ctx.get_schur(indices, a, ordering=ordering, ooc=ooc,
overwrite_a=overwrite_a,
pivot_tol=pivot_tol, discard_factors=True)
schur_compl = ctx.get_schur(
indices,
a,
ordering=ordering,
ooc=ooc,
overwrite_a=overwrite_a,
pivot_tol=pivot_tol,
discard_factors=True,
)
if calc_stats:
return [schur_compl, ctx.analysis_stats, ctx.factor_stats]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment