Commit 474b3f89 authored by Christoph Groth's avatar Christoph Groth
Browse files

add comparisons for equality and non-equality and make tinyarray's hash agree with tuple's

parent b491f572
* Add missing arithmetic operations.
* Implement missing arithmetic operations.
* Implement missing comparisons.
* Implement creation of arrays from objects supporting the buffer protocol.
* 0-d arrays should not be sequences. Currently there are problems with using
0-d arrays with PySequence_Fast.
......
......@@ -481,52 +481,122 @@ fail:
return -1;
}
// The hash functions are modelled on Python's. It is important that a == b =>
// hash(a) == hash(b), otherwise there will be problems with dictionaries etc.
// Given the same input, these hash functions return the same result as those
// in Python. As tinyarrays compare equal to equivalent tuples it is important
// for the hashes to agree. If not, there will be problems with dictionaries.
long hash(long x)
{
return x;
return x != -1 ? x : -2;
}
long hash(double x)
{
const double two_to_31st = 1L << 31;
double intpart, fractpart;
fractpart = modf(x, &intpart);
if (fractpart == 0 &&
intpart >= std::numeric_limits<long>::min() &&
intpart <= std::numeric_limits<long>::max()) {
// This must return the same hash as an equal long.
return long(intpart);
return hash(long(intpart));
}
// Can't represent the number as a long: interpret its bits as hash!
static_assert(sizeof(double) >= sizeof(size_t),
"hash(double) has to be adopted for this machine.");
return long(*reinterpret_cast<size_t*>(&x));
int expo;
x = frexp(x, &expo) * two_to_31st;
long hipart = x; // Take the top 32 bits.
x = (x - double(hipart)) * two_to_31st; // Get the next 32 bits.
return hash(hipart + (long)x + (expo << 15));
}
long hash(Complex x)
{
// x.imag == 0 => hash(x.imag) == 0 => hash(x) == hash(x.real)
return hash(x.real()) + 1000003 * hash(x.imag());
return hash(x.real()) + 1000003L * hash(x.imag());
}
// This routine calculates the hash of a multi-dimensional array. The hash is
// equal to that of an arrangement of nested tuples equivalent to the array.
template <typename T>
long hash(Array<T> *self)
{
long mult = 1000003, r = 0x345678;
int ndim;
size_t *shape;
self->ndim_shape(&ndim, &shape);
Py_ssize_t size = calc_size(ndim, shape);
T *p = self->data();
while (--size >= 0) {
r = (r ^ hash(*p++)) * mult;
mult += long(82520 + size + size);
if (ndim == 0) return hash(*p);
const long mult_init = 1000003, r_init = 0x345678;
const long mul_addend = 82520, r_addend = 97531;
Py_ssize_t i[max_ndim];
long mult[max_ndim], r[max_ndim];
--ndim; // For convenience.
int d = 0;
i[0] = shape[0];
mult[0] = mult_init;
r[0] = r_init;
while (true) {
if (i[d]) {
--i[d];
if (d == ndim) {
r[d] = (r[d] ^ hash(*p++)) * mult[d];
mult[d] += mul_addend + 2 * i[d];
} else {
++d;
i[d] = shape[d];
mult[d] = mult_init;
r[d] = r_init;
}
} else {
if (d == 0) return hash(r[0] + r_addend);
--d;
r[d] = (r[d] ^ hash(r[d+1] + r_addend)) * mult[d];
mult[d] += mul_addend + 2 * i[d];
}
}
r += 97531;
if (r == -1) r = -2;
return r;
}
template <typename T>
bool is_equal_data(PyObject *a_, PyObject *b_, size_t size)
{
assert(Array<T>::check_exact(a_)); Array<T> *a = (Array<T>*)a_;
assert(Array<T>::check_exact(b_)); Array<T> *b = (Array<T>*)b_;
T *data_a = a->data();
T *data_b = b->data();
for (size_t i = 0; i < size; ++i)
if (data_a[i] != data_b[i]) return false;
return true;
}
bool (*is_equal_data_dtable[])(PyObject*, PyObject*, size_t) =
DTYPE_DISPATCH(is_equal_data);
PyObject *richcompare(PyObject *a, PyObject *b, int op)
{
if (op != Py_EQ && op != Py_NE) {
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
bool equal = (a == b);
if (equal) goto done;
Dtype dtype;
if (coerce_to_arrays(&a, &b, &dtype) < 0) return 0;
int ndim_a, ndim_b;
size_t *shape_a, *shape_b;
reinterpret_cast<Array_base*>(a)->ndim_shape(&ndim_a, &shape_a);
reinterpret_cast<Array_base*>(b)->ndim_shape(&ndim_b, &shape_b);
if (ndim_a != ndim_b) goto done;
for (int d = 0; d < ndim_a; ++d)
if (shape_a[d] != shape_b[d]) goto done;
equal = is_equal_data_dtable[int(dtype)](a, b, calc_size(ndim_a, shape_a));
done:
PyObject *result = ((op == Py_EQ) == equal) ? Py_True : Py_False;
Py_INCREF(result);
return result;
}
PyObject *get_dtype_py(PyObject *self, void *)
......@@ -1019,8 +1089,7 @@ PyTypeObject Array<T>::pytype = {
doc, // tp_doc
0, // tp_traverse
0, // tp_clear
// richcompare, // tp_richcompare
0, // tp_richcompare
(richcmpfunc)richcompare, // tp_richcompare
0, // tp_weaklistoffset
(getiterfunc)Array_iter<T>::make, // tp_iter
0, // tp_iternext
......
......@@ -147,11 +147,27 @@ def test_iteration():
assert_equal(np.array(ta.array(tuple(t))), np.array(t))
def test_hash():
def test_as_dict_key():
n = 100
for dtype in dtypes:
s = set(hash(ta.array(range(i), dtype)) for i in range(n))
assert_equal(len(s), n)
d = {}
for dtype in dtypes + dtypes:
for i in xrange(n):
d[ta.array(xrange(i), dtype)] = i
assert_equal(len(d), n)
for i in xrange(n):
assert_equal(d[tuple(xrange(i))], i)
def test_hash_equality():
for tup in [0, -1, -1.0, -1 + 0j, -0.3, 1.7, 0.4j,
-12.3j, 1 - 12.3j, 1.3 - 12.3j,
(), (-1,), (2,),
(0, 0), (-1, -1), (-5, 7), (3, -1, 0),
((0, 0), (0, 0)), (((-1,),),)]:
arr = ta.array(tup)
assert arr == tup
assert not (arr != tup)
assert hash(arr) == hash(tup)
def test_broadcasting():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment