Commit f8deb036 authored by Christoph Groth's avatar Christoph Groth
Browse files

expose binary arithmetic operations as functions

parent 92085a65
#include <Python.h>
#include <limits>
#include <cmath>
#include <cstddef>
#include <sstream>
#include <functional>
#include <algorithm>
......@@ -117,10 +116,6 @@ PyObject *array_matrix_product(PyObject *a_, PyObject *b_)
PyObject *(*array_matrix_product_dtable[])(PyObject*, PyObject*) =
DTYPE_DISPATCH(array_matrix_product);
typedef PyObject *Binary_ufunc(int, const size_t*,
PyObject*, const ptrdiff_t*,
PyObject*, const ptrdiff_t*);
PyObject *apply_binary_ufunc(Binary_ufunc **ufunc_dtable,
PyObject *a, PyObject *b)
{
......@@ -186,16 +181,7 @@ end:
return result;
}
template <template <typename> class Op>
struct Binary_op {
template <typename T>
static PyObject *ufunc(int ndim, const size_t *shape,
PyObject *a_, const ptrdiff_t *hops_a,
PyObject *b_, const ptrdiff_t *hops_b);
static PyObject *apply(PyObject *a, PyObject *b);
static Binary_ufunc *dtable[];
};
} // Anonymous namespace
template <template <typename> class Op>
template <typename T>
......@@ -372,8 +358,6 @@ bool Divide<long>::operator()(long &result, long x, long y)
return floor_divide(result, x, y);
}
} // Anonymous namespace
PyObject *dot_product(PyObject *a, PyObject *b)
{
Dtype dtype;
......
#ifndef ARITHMETIC_HH
#define ARITHMETIC_HH
#include <cstddef>
PyObject *dot_product(PyObject *a, PyObject *b);
typedef PyObject *Binary_ufunc(int, const size_t*,
PyObject*, const ptrdiff_t*,
PyObject*, const ptrdiff_t*);
template <template <typename> class Op>
class Binary_op {
public:
static PyObject *apply(PyObject *a, PyObject *b);
private:
template <typename T>
static PyObject *ufunc(int ndim, const size_t *shape,
PyObject *a_, const ptrdiff_t *hops_a,
PyObject *b_, const ptrdiff_t *hops_b);
static Binary_ufunc *dtable[];
};
template <typename T> struct Add;
template <typename T> struct Subtract;
template <typename T> struct Multiply;
template <typename T> struct Divide;
template <typename T> struct Remainder;
template <typename T> struct Floor_divide;
#endif // !ARITHMETIC_HH
......@@ -109,11 +109,18 @@ PyObject *array(PyObject *, PyObject *args)
PyObject *dot(PyObject *, PyObject *args)
{
PyObject *a, *b;
if (!PyArg_ParseTuple(args, "OO", &a, &b))
return 0;
if (!PyArg_ParseTuple(args, "OO", &a, &b)) return 0;
return dot_product(a, b);
}
template <template <typename> class Op>
PyObject *ufunc(PyObject *, PyObject *args)
{
PyObject *a, *b;
if (!PyArg_ParseTuple(args, "OO", &a, &b)) return 0;
return Binary_op<Op>::apply(a, b);
}
} // Anonymous namespace
PyMethodDef functions[] = {
......@@ -122,5 +129,11 @@ PyMethodDef functions[] = {
{"identity", identity, METH_VARARGS},
{"array", array, METH_VARARGS},
{"dot", dot, METH_VARARGS},
{"add", ufunc<Add>, METH_VARARGS},
{"subtract", ufunc<Subtract>, METH_VARARGS},
{"multiply", ufunc<Multiply>, METH_VARARGS},
{"divide", ufunc<Divide>, METH_VARARGS},
{"remainder", ufunc<Remainder>, METH_VARARGS},
{"floor_divide", ufunc<Remainder>, METH_VARARGS},
{0, 0, 0, 0} // Sentinel
};
......@@ -200,3 +200,22 @@ def test_binary_operators():
assert_equal(
op(ta.array(a.tolist()), ta.array(b.tolist())),
op(a, b))
def test_binary_ufuncs():
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)
for name in ["add", "subtract", "multiply", "divide",
"remainder", "floor_divide"]:
np_func = np.__dict__[name]
ta_func = ta.__dict__[name]
for dtype in dtypes:
for shape in [(), 1, 3]:
if dtype is complex and \
name in ["remainder", "floor_divide"]:
continue
a = make(shape, dtype)
b = make(shape, dtype)
assert_equal(ta_func(a.tolist(), b.tolist()),
np_func(a, b))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment