Commit 17395996 authored by Christoph Groth's avatar Christoph Groth
Browse files

implement unary arithmetic operations and expose them as functions

parent f8deb036
......@@ -384,6 +384,105 @@ end:
return result;
}
template <typename Op>
PyObject *apply_unary_ufunc(PyObject *a_)
{
typedef typename Op::Type T;
Op operation;
typedef decltype(operation(T())) R; // Return type
if (Op::error) {
PyErr_SetString(PyExc_TypeError, Op::error);
return 0;
}
assert(Array<T>::check_exact(a_)); Array<T> *a = (Array<T>*)a_;
int ndim;
size_t *shape;
a->ndim_shape(&ndim, &shape);
if (ndim == 0)
return (PyObject*)pyobject_from_number(operation(*a->data()));
if (Op::unchanged) {
Py_INCREF(a_);
return a_;
}
size_t size;
Array<R> *result = Array<R>::make(ndim, shape, &size);
if (result == 0) return 0;
T *src = a->data();
R *dest = result->data();
for (size_t i = 0; i < size; ++i) dest[i] = operation(src[i]);
return (PyObject*)result;
}
template <typename T>
struct Negative {
typedef T Type;
static constexpr const char *error = 0;
static const bool unchanged = false;
T operator()(T x) { return -x; }
};
template <typename T>
struct Positive {
typedef T Type;
static constexpr const char *error = 0;
static const bool unchanged = true;
T operator()(T x) { return x; }
};
template <typename T>
struct Absolute {
typedef T Type;
static constexpr const char *error = 0;
static const bool unchanged = false;
T operator()(T x) { return std::abs(x); }
};
template <>
struct Absolute<Complex> {
typedef Complex Type;
static constexpr const char *error = 0;
static const bool unchanged = false;
double operator()(Complex x) { return std::abs(x); }
};
// Integers are not changed by any kind of rounding.
template <typename Kind>
struct Round<Kind, long> {
typedef long Type;
static constexpr const char *error = 0;
static const bool unchanged = true;
long operator()(long x) { return x; }
};
template <typename Kind>
struct Round<Kind, double> {
typedef double Type;
static constexpr const char *error = 0;
static const bool unchanged = false;
double operator()(double x) {
Kind rounding_kind;
return rounding_kind(x);
}
};
template <typename Kind>
struct Round<Kind, Complex> {
typedef Complex Type;
static constexpr const char *error =
"Rounding is not defined for complex numbers.";
static const bool unchanged = false;
Complex operator()(Complex) { return 0.0/0.0; }
};
// The following types are to be used as Kind template parameter for Round.
struct Nearest { double operator()(double x) { return std::round(x); } };
struct Floor { double operator()(double x) { return std::floor(x); } };
struct Ceil { double operator()(double x) { return std::ceil(x); } };
template <typename T>
PyNumberMethods Array<T>::as_number = {
Binary_op<Add>::apply, // nb_add
......@@ -393,9 +492,9 @@ PyNumberMethods Array<T>::as_number = {
Binary_op<Remainder>::apply, // nb_remainder
(binaryfunc)0, // nb_divmod
(ternaryfunc)0, // nb_power
(unaryfunc)0, // nb_negative
(unaryfunc)0, // nb_positive
(unaryfunc)0, // nb_absolute
apply_unary_ufunc<Negative<T>>, // nb_negative
apply_unary_ufunc<Positive<T>>, // nb_positive
apply_unary_ufunc<Absolute<T>>, // nb_absolute
(inquiry)0, // nb_nonzero
(unaryfunc)0, // nb_invert
(binaryfunc)0, // nb_lshift
......@@ -434,3 +533,13 @@ PyNumberMethods Array<T>::as_number = {
template PyNumberMethods Array<long>::as_number;
template PyNumberMethods Array<double>::as_number;
template PyNumberMethods Array<Complex>::as_number;
template PyObject *apply_unary_ufunc<Round<Nearest, long>>(PyObject*);
template PyObject *apply_unary_ufunc<Round<Nearest, double>>(PyObject*);
template PyObject *apply_unary_ufunc<Round<Nearest, Complex>>(PyObject*);
template PyObject *apply_unary_ufunc<Round<Floor, long>>(PyObject*);
template PyObject *apply_unary_ufunc<Round<Floor, double>>(PyObject*);
template PyObject *apply_unary_ufunc<Round<Floor, Complex>>(PyObject*);
template PyObject *apply_unary_ufunc<Round<Ceil, long>>(PyObject*);
template PyObject *apply_unary_ufunc<Round<Ceil, double>>(PyObject*);
template PyObject *apply_unary_ufunc<Round<Ceil, Complex>>(PyObject*);
......@@ -22,6 +22,7 @@ private:
static Binary_ufunc *dtable[];
};
// Binary operations
template <typename T> struct Add;
template <typename T> struct Subtract;
template <typename T> struct Multiply;
......@@ -29,4 +30,19 @@ template <typename T> struct Divide;
template <typename T> struct Remainder;
template <typename T> struct Floor_divide;
template <typename Op> PyObject *apply_unary_ufunc(PyObject *a);
// Unaray operations
template <typename T> struct Negative;
template <typename T> struct Positive;
template <typename T> struct Absolute;
template <typename Kind, typename T> struct Round;
// Kinds of rounding, to be used with Round.
struct Nearest;
struct Floor;
struct Ceil;
#endif // !ARITHMETIC_HH
......@@ -114,26 +114,65 @@ PyObject *dot(PyObject *, PyObject *args)
}
template <template <typename> class Op>
PyObject *ufunc(PyObject *, PyObject *args)
PyObject *binary_ufunc(PyObject *, PyObject *args)
{
PyObject *a, *b;
if (!PyArg_ParseTuple(args, "OO", &a, &b)) return 0;
return Binary_op<Op>::apply(a, b);
}
template <template <typename> class Op>
PyObject *unary_ufunc(PyObject *, PyObject *args)
{
static_assert(int(Dtype::LONG) == 0 && int(Dtype::DOUBLE) == 1 &&
int(Dtype::COMPLEX) == 2 && int(Dtype::NONE) == 3,
"Update me.");
static PyObject *(*operation_dtable[])(PyObject*) = {
apply_unary_ufunc<Op<long> >,
apply_unary_ufunc<Op<double> >,
apply_unary_ufunc<Op<Complex> >
};
PyObject *a;
if (!PyArg_ParseTuple(args, "O", &a)) return 0;
Dtype dtype = get_dtype(a);
if (dtype != Dtype::NONE) {
Py_INCREF(a);
} else {
a = array_from_arraylike(a, &dtype);
if (!a) return 0;
}
PyObject *result = operation_dtable[int(dtype)](a);
Py_DECREF(a);
return result;
}
} // Anonymous namespace
template <typename T> using Round_nearest = Round<Nearest, T>;
template <typename T> using Round_floor = Round<Floor, T>;
template <typename T> using Round_ceil = Round<Ceil, T>;
PyMethodDef functions[] = {
{"zeros", zeros, METH_VARARGS},
{"ones", ones, METH_VARARGS},
{"identity", identity, METH_VARARGS},
{"array", array, METH_VARARGS},
{"dot", dot, METH_VARARGS},
{"add", ufunc<Add>, METH_VARARGS},
{"subtract", ufunc<Subtract>, METH_VARARGS},
{"multiply", ufunc<Multiply>, METH_VARARGS},
{"divide", ufunc<Divide>, METH_VARARGS},
{"remainder", ufunc<Remainder>, METH_VARARGS},
{"floor_divide", ufunc<Remainder>, METH_VARARGS},
{"add", binary_ufunc<Add>, METH_VARARGS},
{"subtract", binary_ufunc<Subtract>, METH_VARARGS},
{"multiply", binary_ufunc<Multiply>, METH_VARARGS},
{"divide", binary_ufunc<Divide>, METH_VARARGS},
{"remainder", binary_ufunc<Remainder>, METH_VARARGS},
{"floor_divide", binary_ufunc<Floor_divide>, METH_VARARGS},
{"negative", unary_ufunc<Negative>, METH_VARARGS},
{"abs", unary_ufunc<Absolute>, METH_VARARGS},
{"absolute", unary_ufunc<Absolute>, METH_VARARGS},
{"round", unary_ufunc<Round_nearest>, METH_VARARGS},
{"floor", unary_ufunc<Round_floor>, METH_VARARGS},
{"ceil", unary_ufunc<Round_ceil>, METH_VARARGS},
{0, 0, 0, 0} // Sentinel
};
......@@ -219,3 +219,25 @@ def test_binary_ufuncs():
b = make(shape, dtype)
assert_equal(ta_func(a.tolist(), b.tolist()),
np_func(a, b))
def test_unary_operators():
ops = operator
for op in [ops.neg, ops.pos, ops.abs]:
for dtype in dtypes:
for shape in [(), 1, 3]:
a = make(shape, dtype)
assert_equal(op(ta.array(a.tolist())), op(a))
def test_unary_ufuncs():
for name in ["negative", "abs", "absolute", "round", "floor", "ceil"]:
np_func = np.__dict__[name]
ta_func = ta.__dict__[name]
for dtype in dtypes:
for shape in [(), 1, 3]:
a = make(shape, dtype)
if dtype is complex and name in ["round", "floor", "ceil"]:
assert_raises(TypeError, ta_func, a.tolist())
else:
assert_equal(ta_func(a.tolist()), np_func(a))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment