diff --git a/pymf/params/param_transforms.py b/pymf/params/param_transforms.py index 15f28dc807320a7d9cac68ddfe26163b8ab325db..998ab96aa6aa708d0323962415a7d0ba52843daf 100644 --- a/pymf/params/param_transforms.py +++ b/pymf/params/param_transforms.py @@ -27,7 +27,7 @@ def tb_to_flat(tb: tb_type) -> np.ndarray: def flat_to_tb( tb_param_complex: np.ndarray, - shape: tuple[int, int], + ndof: int, tb_keys: list[tuple[None] | tuple[int, ...]], ) -> tb_type: """Reverse operation to `tb_to_flat`. @@ -38,9 +38,8 @@ def flat_to_tb( ---------- tb_param_complex : 1d complex array that parametrises the tb model. - shape : - Tuple (n, n) where n is the number of internal degrees of freedom - (e.g. orbitals, spin, sublattice) within the tight-binding model. + ndof : + Number internal degrees of freedom within the unit cell. tb_keys : List of keys of the tight-binding dictionary. @@ -49,6 +48,7 @@ def flat_to_tb( tb : tight-binding dictionary """ + shape = (len(tb_keys), ndof, ndof) if len(tb_keys[0]) == 0: matrix = np.zeros((shape[-1], shape[-2]), dtype=complex) matrix[np.triu_indices(shape[-1])] = tb_param_complex diff --git a/pymf/params/rparams.py b/pymf/params/rparams.py index 1ca098cc301735bf919bf57533cbaee5170e7c85..cbb70be7cdd0433e9ef79d29e01349642618a612 100644 --- a/pymf/params/rparams.py +++ b/pymf/params/rparams.py @@ -45,4 +45,4 @@ def rparams_to_tb( Tight-biding dictionary. """ flat_matrix = real_to_complex(tb_params) - return flat_to_tb(flat_matrix, (len(tb_keys), ndof, ndof), tb_keys) + return flat_to_tb(flat_matrix, ndof, tb_keys)