diff --git a/kwant/lattice.py b/kwant/lattice.py
index 9b2e0995b76ab4ad8b598ef055868c5049833ce2..d117e04aa443c224f7843aab2d33f70224fb31cd 100644
--- a/kwant/lattice.py
+++ b/kwant/lattice.py
@@ -14,6 +14,7 @@ from math import sqrt
 import numpy as np
 import tinyarray as ta
 from . import builder
+from .linalg import lll
 
 
 def general(prim_vecs, basis=None, name=''):
@@ -96,67 +97,16 @@ class Polyatomic(object):
                             for offset, sname in zip(basis, name)]
         # Sequence of primitive vectors of the lattice.
         self.prim_vecs = prim_vecs
+        # Precalculation of auxiliary arrays for real space calculations.
+        self._reduced_vecs, self._transf = lll.lll(prim_vecs)
+        self._voronoi = ta.dot(lll.voronoi(self._reduced_vecs), self._transf)
 
     def shape(self, function, start):
-        """
-        Yield all the lattice sites which belong to a certain shape.
-
-        Parameters
-        ----------
-        function : a boolean function of real-space coordinates
-            A function which evaluates to True inside the desired shape.
-        start : real-valued vector
-            The starting point to the flood-fill algorithm.  If the site
-            nearest to `start` is not inside the shape, no sites are returned.
+        """Yield sites belonging to a certain shape.
 
-        Returns
-        -------
-        sites : sequence of `Site` objects
-            all the sites that belong to the lattice and fit inside the shape.
+        See `~kwant.lattice.Shape` for more information.
         """
-        Site = builder.Site
-
-        dim = len(start)
-        num_vecs = len(self.prim_vecs)
-        if dim != self.prim_vecs.shape[1]:
-            raise ValueError('Dimensionality of start position does not match'
-                             ' the space dimensionality.')
-        sls = self.sublattices
-        deltas = [ta.array(i * (0,) + (1,) + (num_vecs - 1 - i) * (0,))
-                  for i in xrange(num_vecs)]
-        deltas += [-delta for delta in deltas]
-
-        # Check if no sites are going to be added, to catch a common error.
-        empty = True
-        for sl in sls:
-            if function(sl(*sl.closest(start)).pos):
-                empty = False
-        if empty:
-            msg = 'No sites close to {0} are inside the desired shape.'
-            raise ValueError(msg.format(start))
-
-        # Continue to flood fill.
-        outer_shell = set(sl.closest(start) for sl in sls)
-        inner_shell = set()
-        while outer_shell:
-            tmp = set()
-            for tag in outer_shell:
-                vec = ta.dot(tag, self.prim_vecs)
-                any_hits = False
-                for sl in sls:
-                    if not function(vec + sl.offset):
-                        continue
-                    yield Site(sl, tag, True)
-                    any_hits = True
-                if not any_hits:
-                    continue
-                for shift in deltas:
-                    new_tag = tag + shift
-                    if new_tag not in inner_shell and \
-                       new_tag not in outer_shell:
-                        tmp.add(new_tag)
-            inner_shell = outer_shell
-            outer_shell = tmp
+        return Shape(self, function, start)
 
     def vec(self, int_vec):
         """
@@ -208,6 +158,9 @@ class Monatomic(builder.SiteFamily, Polyatomic):
         self.prim_vecs = prim_vecs
         self.inv_pv = ta.array(np.linalg.pinv(prim_vecs))
         self.offset = offset
+        # Precalculation of auxiliary arrays for real space calculations.
+        self._reduced_vecs, self._transf = lll.lll(prim_vecs)
+        self._voronoi = ta.dot(lll.voronoi(self._reduced_vecs), self._transf)
 
         def short_array_repr(array):
             full = ' '.join([i.lstrip() for i in repr(array).split('\n')])
@@ -245,9 +198,23 @@ class Monatomic(builder.SiteFamily, Polyatomic):
             raise ValueError("Dimensionality mismatch.")
         return tag
 
+    def n_closest(self, pos, n=1):
+        """Find n sites closest to position `pos`.
+
+        Returns
+        -------
+        sites : numpy array
+            An array with sites coordinates.
+        """
+        # TODO (Anton): transform to tinyarrays, once ta indexing is better.
+        return np.dot(lll.cvp(pos - self.offset, self._reduced_vecs, n),
+                      self._transf.T)
+
     def closest(self, pos):
