diff --git a/codes/tb/transforms.py b/codes/tb/transforms.py
index 30696812a5b5a58a375a45e0e86dc14ec02107c2..1982c6dd1233ce43325c45fdaa61ea7229709bc9 100644
--- a/codes/tb/transforms.py
+++ b/codes/tb/transforms.py
@@ -3,7 +3,7 @@ from scipy.fftpack import ifftn
 import itertools as it
 
 
-def tb_to_khamvector(tb, nk, ndim):
+def tb_to_khamvector(tb, nk, ndim, ks=None):
     """
     Real-space tight-binding model to hamiltonian on k-space grid.
 
@@ -22,9 +22,9 @@ def tb_to_khamvector(tb, nk, ndim):
         Hamiltonian evaluated on a k-point grid.
 
     """
-
-    ks = np.linspace(-np.pi, np.pi, nk, endpoint=False)
-    ks = np.concatenate((ks[nk // 2 :], ks[: nk // 2]), axis=0)  # shift for ifft
+    if ks is None:
+        ks = np.linspace(-np.pi, np.pi, nk, endpoint=False)
+        ks = np.concatenate((ks[nk // 2 :], ks[: nk // 2]), axis=0)  # shift for ifft
     kgrid = np.meshgrid(*([ks] * ndim), indexing="ij")
 
     num_keys = len(list(tb.keys()))
@@ -53,6 +53,10 @@ def tb_to_kfunc(tb):
     -------
     function
         A function that takes a k-space vector and returns a complex np.array.
+
+    Notes
+    -----
+    Function doesn't work for all dimensions
     """
 
     def kfunc(k):