Commit 6ff26513 authored by Joseph Weston's avatar Joseph Weston
Browse files

WIP

parent 09bda66c
Pipeline #426 failed with stage
......@@ -7,10 +7,13 @@
// top-level directory of this distribution and at
// https://gitlab.kwant-project.org/kwant/tinyarray.
#include <iostream>
#include <Python.h>
#include <cstddef>
#include <sstream>
#include <limits>
#include <assert.h>
#include "array.hh"
#include "arithmetic.hh"
#include "functions.hh"
......@@ -180,7 +183,7 @@ int examine_sequence(PyObject *arraylike, int *ndim, size_t *shape,
PyObject *p = arraylike;
int d = -1;
assert(PySequence_Check(p));
for (bool is_sequence = true; ; is_sequence = PySequence_Check(p)) {
for (bool is_sequence = true;; is_sequence = PySequence_Check(p)) {
if (is_sequence) {
++d;
if (d == ptrdiff_t(max_ndim)) {
......@@ -910,48 +913,125 @@ Hash hash(PyObject *obj)
}
template <typename T>
bool is_equal_data(PyObject *a_, PyObject *b_, size_t size)
bool compare_scalar(const int op, const T a, const T b) {
switch(op){
case Py_EQ: return a == b;
case Py_NE: return a != b;
case Py_LE: return a <= b;
case Py_GE: return a >= b;
case Py_LT: return a < b;
case Py_GT: return a > b;
default:
assert(false); // if we get here something is very wrong
return false; // stop the compiler complaining
}
}
template <>
bool compare_scalar<Complex>(const int op, const Complex a, const Complex b) {
switch(op){
case Py_EQ: return a == b;
case Py_NE: return a != b;
// this function is never called in a context where
// the following code path is run -- fall through
case Py_LE:
case Py_GT:
case Py_LT:
case Py_GE:
default:
assert(false);
return false; // stop the compiler complaining
}
}
template <typename T>
bool compare_data(int op, PyObject *a_, PyObject *b_, size_t size)
{
assert(Array<T>::check_exact(a_)); Array<T> *a = (Array<T>*)a_;
assert(Array<T>::check_exact(b_)); Array<T> *b = (Array<T>*)b_;
T *data_a = a->data();
T *data_b = b->data();
for (size_t i = 0; i < size; ++i)
if (data_a[i] != data_b[i]) return false;
return true;
const T *data_a = a->data();
const T *data_b = b->data();
// sequences are ordered the same as their first differing elements, see:
// https://docs.python.org/2/reference/expressions.html#not-in
// comparison for "multidimensional" sequences is identical to comparing
// the flattened sequences when they have the same shape (the present case)
size_t i = 0;
for (; i < size; ++i)
if (data_a[i] != data_b[i]) break;
// we got to the end without breaking: arrays are equal
if (i == size) return (op == Py_EQ);
// encapsulate this into a function to handle the COMPLEX case
return compare_scalar<T>(op, data_a[i], data_b[i]);
}
bool (*is_equal_data_dtable[])(PyObject*, PyObject*, size_t) =
DTYPE_DISPATCH(is_equal_data);
// don't generate dispatch table for COMPLEX datatype as it will never be used
// in `rich_compare` (COMPLEX is unorderable), and the compiler complains about
// generating `compare_scalar<complex>` because some operations are undefined
bool (*compare_data_dtable[])(int, PyObject*, PyObject*, size_t) =
DTYPE_DISPATCH(compare_data);
PyObject *richcompare(PyObject *a, PyObject *b, int op)
{
if (op != Py_EQ && op != Py_NE) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
PyObject *result;
const bool equality_comparison = (op == Py_EQ || op == Py_NE);
// short circuit when we are comparing the same object
bool equal = (a == b);
if (equal) goto done;
if (equal) {
result = ((op == Py_EQ) == equal) ? Py_True : Py_False;
goto done;
}
Dtype dtype;
if (coerce_to_arrays(&a, &b, &dtype) < 0) return 0;
// obviate the need for `compare_scalar<Complex` to
// handle the case of an undefined comparison
if (dtype == COMPLEX && !equality_comparison) {
PyErr_SetString(PyExc_TypeError, "unorderable type: complex()");
return 0;
}
int ndim_a, ndim_b;
size_t *shape_a, *shape_b;
reinterpret_cast<Array_base*>(a)->ndim_shape(&ndim_a, &shape_a);
reinterpret_cast<Array_base*>(b)->ndim_shape(&ndim_b, &shape_b);
if (ndim_a != ndim_b) goto decref_then_done;
for (int d = 0; d < ndim_a; ++d)
if (shape_a[d] != shape_b[d]) goto decref_then_done;
equal = is_equal_data_dtable[int(dtype)](a, b, calc_size(ndim_a, shape_a));
// TODO: enable array comparisons between arrays of differing
// dimensions
if (ndim_a != ndim_b) {
if (equality_comparison) {
goto equality_then_done;
} else {
result = Py_NotImplemented;
goto decref_then_done;
}
}
for (int d = 0; d < ndim_a; ++d) {
if (shape_a[d] != shape_b[d]) {
if (equality_comparison) {
goto equality_then_done;
} else {
result = Py_NotImplemented;
goto decref_then_done;
}
}
}
// actually compare the data
std::cout << "COMPARING DATA\n";
equal = compare_data_dtable[int(dtype)](op, a, b, calc_size(ndim_a, shape_a));
result = equal ? Py_True : Py_False;
goto done;
// exit points from this function
equality_then_done:
result = ((op == Py_EQ) == equal) ? Py_True : Py_False;
decref_then_done:
Py_DECREF(a);
Py_DECREF(b);
done:
PyObject *result = ((op == Py_EQ) == equal) ? Py_True : Py_False;
Py_INCREF(result);
return result;
}
......
......@@ -9,6 +9,7 @@
import operator, warnings
import platform
import itertools as it
import tinyarray as ta
from nose.tools import assert_raises
import numpy as np
......@@ -414,17 +415,31 @@ def test_sizeof():
def test_comparison():
ops = operator
# test comparison for 0D and 1D arrays
# TODO: once higher-dimensional arrays are implemented, test them here
for shape in [(), (1,), (2,)]: # should cover all code branches
for op in [ops.ge, ops.gt, ops.le, ops.lt, ops.eq, ops.ne]:
for dtype in (int, float, complex):
a = ta.ones(shape, dtype)
b = ta.zeros(shape, dtype)
if dtype is complex and op not in [ops.eq, ops.ne]:
assert_raises(TypeError, op, a, b)
else:
assert_equal(op(a, b), op(a.tolist(), b.tolist()))
for op in [ops.ge, ops.gt, ops.le, ops.lt, ops.eq, ops.ne]:
for dtype in (int, float, complex):
for left, right in it.product((np.zeros, np.ones), repeat=2):
for shape in [(), (1,), (2,), (2, 2), (2, 2, 2), (2, 3)]:
a = left(shape, dtype)
b = right(shape, dtype)
if dtype is complex and op not in [ops.eq, ops.ne]:
assert_raises(TypeError, op, a, b)
else:
# passing the same object
same = ta.array(a)
print(op, a, b)
assert_equal(op(same, same),
op(a.tolist(), a.tolist()))
# passing different objects
assert_equal(op(ta.array(a), ta.array(a)),
op(a.tolist(), a.tolist()))
# test different ndims and different shapes
for shp1, shp2 in [((2,), (2, 2)), ((2, 2), (2, 3))]:
a = left(shp1, dtype)
b = right(shp2, dtype)
if dtype is complex and op not in [ops.eq, ops.ne]:
assert_raises(TypeError, op, a, b)
elif op not in (ops.eq, ops.ne):
assert_raises(NotImplementedError, op, a, b)
def test_pickle():
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment