diff --git a/kwant/lattice.py b/kwant/lattice.py
index 8bd91dea6ddf0d61ae7c9d36a88227d95bf6c0d9..7183fb5e0ab31a9c303fa0d6205bf49f922d333f 100644
--- a/kwant/lattice.py
+++ b/kwant/lattice.py
@@ -298,7 +298,7 @@ class Polyatomic:
         Site = builder.Site
         sls = self.sublattices
         shortest_hopping = sls[0].n_closest(
-            Site(sls[0], ([0] * sls[0].dim)).pos, 2)[-1]
+            sls[0].pos(([0] * sls[0].lattice_dim)), 2)[-1]
         eps *= np.linalg.norm(self.vec(shortest_hopping))
         nvec = len(self._prim_vecs)
         sublat_pairs = [(i, j) for (i, j) in product(sls, sls)
diff --git a/kwant/linalg/tests/test_lll.py b/kwant/linalg/tests/test_lll.py
index c4d96f3091b73351faf7ecb810c28dd13a625b27..2d2616615813faab5b3284095a24dab830cf6c72 100644
--- a/kwant/linalg/tests/test_lll.py
+++ b/kwant/linalg/tests/test_lll.py
@@ -24,10 +24,11 @@ def test_lll():
 
 def test_cvp():
     rng = ensure_rng(0)
-    for i in range(10):
-        mat = rng.randn(4, 4)
-        mat = lll.lll(mat)[0]
-        for j in range(4):
-            point = 50 * rng.randn(4)
-            assert np.array_equal(lll.cvp(point, mat, 10)[:3],
-                                  lll.cvp(point, mat, 3))
+    for i in range(1, 5):
+        for j in range(i, 5):
+            mat = rng.randn(i, j)
+            mat = lll.lll(mat)[0]
+            for k in range(4):
+                point = 50 * rng.randn(j)
+                assert np.array_equal(lll.cvp(point, mat, 10)[:3],
+                                      lll.cvp(point, mat, 3))
diff --git a/kwant/tests/test_lattice.py b/kwant/tests/test_lattice.py
index 957cf2b63f1a606a4a3c5a7050a54f0f86e72d0e..c3ea78e7b17162d25a1b5314de5a51f1c507aa85 100644
--- a/kwant/tests/test_lattice.py
+++ b/kwant/tests/test_lattice.py
@@ -46,7 +46,7 @@ def test_neighbors():
     lat = lattice.honeycomb(1e-10)
     num_nth_nearest = [len(lat.neighbors(n)) for n in range(5)]
     assert num_nth_nearest == [2, 3, 6, 3, 6]
-    lat = lattice.square(1e8)
+    lat = lattice.general([(0, 1e8, 0, 0), (0, 0, 1e8, 0)])
     num_nth_nearest = [len(lat.neighbors(n)) for n in range(5)]
     assert num_nth_nearest == [1, 2, 2, 2, 4]
     lat = lattice.chain(1e-10)