Skip to content
Snippets Groups Projects
Commit 606bacb3 authored by Anton Akhmerov's avatar Anton Akhmerov Committed by Christoph Groth
Browse files

solvers: avoid performing factorization when not needed

parent 0b702110
Branches
Tags
No related merge requests found
......@@ -363,6 +363,16 @@ class SparseSolver(object):
check_hermiticity,
args, kwargs)
# Do not perform factorization if no calculation is to be done.
len_rhs = sum(i.shape[1] for i in linsys.rhs)
if isinstance(linsys.kept_vars, slice):
len_kv = len(xrange(*slice.indices(linsys.lhs.shape[0])))
else:
len_kv = len(linsys.kept_vars)
if not(len_rhs and len_kv):
return BlockResult(np.zeros((len_kv, len_rhs)),
lead_info, out_leads, in_leads)
flhs = self._factorized(linsys.lhs)
data = self._solve_linear_sys(flhs, linsys.rhs, linsys.kept_vars)
......@@ -410,16 +420,18 @@ class SparseSolver(object):
ldos = np.zeros(num_orb, float)
factored = None
factored = self._factorized(h)
for mat in rhs:
if mat.shape[1] == 0:
continue
# Do not perform factorization if no further calculation is needed.
if sum(i.shape[1] for i in rhs):
factored = self._factorized(h)
for mat in rhs:
if mat.shape[1] == 0:
continue
for j in xrange(0, mat.shape[1], self.nrhs):
jend = min(j + self.nrhs, mat.shape[1])
psi = self._solve_linear_sys(factored, [mat[:, j:jend]],
slice(num_orb))
ldos += np.sum(np.square(abs(psi)), axis=1)
for j in xrange(0, mat.shape[1], self.nrhs):
jend = min(j + self.nrhs, mat.shape[1])
psi = self._solve_linear_sys(factored, [mat[:, j:jend]],
slice(num_orb))
ldos += np.sum(np.square(abs(psi)), axis=1)
return ldos * (0.5 / np.pi)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment