Commit accb8c15 authored by Joseph Weston's avatar Joseph Weston
Browse files

implement comparison between arrays with the same shape

If the arrays do not have the same shape then we return NotImplemented.
parent fc33353c
......@@ -11,6 +11,7 @@
#include <cstddef>
#include <sstream>
#include <limits>
#include <assert.h>
#include "array.hh"
#include "arithmetic.hh"
#include "functions.hh"
......@@ -910,48 +911,128 @@ Hash hash(PyObject *obj)
}
template <typename T>
bool is_equal_data(PyObject *a_, PyObject *b_, size_t size)
bool compare_scalar(const int op, const T a, const T b) {
switch(op){
case Py_EQ: return a == b;
case Py_NE: return a != b;
case Py_LE: return a <= b;
case Py_GE: return a >= b;
case Py_LT: return a < b;
case Py_GT: return a > b;
default:
assert(false); // if we get here something is very wrong
return false; // stop the compiler complaining
}
}
template <>
bool compare_scalar<Complex>(const int op, const Complex a, const Complex b) {
switch(op){
case Py_EQ: return a == b;
case Py_NE: return a != b;
// this function is never called in a context where
// the following code path is run -- fall through
case Py_LE:
case Py_GT:
case Py_LT:
case Py_GE:
default:
assert(false);
return false; // stop the compiler complaining
}
}
template <typename T>
bool compare_data(int op, PyObject *a_, PyObject *b_, size_t size)
{
assert(Array<T>::check_exact(a_)); Array<T> *a = (Array<T>*)a_;
assert(Array<T>::check_exact(b_)); Array<T> *b = (Array<T>*)b_;
T *data_a = a->data();
T *data_b = b->data();
for (size_t i = 0; i < size; ++i)
if (data_a[i] != data_b[i]) return false;
return true;
const T *data_a = a->data();
const T *data_b = b->data();
// sequences are ordered the same as their first differing elements, see:
// https://docs.python.org/2/reference/expressions.html#not-in
// comparison for "multidimensional" sequences is identical to comparing
// the flattened sequences when they have the same shape (the present case)
size_t i = 0;
for (; i < size; ++i)
if (data_a[i] != data_b[i]) break;
// any of these operations should return true when objects are equal
if (i == size) return ((op == Py_EQ) || (op == Py_LE) || (op == Py_GE));
// encapsulate this into a function to handle the COMPLEX case
return compare_scalar<T>(op, data_a[i], data_b[i]);
}
bool (*is_equal_data_dtable[])(PyObject*, PyObject*, size_t) =
DTYPE_DISPATCH(is_equal_data);
// don't generate dispatch table for COMPLEX datatype as it will never be used
// in `rich_compare` (COMPLEX is unorderable), and the compiler complains about
// generating `compare_scalar<complex>` because some operations are undefined
bool (*compare_data_dtable[])(int, PyObject*, PyObject*, size_t) =
DTYPE_DISPATCH(compare_data);
PyObject *richcompare(PyObject *a, PyObject *b, int op)
{
if (op != Py_EQ && op != Py_NE) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
PyObject *result;
const bool equality_comparison = (op == Py_EQ || op == Py_NE);
// short circuit when we are comparing the same object
bool equal = (a == b);
if (equal) goto done;
if (equal) {
// any of these operations should return true when objects are equal
equal = (op == Py_EQ) || (op == Py_GE) || (op == Py_LE);
result = equal ? Py_True : Py_False;
goto done;
}
Dtype dtype;
if (coerce_to_arrays(&a, &b, &dtype) < 0) return 0;
// obviate the need for `compare_scalar<Complex` to
// handle the case of an undefined comparison
if (dtype == COMPLEX && !equality_comparison) {
result = Py_NotImplemented;
goto decref_then_done;
}
int ndim_a, ndim_b;
size_t *shape_a, *shape_b;
reinterpret_cast<Array_base*>(a)->ndim_shape(&ndim_a, &shape_a);
reinterpret_cast<Array_base*>(b)->ndim_shape(&ndim_b, &shape_b);
if (ndim_a != ndim_b) goto decref_then_done;
for (int d = 0; d < ndim_a; ++d)
if (shape_a[d] != shape_b[d]) goto decref_then_done;
equal = is_equal_data_dtable[int(dtype)](a, b, calc_size(ndim_a, shape_a));
// TODO: enable array comparisons between arrays of differing
// dimensions
if (ndim_a != ndim_b) {
if (equality_comparison) {
goto equality_then_done;
} else {
result = Py_NotImplemented;
goto decref_then_done;
}
}
for (int d = 0; d < ndim_a; ++d) {
if (shape_a[d] != shape_b[d]) {
if (equality_comparison) {
goto equality_then_done;
} else {
result = Py_NotImplemented;
goto decref_then_done;
}
}
}
// actually compare the data
equal = compare_data_dtable[int(dtype)](op, a, b, calc_size(ndim_a, shape_a));
result = equal ? Py_True : Py_False;
goto decref_then_done;
// non error-path exit points from this function
equality_then_done:
result = ((op == Py_EQ) == equal) ? Py_True : Py_False;
decref_then_done:
Py_DECREF(a);
Py_DECREF(b);
done:
PyObject *result = ((op == Py_EQ) == equal) ? Py_True : Py_False;
Py_INCREF(result);
return result;
}
......
......@@ -9,6 +9,7 @@
import operator, warnings
import platform
import itertools as it
import tinyarray as ta
from nose.tools import assert_raises
import numpy as np
......@@ -412,6 +413,37 @@ def test_sizeof():
assert_equal(sizeof, sizeof_should_be)
def test_comparison():
ops = operator
for op in [ops.ge, ops.gt, ops.le, ops.lt, ops.eq, ops.ne]:
for dtype in (int, float, complex):
for left, right in it.product((np.zeros, np.ones), repeat=2):
for shape in [(), (1,), (2,), (2, 2), (2, 2, 2), (2, 3)]:
a = left(shape, dtype)
b = right(shape, dtype)
if dtype is complex and op not in [ops.eq, ops.ne]:
# unorderable types
assert_raises(TypeError, op, ta.array(a), ta.array(b))
else:
# passing the same object
same = ta.array(a)
assert_equal(op(same, same),
op(a.tolist(), a.tolist()))
# passing different objects, but equal
assert_equal(op(ta.array(a), ta.array(a)),
op(a.tolist(), a.tolist()))
# passing different objects, not equal
assert_equal(op(ta.array(a), ta.array(b)),
op(a.tolist(), b.tolist()))
# test different ndims and different shapes
for shp1, shp2 in [((2,), (2, 2)), ((2, 2), (2, 3))]:
a = left(shp1, dtype)
b = right(shp2, dtype)
if op not in (ops.eq, ops.ne):
# unorderable types
assert_raises(TypeError, op, ta.array(a), ta.array(b))
def test_pickle():
import pickle
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment