diff --git a/codes/solvers.py b/codes/solvers.py index f11910513aa1ca6d65e0ec78241c3c9d29d0bfdd..6c5eb980db2185a236e137b5a0bfc8e0993d6a53 100644 --- a/codes/solvers.py +++ b/codes/solvers.py @@ -35,8 +35,7 @@ def finite_system_solver(model, optimizer, cost_function, optimizer_kwargs): partial_cost = partial(cost_function, model=model) optimize(initial_mf, partial_cost, optimizer, optimizer_kwargs) -def real_space_cost(mf, model): - shape = mf.shape +def real_space_cost(mf, model, shape): mf = utils.flat_to_matrix(utils.real_to_complex(mf), shape) mf_dict = {} for i, key in enumerate(model.guess.keys()): @@ -68,8 +67,9 @@ def rspace_solver(model, optimizer, cost_function, optimizer_kwargs): """ model.kgrid_evaluation(nk=model.nk) initial_mf = np.array([*model.guess.values()]) + shape = initial_mf.shape initial_mf = utils.complex_to_real(utils.matrix_to_flat(initial_mf)) - partial_cost = partial(cost_function, model=model) + partial_cost = partial(cost_function, model=model, shape=shape) optimize(initial_mf, partial_cost, optimizer, optimizer_kwargs) def kspace_cost(mf, model):