From 4e6117c7916fffec1b0a8e97bd2e90c20a9654d8 Mon Sep 17 00:00:00 2001
From: Kostas Vilkelis <kostasvilkelis@gmail.com>
Date: Tue, 7 May 2024 13:00:49 +0200
Subject: [PATCH] make tb type hidden; rm redundant ks type; adjust import
 order

---
 pymf/kwant_helper/utils.py      | 9 +++++----
 pymf/mf.py                      | 8 ++++----
 pymf/model.py                   | 6 +++---
 pymf/observables.py             | 4 ++--
 pymf/params/param_transforms.py | 6 +++---
 pymf/params/rparams.py          | 6 +++---
 pymf/solvers.py                 | 4 ++--
 pymf/tb/tb.py                   | 6 +++---
 pymf/tb/transforms.py           | 8 +++-----
 pymf/tb/utils.py                | 6 +++---
 10 files changed, 31 insertions(+), 32 deletions(-)

diff --git a/pymf/kwant_helper/utils.py b/pymf/kwant_helper/utils.py
index 7f4b491..650b636 100644
--- a/pymf/kwant_helper/utils.py
+++ b/pymf/kwant_helper/utils.py
@@ -2,18 +2,19 @@ import inspect
 from copy import copy
 from itertools import product
 from typing import Callable
-from pymf.tb.tb import tb_type
 
+import numpy as np
+from scipy.sparse import coo_array
 import kwant
 import kwant.lattice
 import kwant.builder
-import numpy as np
-from scipy.sparse import coo_array
+
+from pymf.tb.tb import _tb_type
 
 
 def builder_to_tb(
     builder: kwant.builder.Builder, params: dict = {}, return_data: bool = False
-) -> tb_type:
+) -> _tb_type:
     """Construct a tight-binding dictionary from a `kwant.builder.Builder` system.
 
     Parameters
diff --git a/pymf/mf.py b/pymf/mf.py
index 33b6168..28fab3e 100644
--- a/pymf/mf.py
+++ b/pymf/mf.py
@@ -2,7 +2,7 @@ import numpy as np
 from scipy.fftpack import ifftn
 from typing import Tuple
 
-from pymf.tb.tb import add_tb, tb_type
+from pymf.tb.tb import add_tb, _tb_type
 from pymf.tb.transforms import ifftn_to_tb, tb_to_khamvector
 
 
@@ -36,8 +36,8 @@ def construct_density_matrix_kgrid(
 
 
 def construct_density_matrix(
-    h: tb_type, filling: float, nk: int
-) -> Tuple[tb_type, float]:
+    h: _tb_type, filling: float, nk: int
+) -> Tuple[_tb_type, float]:
     """Compute the real-space density matrix tight-binding dictionary.
 
     Parameters
@@ -69,7 +69,7 @@ def construct_density_matrix(
         return {(): density_matrix}, fermi
 
 
-def meanfield(density_matrix: tb_type, h_int: tb_type) -> tb_type:
+def meanfield(density_matrix: _tb_type, h_int: _tb_type) -> _tb_type:
     """Compute the mean-field correction from the density matrix.
 
     Parameters
diff --git a/pymf/model.py b/pymf/model.py
index 21e6783..89ea730 100644
--- a/pymf/model.py
+++ b/pymf/model.py
@@ -4,7 +4,7 @@ from pymf.mf import (
     construct_density_matrix,
     meanfield,
 )
-from pymf.tb.tb import add_tb, tb_type
+from pymf.tb.tb import add_tb, _tb_type
 
 
 def _check_hermiticity(h):
@@ -57,7 +57,7 @@ class Model:
     the interaction is the same between all internal degrees of freedom.
     """
 
-    def __init__(self, h_0: tb_type, h_int: tb_type, filling: float) -> None:
+    def __init__(self, h_0: _tb_type, h_int: _tb_type, filling: float) -> None:
         _tb_type_check(h_0)
         self.h_0 = h_0
         _tb_type_check(h_int)
@@ -76,7 +76,7 @@ class Model:
         _check_hermiticity(h_0)
         _check_hermiticity(h_int)
 
-    def mfield(self, mf: tb_type, nk: int = 200) -> tb_type:
+    def mfield(self, mf: _tb_type, nk: int = 200) -> _tb_type:
         """Computes a new mean-field correction from a given one.
 
         Parameters
diff --git a/pymf/observables.py b/pymf/observables.py
index 9c15c34..bd47485 100644
--- a/pymf/observables.py
+++ b/pymf/observables.py
@@ -1,9 +1,9 @@
 import numpy as np
 
-from pymf.tb.tb import tb_type
+from pymf.tb.tb import _tb_type
 
 
-def expectation_value(density_matrix: tb_type, observable: tb_type) -> complex:
+def expectation_value(density_matrix: _tb_type, observable: _tb_type) -> complex:
     """Compute the expectation value of an observable with respect to a density matrix.
 
     Parameters
diff --git a/pymf/params/param_transforms.py b/pymf/params/param_transforms.py
index 998ab96..2095e34 100644
--- a/pymf/params/param_transforms.py
+++ b/pymf/params/param_transforms.py
@@ -1,9 +1,9 @@
 import numpy as np
 
-from pymf.tb.tb import tb_type
+from pymf.tb.tb import _tb_type
 
 
-def tb_to_flat(tb: tb_type) -> np.ndarray:
+def tb_to_flat(tb: _tb_type) -> np.ndarray:
     """Parametrise a hermitian tight-binding dictionary by a flat complex vector.
 
     Parameters
@@ -29,7 +29,7 @@ def flat_to_tb(
     tb_param_complex: np.ndarray,
     ndof: int,
     tb_keys: list[tuple[None] | tuple[int, ...]],
-) -> tb_type:
+) -> _tb_type:
     """Reverse operation to `tb_to_flat`.
 
     It takes a flat complex 1d array and return the tight-binding dictionary.
diff --git a/pymf/params/rparams.py b/pymf/params/rparams.py
index cbb70be..dcb0479 100644
--- a/pymf/params/rparams.py
+++ b/pymf/params/rparams.py
@@ -6,10 +6,10 @@ from pymf.params.param_transforms import (
     real_to_complex,
     tb_to_flat,
 )
-from pymf.tb.tb import tb_type
+from pymf.tb.tb import _tb_type
 
 
-def tb_to_rparams(tb: tb_type) -> np.ndarray:
+def tb_to_rparams(tb: _tb_type) -> np.ndarray:
     """Parametrise a hermitian tight-binding dictionary by a real vector.
 
     Parameters
@@ -27,7 +27,7 @@ def tb_to_rparams(tb: tb_type) -> np.ndarray:
 
 def rparams_to_tb(
     tb_params: np.ndarray, tb_keys: list[tuple[None] | tuple[int, ...]], ndof: int
-) -> tb_type:
+) -> _tb_type:
     """Extract a hermitian tight-binding dictionary from a real vector parametrisation.
 
     Parameters
diff --git a/pymf/solvers.py b/pymf/solvers.py
index 40b4fea..dc73dc5 100644
--- a/pymf/solvers.py
+++ b/pymf/solvers.py
@@ -4,7 +4,7 @@ import scipy
 from typing import Optional, Callable
 
 from pymf.params.rparams import rparams_to_tb, tb_to_rparams
-from pymf.tb.tb import add_tb, tb_type
+from pymf.tb.tb import add_tb, _tb_type
 from pymf.model import Model
 from pymf.tb.utils import calculate_fermi_energy
 
@@ -43,7 +43,7 @@ def solver(
     nk: int = 100,
     optimizer: Optional[Callable] = scipy.optimize.anderson,
     optimizer_kwargs: Optional[dict[str, str]] = {"M": 0},
-) -> tb_type:
+) -> _tb_type:
     """Solve for the mean-field correction through self-consistent root finding.
 
     Parameters
diff --git a/pymf/tb/tb.py b/pymf/tb/tb.py
index 52678b8..445a99b 100644
--- a/pymf/tb/tb.py
+++ b/pymf/tb/tb.py
@@ -1,9 +1,9 @@
 import numpy as np
 
-tb_type = dict[tuple[()] | tuple[int, ...], np.ndarray]
+_tb_type = dict[tuple[()] | tuple[int, ...], np.ndarray]
 
 
-def add_tb(tb1: tb_type, tb2: tb_type) -> tb_type:
+def add_tb(tb1: _tb_type, tb2: _tb_type) -> _tb_type:
     """Add up two tight-binding dictionaries together.
 
     Parameters
