From 3641258227502b413e912b6012dd938f4b98256c Mon Sep 17 00:00:00 2001
From: pacome <pacome.armagnat@gmail.com>
Date: Mon, 17 Sep 2018 19:46:23 +0200
Subject: [PATCH] Add center argument to Rectangle definition

---
 poisson/continuous/shapes.py | 98 ++++++++++++++++++++++++++++--------
 1 file changed, 77 insertions(+), 21 deletions(-)

diff --git a/poisson/continuous/shapes.py b/poisson/continuous/shapes.py
index 1a322fa..99988fe 100644
--- a/poisson/continuous/shapes.py
+++ b/poisson/continuous/shapes.py
@@ -461,13 +461,12 @@ class General(Shape):
     def geometry(self, x):
         return self.func(x)
 
-
 class Rectangle(General):
     '''
         Define a rectangle
     '''
 
-    def __init__(self, length, corner=None):
+    def __init__(self, length, corner=None, center=None):
         '''
         Class that defines a rectangular shape.
         When called and a numpy array is given as parameter it returns a boolean
@@ -482,6 +481,11 @@ class Rectangle(General):
             the lengths of the sides
             if Ly (Lz) is None, Ly (Lz) will be equal to Lx
             (usefull to quickly draw cubes)
+        center = (x0, y0, z0) : numbers
+            overwrites the corner argument
+            if center=None the corner will be used
+            the lower left corner of the (if z0 si None will be 2d)
+
         Returns:
         --------
         func: function
@@ -491,9 +495,9 @@ class Rectangle(General):
         super().__init__()
         self.length=length
         self.corner=corner
+        self.center=center
         self.prepare_param()
 
-
     def prepare_param(self):
 
         if isinstance(self.length, (int, float)):
@@ -508,6 +512,9 @@ class Rectangle(General):
 
         self.corner = np.asanyarray(self.corner)
 
+        if self.center is not None:
+            self.center = np.asarray(self.center)
+
         assert len(self.length) == len(self.corner), (
                 'The corner does not have the same'
                 + ' size as length / or a corner has'
@@ -525,32 +532,61 @@ class Rectangle(General):
         if len(x.shape) > 1:
 
             if len(self.length) == 3:
-                return ((self.corner[0] <= x[:, 0])
-                        * (x[:, 0] <= self.corner[0] + self.length[0])
-                        * (self.corner[1] <= x[:, 1])
-                        * (x[:, 1] <= self.corner[1] + self.length[1])
-                        * (self.corner[2] <= x[:, 2])
-                        * (x[:, 2] <= self.corner[2] + self.length[2]))
+                if self.center is None:
+                    return ((self.corner[0] <= x[:, 0])
+                            * (x[:, 0] <= self.corner[0] + self.length[0])
+                            * (self.corner[1] <= x[:, 1])
+                            * (x[:, 1] <= self.corner[1] + self.length[1])
+                            * (self.corner[2] <= x[:, 2])
+                            * (x[:, 2] <= self.corner[2] + self.length[2]))
+                else:
+                    return ((self.center[0] - self.length[0]/2 <= x[:, 0])
+                            * (x[:, 0] <= self.center[0] + self.length[0]/2)
+                            * (self.center[1] - self.length[1]/2 <= x[:, 1])
+                            * (x[:, 1] <= self.center[1] + self.length[1]/2)
+                            * (self.center[2] - self.length[2]/2 <= x[:, 2])
+                            * (x[:, 2] <= self.center[2] + self.length[2]/2))
+
             else:
-                return ((self.corner[0] <= x[:, 0])
+                if self.center is None:
+                    return ((self.corner[0] <= x[:, 0])
                         * (x[:, 0] <= self.corner[0] + self.length[0])
                         * (self.corner[1] <= x[:, 1])
                         * (x[:, 1] <= self.corner[1] + self.length[1]))
+                else:
+                    return ((self.center[0] - self.length[0]/2 <= x[:, 0])
+                            * (x[:, 0] <= self.center[0] + self.length[0]/2)
+                            * (self.center[1] - self.length[1]/2 <= x[:, 1])
+                            * (x[:, 1] <= self.center[1] + self.length[1]/2))
 
         else:
             if len(self.length) == 3:
-                return ((self.corner[0] <= x[0]
-                         <= self.corner[0] + self.length[0])
-                        and (self.corner[1] <= x[1]
-                             <= self.corner[1] + self.length[1])
-                        and (self.corner[2] <= x[2]
-                             <= self.corner[2] + self.length[2]))
-            else:
-                return ((self.corner[0] <= x[0]
-                         <= self.corner[0] + self.length[0])
-                        and (self.corner[1] <= x[1]
-                             <= self.corner[1] + self.length[1]))
+                if self.center is None:
+                    return ((self.corner[0] <= x[0]
+                             <= self.corner[0] + self.length[0])
+                            and (self.corner[1] <= x[1]
+                                 <= self.corner[1] + self.length[1])
+                            and (self.corner[2] <= x[2]
+                                 <= self.corner[2] + self.length[2]))
+                else:
+                    return ((self.center[0] - self.length[0]/2 <= x[0]
+                          <= self.center[0] + self.length[0]/2)
+                         and (self.center[1] - self.length[0]/2 <= x[1]
+                              <= self.center[1] + self.length[1]/2)
+                         and (self.center[2] - self.length[0]/2 <= x[2]
+                              <= self.center[2] + self.length[2]/2))
 
+            else:
+                if self.center is None:
+                    return ((self.corner[0] <= x[0]
+                             <= self.corner[0] + self.length[0])
+                            and (self.corner[1] <= x[1]
+                                 <= self.corner[1] + self.length[1]))
+                else:
+                    return ((self.center[0] - self.length[0]/2 <= x[0]
+                          <= self.center[0] + self.length[0]/2)
+                         and (self.center[1] - self.length[0]/2 <= x[1]
+                              <= self.center[1] + self.length[1]/2))
 
 class Ellipsoid(General):
     '''
@@ -994,6 +1030,26 @@ def __test_plot_rectangle():
     plt.scatter(x, y, c=mesh.point_label)
     plt.show()
 
+def __test_plot_rectangle_center():
+    rect1 = Rectangle(length=4, center=[2, 2])
+    rect2 = Rectangle(length=2, center=[2, 2])
+    rect3 = Rectangle(length=1, center=[2, 2])
+
+    bbox1 = [0, 4, 0, 4]
+    bbox2 = [1, 3, 1, 3]
+
+    mesh1 = (bbox1, 0.15, rect1, 0)
+    mesh2 = (bbox2, 0.1, rect2, 1)
+    hole1 = (rect3)
+    points = ([[2, 2]], 2)
+
+    mesh = GridBuilder(meshs=[mesh1, mesh2], holes=[hole1],
+                                  points=[points])
+
+    x, y = list(zip(*mesh.points))
+    plt.scatter(x, y, c=mesh.point_label)
+    plt.show()
+
 
 def __test_plot_ellipse_circ():
 
-- 
GitLab