From 836a19ca2465821056f083bebaec0cb4eac12dd0 Mon Sep 17 00:00:00 2001
From: Antonio Manesco <am@antoniomanesco.org>
Date: Fri, 29 Dec 2023 16:59:38 +0100
Subject: [PATCH] fix real-space solver

---
 codes/solvers.py | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/codes/solvers.py b/codes/solvers.py
index f119105..6c5eb98 100644
--- a/codes/solvers.py
+++ b/codes/solvers.py
@@ -35,8 +35,7 @@ def finite_system_solver(model, optimizer, cost_function, optimizer_kwargs):
     partial_cost = partial(cost_function, model=model)
     optimize(initial_mf, partial_cost, optimizer, optimizer_kwargs)
 
-def real_space_cost(mf, model):
-    shape = mf.shape
+def real_space_cost(mf, model, shape):
     mf = utils.flat_to_matrix(utils.real_to_complex(mf), shape)
     mf_dict = {}
     for i, key in enumerate(model.guess.keys()):
@@ -68,8 +67,9 @@ def rspace_solver(model, optimizer, cost_function, optimizer_kwargs):
     """
     model.kgrid_evaluation(nk=model.nk)
     initial_mf = np.array([*model.guess.values()])
+    shape = initial_mf.shape
     initial_mf = utils.complex_to_real(utils.matrix_to_flat(initial_mf))
-    partial_cost = partial(cost_function, model=model)
+    partial_cost = partial(cost_function, model=model, shape=shape)
     optimize(initial_mf, partial_cost, optimizer, optimizer_kwargs)
 
 def kspace_cost(mf, model):
-- 
GitLab