diff --git a/kwant/solvers/common.py b/kwant/solvers/common.py index 1a4521cbd9c4f8e16bb05edddb62bf17a86a1b64..a1798636783a58ba4093449ecbacca0ba83d3163 100644 --- a/kwant/solvers/common.py +++ b/kwant/solvers/common.py @@ -187,7 +187,8 @@ class SparseSolver(object): u, ulinv, nprop, svd = modes if leadnum in out_leads: - kept_vars.extend(range(lhs.shape[0], lhs.shape[0] + nprop)) + kept_vars.append( + np.arange(lhs.shape[0], lhs.shape[0] + nprop)) u_out, ulinv_out = u[:, nprop:], ulinv[:, nprop:] u_in, ulinv_in = u[:, : nprop], ulinv[:, : nprop] @@ -240,7 +241,7 @@ class SparseSolver(object): lhs.shape) lhs = lhs + sig_sparse # __iadd__ is not implemented in v0.7 if leadnum in out_leads: - kept_vars.extend(list(indices)) + kept_vars.append(indices) if leadnum in in_leads: # defer formation of true rhs until the proper system # size is known @@ -267,6 +268,8 @@ class SparseSolver(object): rhs[i] = sp.bmat(bmat, format=self.rhsformat) + kept_vars = \ + np.concatenate(kept_vars) if kept_vars else np.empty(0, int) return LinearSys(lhs, rhs, kept_vars), lead_info def solve(self, sys, energy=0, out_leads=None, in_leads=None, diff --git a/kwant/solvers/mumps.py b/kwant/solvers/mumps.py index 0e57c891168309f173581962ba3ce7f1183da876..63b33bfeb99136eefe18ae5ad028b4108e2ca474 100644 --- a/kwant/solvers/mumps.py +++ b/kwant/solvers/mumps.py @@ -130,8 +130,11 @@ class Solver(common.SparseSolver): if len(sols): return np.concatenate(sols, axis=1) else: - return np.zeros(shape=(len(kept_vars), 0)) - + if isinstance(kept_vars, slice): + num_vars = len(xrange(*kept_vars.indices(a_shape[1]))) + else: + num_vars = len(kept_vars) + return np.zeros(shape=(num_vars, 0)) default_solver = Solver() diff --git a/kwant/solvers/sparse.py b/kwant/solvers/sparse.py index 66aab5a47231f0d0e721e6374d1855d351d70c7d..927e72cd5357ef465b70063c69ab654e3cada596 100644 --- a/kwant/solvers/sparse.py +++ b/kwant/solvers/sparse.py @@ -120,7 +120,11 @@ class Solver(common.SparseSolver): if len(sols): return np.asarray(sols).transpose() else: - return np.asarray(np.zeros(shape=(len(kept_vars), 0))) + if isinstance(kept_vars, slice): + num_vars = len(xrange(*kept_vars.indices(a_shape[1]))) + else: + num_vars = len(kept_vars) + return np.zeros(shape=(num_vars, 0)) default_solver = Solver() diff --git a/kwant/solvers/tests/_test_sparse.py b/kwant/solvers/tests/_test_sparse.py index 79e8666155d7a9d4a8a983106cce4e792f8ee9ce..1fe2f73bc61e6c90a548a4b7abdea1afc6739027 100644 --- a/kwant/solvers/tests/_test_sparse.py +++ b/kwant/solvers/tests/_test_sparse.py @@ -375,7 +375,6 @@ def test_ldos(ldos): def test_wavefunc_ldos_consistency(wave_func, ldos): L = 2 W = 3 - energy = 0 np.random.seed(31) sys = kwant.Builder() @@ -396,12 +395,13 @@ def test_wavefunc_ldos_consistency(wave_func, ldos): sys.attach_lead(top_lead) sys = sys.finalized() - wf = wave_func(sys, energy) - ldos2 = np.zeros(wf.num_orb, float) - for lead in xrange(len(sys.leads)): - temp = abs(wf(lead)) - temp **= 2 - ldos2 += temp.sum(axis=0) - ldos2 *= (0.5 / np.pi) + for energy in [0, 1000]: + wf = wave_func(sys, energy) + ldos2 = np.zeros(wf.num_orb, float) + for lead in xrange(len(sys.leads)): + temp = abs(wf(lead)) + temp **= 2 + ldos2 += temp.sum(axis=0) + ldos2 *= (0.5 / np.pi) - assert_almost_equal(ldos2, ldos(sys, energy)) + assert_almost_equal(ldos2, ldos(sys, energy))