-        """Find the site closest to position `pos`."""
-        return ta.array(ta.round(ta.dot(pos - self.offset, self.inv_pv)), int)
+        """
+        Find the lattice coordinates of the site closest to position `pos`.
+        """
+        return ta.array(self.n_closest(pos)[0])
 
     def pos(self, tag):
         """Return the real-space position of the site with a given tag."""
@@ -423,6 +390,93 @@ class TranslationalSymmetry(builder.Symmetry):
         return result
 
 
+class Shape(object):
+    def __init__(self, lattice, function, start):
+        """A class for finding all the lattice sites in a shape.
+
+        It uses a flood-fill algorithm, and takes into account
+        the symmetry of the builder to which it is provided, or
+        the symmetry, that is supplied to it after initialization.
+
+        Parameters
+        ----------
+        lattice : Polyatomic or Monoatomic lattice
+            Lattice, to which the resulting sites should belong.
+        function : callable
+            A function of real space coordinates, which should
+            return True for coordinates inside the shape,
+            and False otherwise.
+        start : float vector
+            The origin for the flood-fill algorithm.
+        """
+        self.lat, self.func, self.start = lattice, function, start
+
+    def __call__(self, builder_or_symmetry=None):
+        """
+        Yield all the lattice sites which belong to a certain shape.
+
+        Parameters
+        ----------
+        builder_or_symmetry : Builder or Symmetry instance
+            The builder to which the site from the shape are added, or
+            the symmetry, such that the sites from the shape belong to
+            its fundamental domain. If not provided, trivial symmetry is
+            used.
+
+        Returns
+        -------
+        sites : sequence of `Site` objects
+            all the sites that belong to the lattice and fit inside the shape.
+        """
+        Site = builder.Site
+        lat, func, start = self.lat, self.func, self.start
+        try:
+            symmetry = builder_or_symmetry.symmetry
+        except AttributeError:
+            symmetry = builder_or_symmetry
+        if symmetry is None:
+            symmetry = builder.NoSymmetry()
+
+        sym_site = lambda lat, tag: symmetry.to_fd(Site(lat, tag, True))
+
+        dim = len(start)
+        if dim != lat.prim_vecs.shape[1]:
+            raise ValueError('Dimensionality of start position does not match'
+                             ' the space dimensionality.')
+        sls = lat.sublattices
+        deltas = [ta.array(i) for i in lat._voronoi]
+
+        # Check if no sites are going to be added, to catch a common error.
+        empty = True
+        for sl in sls:
+            if func(sym_site(sl, sl.closest(start)).pos):
+                empty = False
+        if empty:
+            msg = 'No sites close to {0} are inside the desired shape.'
+            raise ValueError(msg.format(start))
+
+        # Continue to flood fill.
+        tags = set([sl.closest(start) for sl in sls])
+        new_sites = [sym_site(sl, tag) for sl in sls for tag in tags]
+        new_sites = [i for i in new_sites if func(i.pos)]
+        old_sites = set()
+        while new_sites:
+            tmp = set()
+            for site in new_sites:
+                yield site
+            tags = set((i.tag for i in new_sites))
+            for tag in tags:
+                for shift in deltas:
+                    for sl in sls:
+                        site = sym_site(sl, tag + shift)
+                        if site not in old_sites and \
+                           site not in new_sites and \
+                           func(site.pos):
+                            tmp.add(site)
+            old_sites = new_sites
+            new_sites = tmp
+
+
 ################ Library of lattices (to be extended)
 
 def chain(a=1, name=''):
diff --git a/kwant/linalg/tests/test_mumps.py b/kwant/linalg/tests/test_mumps.py
index 857c5cc66fe9f5b29e46bec1cb0d8b1250527f72..71508633461b1695874d638034e1e2a9c12da9e2 100644
--- a/kwant/linalg/tests/test_mumps.py
+++ b/kwant/linalg/tests/test_mumps.py
@@ -14,7 +14,6 @@ except ImportError:
 
 from kwant.lattice import honeycomb
 from kwant.builder import Builder, HoppingKind
