diff --git a/tests/test_scans.py b/tests/test_scans.py
index 9eadb2a97fba2a09e59aac9f5e26c045c2f9d1d2..7ad5ff11e54eef0e76f882b4a8e5442af29b2c80 100644
--- a/tests/test_scans.py
+++ b/tests/test_scans.py
@@ -3,6 +3,7 @@ import os
 import pytest
 import numpy as np
 import PIL.Image
+import cv2
 from tempfile import NamedTemporaryFile
 from flask import Flask
 from io import BytesIO
@@ -280,3 +281,25 @@ def test_image_extraction(datadir, filename):
         assert img is not None
         assert np.average(np.array(img)) == 255
     assert page == 2
+
+
+@pytest.mark.parametrize('file_name', ["a4-rotated.png", "a4-3-markers.png", "a4-rotated-3-markers.png"])
+def test_realign_image(datadir, file_name):
+    dir_name = "cornermarkers"
+    epsilon = 2
+
+    test_file = os.path.join(datadir, dir_name, file_name)
+    test_image = np.array(PIL.Image.open(test_file))
+
+    correct_file = os.path.join(datadir, dir_name, "a4.png")
+    correct_image = cv2.imread(correct_file)
+
+    correct_corner_markers = scans.find_corner_marker_keypoints(correct_image)
+
+    result_image, result_corner_markers = scans.realign_image(test_image)
+
+    assert result_corner_markers is not None
+    for i in range(4):
+        diff = np.absolute(np.subtract(correct_corner_markers[i], result_corner_markers[i]))
+        assert diff[0] < epsilon
+        assert diff[1] < epsilon
diff --git a/zesje/images.py b/zesje/images.py
index ff8669e5b04aa77aa882f0bdb3acbaf9901ab450..72ff551973e04ae698109c1e199741361565163e 100644
--- a/zesje/images.py
+++ b/zesje/images.py
@@ -71,28 +71,32 @@ def fix_corner_markers(corner_keypoints, shape):
     bottom_right = [(x, y) for x, y in corner_keypoints if x > x_sep and y > y_sep]
 
     missing_point = ()
-
+    # index = 0
     if not top_left:
         # Top left point is missing
         (dx, dy) = tuple(map(sub, top_right[0], bottom_right[0]))
         missing_point = tuple(map(add, bottom_left[0], (dx, dy)))
+        index = 0
 
     elif not bottom_left:
         # Bottom left point is missing
         (dx, dy) = tuple(map(sub, top_right[0], bottom_right[0]))
         missing_point = tuple(map(sub, top_left[0], (dx, dy)))
+        index = 2
 
     elif not top_right:
         # Top right point is missing
         (dx, dy) = tuple(map(sub, top_left[0], bottom_left[0]))
         missing_point = tuple(map(add, bottom_right[0], (dx, dy)))
+        index = 1
 
     elif not bottom_right:
         # bottom right
         (dx, dy) = tuple(map(sub, top_left[0], bottom_left[0]))
         missing_point = tuple(map(sub, top_right[0], (dx, dy)))
+        index = 3
 
-    corner_keypoints.append(missing_point)
+    corner_keypoints.insert(index, missing_point)
     return corner_keypoints
 
 
diff --git a/zesje/scans.py b/zesje/scans.py
index d2b7197beec95da44c4a36e190a535dceccee033..589238b466ae3979569c1f7e8d3785ff267ebaa8 100644
--- a/zesje/scans.py
+++ b/zesje/scans.py
@@ -792,6 +792,9 @@ def check_corner_keypoints(image_array, keypoints):
 
 def realign_image(image_array, keypoints=None,
                   reference_keypoints=None):
+    """
+    This function realigns an images based on the template image
+    """
 
     if(keypoints is None):
         keypoints = find_corner_marker_keypoints(image_array)
@@ -813,8 +816,11 @@ def realign_image(image_array, keypoints=None,
     # get the transformation matrix
     M = cv2.getPerspectiveTransform(keypoints_32, reference_keypoints_32)
     # apply the transformation matrix
-    return_image = cv2.warpPerspective(image_array, M, (cols, rows))
+    return_image = cv2.warpPerspective(image_array, M, (cols, rows),
+                                       borderValue=(255, 255, 255, 255))
 
     return_keypoints = find_corner_marker_keypoints(return_image)
+    if(len(return_keypoints) != 4):
+        return_keypoints = fix_corner_markers(return_keypoints, return_image.shape)
 
     return return_image, return_keypoints