Commit 3ffa951e authored by Christoph Groth's avatar Christoph Groth

merge Joe's hashing fixes

Notably, this fixes the hash equality between tinyarrays and tuples on
64 bit Windows where long is only 32 bit.  The changes shouldn't have
any observable effect (other than a minor speedup) on Unix.
parents d541a800 7146ff15
Pipeline #12760 passed with stage
in 23 seconds
......@@ -806,67 +806,45 @@ fail:
// in Python. As tinyarrays compare equal to equivalent tuples it is important
// for the hashes to agree. If not, there will be problems with dictionaries.
long old_hash(long x)
{
return x != -1 ? x : -2;
}
#if PY_MAJOR_VERSION >= 3
// the only documentation for this is in the Python sourcecode
typedef Py_hash_t Hash;
const Hash HASH_IMAG = _PyHASH_IMAG;
const Py_hash_t HASH_IMAG = _PyHASH_IMAG;
Hash hash(long x_)
Py_hash_t hash(long x)
{
int negative = x_ < 0;
unsigned long x = (negative ? -x_ : x_);
// x is the absolute value of x_.
if (x < PyLong_BASE) {
if (negative)
return x == 1 ? -2 : -long(x);
else
return x;
}
// PyLong_SHIFT is 15 for 32-bit architectures and 30 for 64-bit. So a
// long contains at most 3 digits. Therefore, a starting value of n = 2 is
// sufficient.
Py_uhash_t r = 0;
for (int n = 2; n >= 0; --n) {
unsigned dig = (x >> (n * PyLong_SHIFT)) & PyLong_MASK;
if (r == 0 && dig == 0) continue;
r = ((r << PyLong_SHIFT) & _PyHASH_MODULUS) |
(r >> (_PyHASH_BITS - PyLong_SHIFT));
r += dig;
if (r >= _PyHASH_MODULUS) r -= _PyHASH_MODULUS;
}
if (negative) r = -r;
if (r == Py_uhash_t(-1)) r = Py_uhash_t(-2);
return r;
// For integers the hash is just the integer itself modulo _PyHASH_MODULUS
// except for the singular case of -1.
// define 'sign' of the correct width to avoid overflow
Py_hash_t sign = x < 0 ? -1 : 1;
Py_hash_t result = sign * ((sign * x) % _PyHASH_MODULUS);
return result == -1 ? -2 : result;
}
#else
typedef long Hash;
const Hash HASH_IMAG = 1000003L;
/* In Python 2 hashes were long integers, as indicated by
https://github.com/python/cpython/blob/2.7/Include/object.h#L314
*/
typedef long Py_hash_t;
typedef unsigned long Py_uhash_t;
const Py_hash_t HASH_IMAG = 1000003L;
Hash hash(long x)
Py_hash_t hash(long x)
{
return old_hash(x);
return x != -1 ? x : -2;
}
#endif
Hash hash(double x)
Py_hash_t hash(double x)
{
// We used to have our own implementation of this, but the extra function
// call is quite negligible compared to the execution time of the function.
return _Py_HashDouble(x);
}
Hash hash(Complex x)
Py_hash_t hash(Complex x)
{
// x.imag == 0 => hash(x.imag) == 0 => hash(x) == hash(x.real)
return hash(x.real()) + HASH_IMAG * hash(x.imag());
......@@ -875,7 +853,7 @@ Hash hash(Complex x)
// This routine calculates the hash of a multi-dimensional array. The hash is
// equal to that of an arrangement of nested tuples equivalent to the array.
template <typename T>
Hash hash(PyObject *obj)
Py_hash_t hash(PyObject *obj)
{
int ndim;
size_t *shape;
......@@ -884,10 +862,10 @@ Hash hash(PyObject *obj)
T *p = self->data();
if (ndim == 0) return hash(*p);
const long mult_init = 1000003, r_init = 0x345678;
const long mul_addend = 82520, r_addend = 97531;
const Py_uhash_t mult_init = 1000003, r_init = 0x345678;
const Py_uhash_t mul_addend = 82520, r_addend = 97531;
Py_ssize_t i[max_ndim];
long mult[max_ndim], r[max_ndim];
Py_uhash_t mult[max_ndim], r[max_ndim];
--ndim; // For convenience.
int d = 0;
i[0] = shape[0];
......@@ -906,9 +884,14 @@ Hash hash(PyObject *obj)
r[d] = r_init;
}
} else {
if (d == 0) return old_hash(r[0] + r_addend);
if (d == 0) {
Py_uhash_t r_next = r[0] + r_addend;
return r_next == Py_uhash_t(-1) ? -2 : r_next;
}
--d;
r[d] = (r[d] ^ old_hash(r[d+1] + r_addend)) * mult[d];
Py_uhash_t r_next = r[d+1] + r_addend;
r_next = r_next == Py_uhash_t(-1) ? -2 : r_next;
r[d] = (r[d] ^ r_next) * mult[d];
mult[d] += mul_addend + 2 * i[d];
}
}
......@@ -1252,9 +1235,9 @@ template PyObject *str<long>(PyObject*);
template PyObject *str<double>(PyObject*);
template PyObject *str<Complex>(PyObject*);
template Hash hash<long>(PyObject*);
template Hash hash<double>(PyObject*);
template Hash hash<Complex>(PyObject*);
template Py_hash_t hash<long>(PyObject*);
template Py_hash_t hash<double>(PyObject*);
template Py_hash_t hash<Complex>(PyObject*);
template int getbuffer<long>(PyObject*, Py_buffer*, int);
template int getbuffer<double>(PyObject*, Py_buffer*, int);
......@@ -1360,6 +1343,19 @@ MOD_INIT_FUNC(tinyarray)
PyModule_AddObject(m, "ndarray_complex",
(PyObject *)&Array<Complex>::pytype);
// export information on the sizes of different dtypes in bytes
PyObject *dtype_size = PyDict_New();
PyDict_SetItem(dtype_size,
(PyObject*)&PyInt_Type,
PyInt_FromSize_t(sizeof(long)));
PyDict_SetItem(dtype_size,
(PyObject*)&PyFloat_Type,
PyInt_FromSize_t(sizeof(double)));
PyDict_SetItem(dtype_size,
(PyObject*)&PyComplex_Type,
PyInt_FromSize_t(sizeof(Complex)));
PyModule_AddObject(m, "dtype_size", dtype_size);
// We never release these references but this is not a problem. The Python
// interpreter does the same, see try_complex_special_method in
// complexobject.c
......@@ -1803,7 +1799,7 @@ template <typename T>
PyTypeObject Array<T>::pytype = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
pyname,
sizeof(Array<T>) - sizeof(T), // tp_basicsize
sizeof(Array_base), // tp_basicsize
sizeof(T), // tp_itemsize
(destructor)PyObject_Del, // tp_dealloc
0, // tp_print
......
......@@ -33,12 +33,6 @@ def machine_wordsize():
dtypes = [int, float, complex]
dtype_size = {
int: machine_wordsize(),
float: 8,
complex: 16
}
some_shapes = [(), 0, 1, 2, 3,
(0, 0), (1, 0), (0, 1), (2, 2), (17, 17),
(0, 0, 0), (1, 1, 1), (2, 2, 1), (2, 0, 3)]
......@@ -282,8 +276,10 @@ def test_as_dict_key():
def test_hash_equality():
random.seed(123)
maxint = sys.maxsize + 1 # will be typically 2**31 or 2**63
int_bits = 63 if maxint > 2**32 else 31
# These refer to the width of integers stored in a tinyarray.ndarray_int.
int_bits = (8 * ta.dtype_size[int]) - 1 # 8 bits per byte, minus 1 sign bit
maxint = 2**(int_bits)
special = [float('nan'), float('inf'), float('-inf'),
0, -1, -1.0, -1 + 0j,
......@@ -426,8 +422,8 @@ def test_sizeof():
# at the start of the buffer
if len(a.shape) > 1:
n_elements += (a.ndim * machine_wordsize() +
dtype_size[dtype] - 1) // dtype_size[dtype]
buffer_size = n_elements * dtype_size[dtype]
ta.dtype_size[dtype] - 1) // ta.dtype_size[dtype]
buffer_size = n_elements * ta.dtype_size[dtype]
# A Basic Python object has 3 pointer-sized members, or 5 if in
# debug mode.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment