Commit c14515e4 authored by Christoph Groth's avatar Christoph Groth
Browse files

the most important binary arithmetic operations work

parent 67f7d862
* Add missing arithmetic operations.
* 0-d arrays should not be sequences. Currently there are problems with using
0-d arrays with PySequence_Fast.
......
#include <Python.h>
#include <limits>
#include <cmath>
#include <cstddef>
#include <sstream>
#include <functional>
#include <algorithm>
#include <complex>
#include "array.hh"
#include "arithmetic.hh"
#include "conversion.hh"
static_assert(int(-3) / int(2) == -1,
"C99 behavior of division is assumed in this module.");
namespace {
template <typename T>
PyObject *array_scalar_product(PyObject *a_, PyObject *b_)
{
......@@ -33,6 +42,9 @@ PyObject *array_scalar_product(PyObject *a_, PyObject *b_)
return pyobject_from_number(result);
}
PyObject *(*array_scalar_product_dtable[])(PyObject*, PyObject*) =
DTYPE_DISPATCH(array_scalar_product);
// This routine is not heavily optimized. It's performance has been measured
// to be adequate, given that it will be called from Python. The actual
// calculation of the matrix product typically uses less than half of the
......@@ -56,7 +68,7 @@ PyObject *array_matrix_product(PyObject *a_, PyObject *b_)
return 0;
}
const size_t n = shape_a[ndim_a - 1];
size_t shape[ndim];
size_t shape[max_ndim];
size_t d = 0, a0 = 1;
for (int id = 0, e = ndim_a - 1; id < e; ++id)
......@@ -102,61 +114,393 @@ PyObject *array_matrix_product(PyObject *a_, PyObject *b_)
return (PyObject*)result;
}
PyNumberMethods as_number = {
(binaryfunc)0/*add*/, // nb_add
(binaryfunc)0, // nb_subtract
(binaryfunc)0, // nb_multiply
(binaryfunc)0, // nb_divide
(binaryfunc)0, // nb_remainder
(binaryfunc)0, // nb_divmod
(ternaryfunc)0, // nb_power
(unaryfunc)0, // nb_negative
(unaryfunc)0, // nb_positive
(unaryfunc)0, // nb_absolute
(inquiry)0, // nb_nonzero
(unaryfunc)0, // nb_invert
(binaryfunc)0, // nb_lshift
(binaryfunc)0, // nb_rshift
(binaryfunc)0, // nb_and
(binaryfunc)0, // nb_xor
(binaryfunc)0, // nb_or
(coercion)0, // nb_coerce
(unaryfunc)0, // nb_int
(unaryfunc)0, // nb_long
(unaryfunc)0, // nb_float
(unaryfunc)0, // nb_oct
(unaryfunc)0, // nb_hex
(binaryfunc)0, // nb_inplace_add
(binaryfunc)0, // nb_inplace_subtract
(binaryfunc)0, // nb_inplace_multiply
(binaryfunc)0, // nb_inplace_divide
(binaryfunc)0, // nb_inplace_remainder
(ternaryfunc)0, // nb_inplace_power
(binaryfunc)0, // nb_inplace_lshift
(binaryfunc)0, // nb_inplace_rshift
(binaryfunc)0, // nb_inplace_and
(binaryfunc)0, // nb_inplace_xor
(binaryfunc)0, // nb_inplace_or
(binaryfunc)0, // nb_floor_divide
(binaryfunc)0, // nb_true_divide
(binaryfunc)0, // nb_inplace_floor_divide
(binaryfunc)0, // nb_inplace_true_divide
(unaryfunc)0 // nb_index
PyObject *(*array_matrix_product_dtable[])(PyObject*, PyObject*) =
DTYPE_DISPATCH(array_matrix_product);
typedef PyObject *Binary_ufunc(int, const size_t*,
PyObject*, const ptrdiff_t*,
PyObject*, const ptrdiff_t*);
PyObject *apply_binary_ufunc(Binary_ufunc **ufunc_dtable,
PyObject *a, PyObject *b)
{
Dtype dtype_a = get_dtype(a), dtype_b = get_dtype(b), dtype;
// Make sure a and b are tinyarrays.
if (dtype_a != Dtype::NONE) {
Py_INCREF(a);
} else {
a = array_from_arraylike(a, &dtype_a);
if (!a) return 0;
}
if (dtype_b != Dtype::NONE) {
Py_INCREF(b);
} else {
b = array_from_arraylike(b, &dtype_b, dtype_a);
if (!b) {
Py_DECREF(a);
return 0;
}
}
PyObject *result = 0;
// Promote to a common dtype.
dtype = Dtype(std::max(int(dtype_a), int(dtype_b)));
if (dtype_a != dtype) {
PyObject *temp = promote_array(dtype, a, dtype_a);
if (temp == 0) goto end;
Py_DECREF(a);
a = temp;
} else if (dtype_b != dtype) {
PyObject *temp = promote_array(dtype, b, dtype_b);
if (temp == 0) goto end;
Py_DECREF(b);
b = temp;
}
int ndim_a, ndim_b;
size_t *shape_a, *shape_b;
reinterpret_cast<Array_base*>(a)->ndim_shape(&ndim_a, &shape_a);
reinterpret_cast<Array_base*>(b)->ndim_shape(&ndim_b, &shape_b);
int ndim;
size_t stride_a, stride_b, shape[max_ndim];;
ndim = std::max(ndim_a, ndim_b);
stride_a = stride_b = 1;
ptrdiff_t hops_a[max_ndim], hops_b[max_ndim];
for (int d = ndim - 1, d_a = ndim_a - 1, d_b = ndim_b - 1;
d >= 0; --d, --d_a, --d_b) {
size_t ext_a = d_a >= 0 ? shape_a[d_a] : 1;
size_t ext_b = d_b >= 0 ? shape_b[d_b] : 1;
if (ext_a == ext_b) {
hops_a[d] = stride_a;
hops_b[d] = stride_b;
shape[d] = ext_a;
stride_a *= ext_a;
stride_b *= ext_b;
} else if (ext_a == 1) {
hops_a[d] = 0;
hops_b[d] = stride_b;
stride_b *= shape[d] = ext_b;
} else if (ext_b == 1) {
hops_a[d] = stride_a;
hops_b[d] = 0;
stride_a *= shape[d] = ext_a;
} else {
std::ostringstream s;
s << "Operands could not be broadcast together with shapes (";
for (int d = 0; d < ndim_a; ++d) {
s << shape_a[d];
if (d + 1 < ndim_a) s << ", ";
}
s << ") and (";
for (int d = 0; d < ndim_b; ++d) {
s << shape_b[d];
if (d + 1 < ndim_b) s << ", ";
}
s << ").";
PyErr_SetString(PyExc_ValueError, s.str().c_str());
goto end;
}
}
for (int d = 1; d < ndim; ++d)
{
hops_a[d - 1] -= hops_a[d] * shape[d];
hops_b[d - 1] -= hops_b[d] * shape[d];
}
result = ufunc_dtable[int(dtype)](ndim, shape, a, hops_a, b, hops_b);
end:
Py_DECREF(a);
Py_DECREF(b);
return result;
}
template <template <typename> class Op>
struct Binary_op {
template <typename T>
static PyObject *ufunc(int ndim, const size_t *shape,
PyObject *a_, const ptrdiff_t *hops_a,
PyObject *b_, const ptrdiff_t *hops_b);
static PyObject *apply(PyObject *a, PyObject *b);
static Binary_ufunc *dtable[];
};
template <template <typename> class Op>
template <typename T>
PyObject *Binary_op<Op>::ufunc(int ndim, const size_t *shape,
PyObject *a_, const ptrdiff_t *hops_a,
PyObject *b_, const ptrdiff_t *hops_b)
{
Op<T> operation;
assert(Array<T>::check_exact(a_)); Array<T> *a = (Array<T>*)a_;
assert(Array<T>::check_exact(b_)); Array<T> *b = (Array<T>*)b_;
T *src_a = a->data(), *src_b = b->data();
if (ndim == 0) {
T result;
if (operation(result, *src_a, *src_b)) return 0;
return (PyObject*)pyobject_from_number(result);
}
Array<T> *result = Array<T>::make(ndim, shape);
if (result == 0) return 0;
T *dest = result->data();
int d = 0;
size_t i[max_ndim];
--ndim;
i[0] = shape[0];
while (true) {
if (i[d]) {
--i[d];
if (d == ndim) {
if (operation(*dest++, *src_a, *src_b)) {
Py_DECREF(result);
return 0;
}
src_a += hops_a[d];
src_b += hops_b[d];
} else {
++d;
i[d] = shape[d];
}
} else {
if (d == 0) return (PyObject*)result;
--d;
src_a += hops_a[d];
src_b += hops_b[d];
}
}
}
template <template <typename> class Op>
PyObject *Binary_op<Op>::apply(PyObject *a, PyObject *b)
{
return apply_binary_ufunc(dtable, a, b);
}
template <template <typename> class Op>
Binary_ufunc *Binary_op<Op>::dtable[] = DTYPE_DISPATCH(ufunc);
template <typename T>
struct Add {
bool operator()(T &result, T x, T y) {
result = x + y;
return false;
}
};
template <typename T>
struct Subtract {
bool operator()(T &result, T x, T y) {
result = x - y;
return false;
}
};
template <typename T>
struct Multiply {
bool operator()(T &result, T x, T y) {
result = x * y;
return false;
}
};
template <typename T>
struct Remainder {
bool operator()(T &result, T x, T y);
};
template <>
bool Remainder<long>::operator()(long &result, long x, long y)
{
if (y == 0 || (y == -1 && x == std::numeric_limits<long>::min())) {
const char *msg = (y == 0) ?
"Integer modulo by zero." : "Integer modulo overflow.";
if (PyErr_WarnEx(PyExc_RuntimeWarning, msg, 1) < 0) return true;
result = 0;
return false;
}
long x_mod_y = x % y;
result = ((x ^ y) >= 0 /*same sign*/) ? x_mod_y : -x_mod_y;
return false;
}
template <>
bool Remainder<double>::operator()(double &result, double x, double y)
{
result = x - std::floor(x / y) * y;
return false;
}
template <>
template <>
PyObject *Binary_op<Remainder>::ufunc<Complex>(int, const size_t*,
PyObject*, const ptrdiff_t*,
PyObject*, const ptrdiff_t*)
{
PyErr_SetString(PyExc_TypeError,
"Modulo is not defined for complex numbers.");
return 0;
}
template <typename T>
struct Floor_divide {
bool operator()(T &result, T x, T y);
};
template <>
bool Floor_divide<long>::operator()(long &result, long x, long y)
{
if (y == 0 || (y == -1 && x == std::numeric_limits<long>::min())) {
const char *msg = (y == 0) ?
"Integer division by zero." : "Integer division overflow.";
if (PyErr_WarnEx(PyExc_RuntimeWarning, msg, 1) < 0) return true;
result = 0;
return false;
}
long x_div_y = x / y;
result = ((x ^ y) >= 0 /*same sign*/ || (x % y) == 0) ?
x_div_y : x_div_y - 1;
return false;
}
template <>
bool Floor_divide<double>::operator()(double &result, double x, double y)
{
result = std::floor(x / y);
return false;
}
template <>
template <>
PyObject *Binary_op<Floor_divide>::ufunc<Complex>(int, const size_t*,
PyObject*, const ptrdiff_t*,
PyObject*, const ptrdiff_t*)
{
PyErr_SetString(PyExc_TypeError,
"Floor divide is not defined for complex numbers.");
return 0;
}
template <typename T>
struct Divide {
bool operator()(T &result, T x, T y) {
result = x / y;
return false;
}
};
template <>
bool Divide<long>::operator()(long &result, long x, long y)
{
Floor_divide<long> floor_divide;
return floor_divide(result, x, y);
}
} // Anonymous namespace
PyObject *dot_product(PyObject *a, PyObject *b)
{
Dtype dtype_a = get_dtype(a), dtype_b = get_dtype(b);
// Make sure a and b are tinyarrays.
if (dtype_a != Dtype::NONE) {
Py_INCREF(a);
} else {
a = array_from_arraylike(a, &dtype_a);
if (!a) return 0;
}
if (dtype_b != Dtype::NONE) {
Py_INCREF(b);
} else {
b = array_from_arraylike(b, &dtype_b);
if (!b) {
Py_DECREF(a);
return 0;
}
}
PyObject *result = 0;
int ndim_a, ndim_b;
reinterpret_cast<Array_base*>(a)->ndim_shape(&ndim_a, 0);
reinterpret_cast<Array_base*>(b)->ndim_shape(&ndim_b, 0);
if (ndim_a == 0 || ndim_b == 0) {
PyErr_SetString(PyExc_ValueError,
"dot does not support zero-dimensional arrays yet.");
goto end;
}
if (dtype_a != dtype_b) {
PyErr_SetString(PyExc_ValueError,
"Dtype must be the same for now.");
goto end;
}
if (ndim_a == 1 && ndim_b == 1)
result = array_scalar_product_dtable[int(dtype_a)](a, b);
else
result = array_matrix_product_dtable[int(dtype_a)](a, b);
end:
Py_DECREF(a);
Py_DECREF(b);
return result;
}
template <typename T>
PyNumberMethods Array<T>::as_number = {
Binary_op<Add>::apply, // nb_add
Binary_op<Subtract>::apply, // nb_subtract
Binary_op<Multiply>::apply, // nb_multiply
Binary_op<Divide>::apply, // nb_divide
Binary_op<Remainder>::apply, // nb_remainder
(binaryfunc)0, // nb_divmod
(ternaryfunc)0, // nb_power
(unaryfunc)0, // nb_negative
(unaryfunc)0, // nb_positive
(unaryfunc)0, // nb_absolute
(inquiry)0, // nb_nonzero
(unaryfunc)0, // nb_invert
(binaryfunc)0, // nb_lshift
(binaryfunc)0, // nb_rshift
(binaryfunc)0, // nb_and
(binaryfunc)0, // nb_xor
(binaryfunc)0, // nb_or
(coercion)0, // nb_coerce
(unaryfunc)0, // nb_int
(unaryfunc)0, // nb_long
(unaryfunc)0, // nb_float
(unaryfunc)0, // nb_oct
(unaryfunc)0, // nb_hex
(binaryfunc)0, // nb_inplace_add
(binaryfunc)0, // nb_inplace_subtract
(binaryfunc)0, // nb_inplace_multiply
(binaryfunc)0, // nb_inplace_divide
(binaryfunc)0, // nb_inplace_remainder
(ternaryfunc)0, // nb_inplace_power
(binaryfunc)0, // nb_inplace_lshift
(binaryfunc)0, // nb_inplace_rshift
(binaryfunc)0, // nb_inplace_and
(binaryfunc)0, // nb_inplace_xor
(binaryfunc)0, // nb_inplace_or
Binary_op<Floor_divide>::apply, // nb_floor_divide
(binaryfunc)0, // nb_true_divide
(binaryfunc)0, // nb_inplace_floor_divide
(binaryfunc)0, // nb_inplace_true_divide
(unaryfunc)0 // nb_index
};
// Explicit instantiations.
template
PyObject *array_scalar_product<long>(PyObject*, PyObject*);
template
PyObject *array_scalar_product<double>(PyObject*, PyObject*);
template
PyObject *array_scalar_product<Complex>(PyObject*, PyObject*);
template
PyObject *array_matrix_product<long>(PyObject*, PyObject*);
template
PyObject *array_matrix_product<double>(PyObject*, PyObject*);
template
PyObject *array_matrix_product<Complex>(PyObject*, PyObject*);
template PyNumberMethods Array<long>::as_number;
template PyNumberMethods Array<double>::as_number;
template PyNumberMethods Array<Complex>::as_number;
#ifndef ARITHMETIC_HH
#define ARITHMETIC_HH
template <typename T>
PyObject *array_scalar_product(PyObject *a, PyObject *b);
template <typename T>
PyObject *array_matrix_product(PyObject *a, PyObject *b);
extern PyNumberMethods as_number;
PyObject *dot_product(PyObject *a, PyObject *b);
#endif // !ARITHMETIC_HH
......@@ -792,10 +792,10 @@ Py_ssize_t load_index_seq_as_ulong(PyObject *obj, unsigned long *uout,
return len;
}
// If *dtype == Dtype::NONE the simplest fitting dtype for the array will be
// used and written back to *dtype. Any other value of *dtype requests an
// array of a given dtype.
PyObject *array_from_arraylike(PyObject *src, Dtype *dtype)
// If *dtype == Dtype::NONE the simplest fitting dtype (at least min_dtype)
// will be used and written back to *dtype. Any other value of *dtype requests
// an array of the given dtype.
PyObject *array_from_arraylike(PyObject *src, Dtype *dtype, Dtype min_dtype)
{
int ndim;
size_t shape[max_ndim];
......@@ -809,6 +809,7 @@ PyObject *array_from_arraylike(PyObject *src, Dtype *dtype)
assert(shape[ndim - 1] == 0);
*dtype = default_dtype;
}
if (int(*dtype) < int(min_dtype)) *dtype = min_dtype;
while (true) {
PyObject *result = make_and_readin_array_dtable[int(*dtype)](
src, ndim, shape, seqs, true);
......@@ -836,6 +837,39 @@ PyObject *array_from_arraylike(PyObject *src, Dtype *dtype)
}
}
template<typename O, typename I>
Array<O> *promote_array(Array<I> *in)
{
int ndim;
size_t *shape, size;
in->ndim_shape(&ndim, &shape);
Array<O> *out = Array<O>::make(ndim, shape, &size);
if (!out) return 0;
I *src = in->data();
O *dest = out->data();
for (size_t i = 0; i < size; ++i) dest[i] = src[i];
return out;
}
PyObject *promote_array(Dtype out_dtype, PyObject *in, Dtype in_dtype)
{
if (in_dtype == Dtype::NONE)
in_dtype = get_dtype(in);
assert(get_dtype(in) == get_dtype(in));
if (out_dtype == Dtype::DOUBLE) {
assert(in_dtype == Dtype::LONG);
return (PyObject*)promote_array<double>((Array<long>*)in);
} else {
assert(out_dtype == Dtype::COMPLEX);
if (in_dtype == Dtype::LONG) {
return (PyObject*)promote_array<Complex>((Array<long>*)in);
} else {
assert(in_dtype == Dtype::DOUBLE);
return (PyObject*)promote_array<Complex>((Array<double>*)in);
}
}
}
template <typename T>
Array<T> *Array<T>::make(int ndim, size_t size)
{
......@@ -924,31 +958,25 @@ PyTypeObject Array<T>::pytype = {
0, // tp_setattr
0, // tp_compare
(reprfunc)repr<T>, // tp_repr
0/*&as_number*/, // tp_as_number
&as_number, // tp_as_number
&as_sequence, // tp_as_sequence
&as_mapping, // tp_as_mapping
(hashfunc)hash<T>, // tp_hash
0, // tp_call
(reprfunc)str<T>, // tp_str
PyObject_GenericGetAttr, // tp_getattro
0, // tp_setattro
&as_buffer, // tp_as_buffer
Py_TPFLAGS_DEFAULT
| Py_TPFLAGS_HAVE_NEWBUFFER, // tp_flags
&as_buffer, // tp_as_buffer
Py_TPFLAGS_DEFAULT |
Py_TPFLAGS_HAVE_NEWBUFFER |
Py_TPFLAGS_CHECKTYPES, // tp_flags
doc, // tp_doc
0, // tp_traverse
0, // tp_clear
// richcompare, // tp_richcompare
0, // tp_richcompare
0, // tp_weaklistoffset
(getiterfunc)Array_iter<T>::make, // tp_iter
0, // tp_iternext
0, // tp_methods
0, // tp_members
......
......@@ -64,6 +64,7 @@ private:
static PySequenceMethods as_sequence;
static PyMappingMethods as_mapping;
static PyBufferProcs as_buffer;
static PyNumberMethods as_number;
static PyTypeObject pytype;