From 2596791fcbfb4ff36148c66886930b1e599ce17d Mon Sep 17 00:00:00 2001
From: Kostas Vilkelis <kostasvilkelis@gmail.com>
Date: Tue, 7 May 2024 00:16:03 +0200
Subject: [PATCH] use ndof instead of shape

---
 pymf/params/param_transforms.py | 8 ++++----
 pymf/params/rparams.py          | 2 +-
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/pymf/params/param_transforms.py b/pymf/params/param_transforms.py
index 15f28dc..998ab96 100644
--- a/pymf/params/param_transforms.py
+++ b/pymf/params/param_transforms.py
@@ -27,7 +27,7 @@ def tb_to_flat(tb: tb_type) -> np.ndarray:
 
 def flat_to_tb(
     tb_param_complex: np.ndarray,
-    shape: tuple[int, int],
+    ndof: int,
     tb_keys: list[tuple[None] | tuple[int, ...]],
 ) -> tb_type:
     """Reverse operation to `tb_to_flat`.
@@ -38,9 +38,8 @@ def flat_to_tb(
     ----------
     tb_param_complex :
         1d complex array that parametrises the tb model.
-    shape :
-        Tuple (n, n) where n is the number of internal degrees of freedom
-        (e.g. orbitals, spin, sublattice) within the tight-binding model.
+    ndof :
+        Number internal degrees of freedom within the unit cell.
     tb_keys :
         List of keys of the tight-binding dictionary.
 
@@ -49,6 +48,7 @@ def flat_to_tb(
     tb :
         tight-binding dictionary
     """
+    shape = (len(tb_keys), ndof, ndof)
     if len(tb_keys[0]) == 0:
         matrix = np.zeros((shape[-1], shape[-2]), dtype=complex)
         matrix[np.triu_indices(shape[-1])] = tb_param_complex
diff --git a/pymf/params/rparams.py b/pymf/params/rparams.py
index 1ca098c..cbb70be 100644
--- a/pymf/params/rparams.py
+++ b/pymf/params/rparams.py
@@ -45,4 +45,4 @@ def rparams_to_tb(
         Tight-biding dictionary.
     """
     flat_matrix = real_to_complex(tb_params)
-    return flat_to_tb(flat_matrix, (len(tb_keys), ndof, ndof), tb_keys)
+    return flat_to_tb(flat_matrix, ndof, tb_keys)
-- 
GitLab