@@ -21,7 +21,7 @@ def add_tb(tb1: tb_type, tb2: tb_type) -> tb_type:
     return {k: tb1.get(k, 0) + tb2.get(k, 0) for k in frozenset(tb1) | frozenset(tb2)}
 
 
-def scale_tb(tb: tb_type, scale: float) -> tb_type:
+def scale_tb(tb: _tb_type, scale: float) -> _tb_type:
     """Scale a tight-binding dictionary by a constant.
 
     Parameters
diff --git a/pymf/tb/transforms.py b/pymf/tb/transforms.py
index 374526d..bdd8ae2 100644
--- a/pymf/tb/transforms.py
+++ b/pymf/tb/transforms.py
@@ -1,12 +1,10 @@
 import itertools
 import numpy as np
-from typing import Optional
-from pymf.tb.tb import tb_type
 
-ks_type = Optional[np.ndarray]
+from pymf.tb.tb import _tb_type
 
 
-def tb_to_khamvector(tb: tb_type, nk: int) -> np.ndarray:
+def tb_to_khamvector(tb: _tb_type, nk: int) -> np.ndarray:
     """Evaluate a tight-binding dictionary on a k-space grid.
 
     Parameters
@@ -39,7 +37,7 @@ def tb_to_khamvector(tb: tb_type, nk: int) -> np.ndarray:
     return np.sum(tb_array * k_dependency, axis=0)
 
 
-def ifftn_to_tb(ifft_array: np.ndarray) -> tb_type:
+def ifftn_to_tb(ifft_array: np.ndarray) -> _tb_type:
     """
     Convert the result of `scipy.fft.ifftn` to a tight-binding dictionary.
 
diff --git a/pymf/tb/utils.py b/pymf/tb/utils.py
index d8be775..6ed2cbf 100644
--- a/pymf/tb/utils.py
+++ b/pymf/tb/utils.py
@@ -1,14 +1,14 @@
 from itertools import product
 import numpy as np
 
-from pymf.tb.tb import tb_type
+from pymf.tb.tb import _tb_type
 from pymf.mf import fermi_on_grid
 from pymf.tb.transforms import tb_to_khamvector
 
 
 def generate_guess(
     tb_keys: list[tuple[None] | tuple[int, ...]], ndof: int, scale: float = 1
-) -> tb_type:
+) -> _tb_type:
     """Generate guess tight-binding dictionary.
 
     Parameters
@@ -59,7 +59,7 @@ def generate_tb_keys(cutoff: int, dim: int) -> list[tuple[None] | tuple[int, ...
     return [*product(*([[*range(-cutoff, cutoff + 1)]] * dim))]
 
 
-def calculate_fermi_energy(tb: tb_type, filling: float, nk: int = 100):
+def calculate_fermi_energy(tb: _tb_type, filling: float, nk: int = 100):
     """
     Calculate the Fermi energy of a given tight-binding dictionary.
 
-- 
GitLab