-from nose.tools import assert_equal, assert_true
 from numpy.testing.decorators import skipif
 import numpy as np
 import scipy.sparse as sp
diff --git a/kwant/tests/test_lattice.py b/kwant/tests/test_lattice.py
index 6bb6d6e2f7c890b6607038d6105e4a21f506b5c3..3ef7b918c34e2ab518c2cbac17b709731cbb3ef7 100644
--- a/kwant/tests/test_lattice.py
+++ b/kwant/tests/test_lattice.py
@@ -15,6 +15,19 @@ from numpy.testing import assert_equal
 from kwant import lattice
 
 
+def test_closest():
+    np.random.seed(4)
+    lat = lattice.general(((1, 0), (0.5, sqrt(3)/2)))
+    for i in range(50):
+        point = 20 * np.random.rand(2)
+        closest = lat(*lat.closest(point)).pos
+        assert np.linalg.norm(point - closest) <= 1 / sqrt(3)
+    lat = lattice.general(np.random.randn(3, 3))
+    for i in range(50):
+        tag = np.random.randint(10, size=(3,))
+        assert_equal(lat.closest(lat(*tag).pos), tag)
+
+
 def test_general():
     for lat in (lattice.general(((1, 0), (0.5, 0.5))),
                 lattice.general(((1, 0), (0.5, sqrt(3)/2)),
@@ -36,7 +49,7 @@ def test_shape():
 
     lat = lattice.general(((1, 0), (0.5, sqrt(3) / 2)),
                                ((0, 0), (0, 1 / sqrt(3))))
-    sites = list(lat.shape(in_circle, (0, 0)))
+    sites = list(lat.shape(in_circle, (0, 0))())
     sites_alt = list()
     sl0, sl1 = lat.sublattices
     for x in xrange(-2, 3):
@@ -47,8 +60,15 @@ def test_shape():
                     sites_alt.append(site)
     assert len(sites) == len(sites_alt)
     assert_equal(set(sites), set(sites_alt))
-    assert_raises(ValueError, lat.shape(in_circle, (10, 10)).next)
-
+    assert_raises(ValueError, lat.shape(in_circle, (10, 10))().next)
+    # Check if narrow ribbons work.
+    for period in (0, 1), (1, 0), (1, -1):
+        vec = lat.vec(period)
+        sym = lattice.TranslationalSymmetry(vec)
+        def shape(pos):
+            return abs(pos[0] * vec[1] - pos[1] * vec[0]) < 10
+        sites = list(lat.shape(shape, (0, 0))(sym))
+        assert len(sites) > 35
 
 def test_translational_symmetry():
     ts = lattice.TranslationalSymmetry
diff --git a/kwant/tests/test_plotter.py b/kwant/tests/test_plotter.py
index 7290117cd546e624e04c4d337bbc33495c5ad5b7..62056470a8dbb40a7de1f9910d20765bb766fb8c 100644
--- a/kwant/tests/test_plotter.py
+++ b/kwant/tests/test_plotter.py
@@ -33,9 +33,7 @@ def sys_2d(W=3, r1=3, r2=8):
     lead0 = kwant.Builder(sym_lead0)
     lead2 = kwant.Builder(sym_lead0)
 
-    def lead_shape(pos):
-        (x, y) = pos
-        return (-1 < x < 1) and (-W / 2 < y < W / 2)
+    lead_shape = lambda pos: (-W / 2 < pos[1] < W / 2)
 
     lead0[lat.shape(lead_shape, (0, 0))] = 4 * t
     lead2[lat.shape(lead_shape, (0, 0))] = 4 * t
@@ -62,9 +60,7 @@ def sys_3d(W=3, r1=2, r2=4, a=1, t=1.0):
     sym_lead0 = kwant.TranslationalSymmetry(lat.vec((-1, 0, 0)))
     lead0 = kwant.Builder(sym_lead0)
 
-    def lead_shape(pos):
-        (x, y, z) = pos
-        return (-1 < x < 1) and (-W / 2 < y < W / 2) and abs(z) < 2
+    lead_shape = lambda pos: (-W / 2 < pos[1] < W / 2) and abs(pos[2]) < 2
 
     lead0[lat.shape(lead_shape, (0, 0, 0))] = 4 * t
     lead0[lat.nearest] = - t