diff --git a/codes/tb/transforms.py b/codes/tb/transforms.py
index 683f503fd6f207ded11332307943aa9cdc328737..30696812a5b5a58a375a45e0e86dc14ec02107c2 100644
--- a/codes/tb/transforms.py
+++ b/codes/tb/transforms.py
@@ -25,7 +25,7 @@ def tb_to_khamvector(tb, nk, ndim):
 
     ks = np.linspace(-np.pi, np.pi, nk, endpoint=False)
     ks = np.concatenate((ks[nk // 2 :], ks[: nk // 2]), axis=0)  # shift for ifft
-    kgrid = np.meshgrid(ks, ks, indexing="ij")
+    kgrid = np.meshgrid(*([ks] * ndim), indexing="ij")
 
     num_keys = len(list(tb.keys()))
     tb_array = np.array(list(tb.values()))