Commit b491f572 authored by Christoph Groth's avatar Christoph Groth
Browse files

factor-out coercion, use it in dot

parent c14515e4
......@@ -124,50 +124,17 @@ typedef PyObject *Binary_ufunc(int, const size_t*,
PyObject *apply_binary_ufunc(Binary_ufunc **ufunc_dtable,
PyObject *a, PyObject *b)
{
Dtype dtype_a = get_dtype(a), dtype_b = get_dtype(b), dtype;
// Make sure a and b are tinyarrays.
if (dtype_a != Dtype::NONE) {
Py_INCREF(a);
} else {
a = array_from_arraylike(a, &dtype_a);
if (!a) return 0;
}
if (dtype_b != Dtype::NONE) {
Py_INCREF(b);
} else {
b = array_from_arraylike(b, &dtype_b, dtype_a);
if (!b) {
Py_DECREF(a);
return 0;
}
}
PyObject *result = 0;
// Promote to a common dtype.
dtype = Dtype(std::max(int(dtype_a), int(dtype_b)));
if (dtype_a != dtype) {
PyObject *temp = promote_array(dtype, a, dtype_a);
if (temp == 0) goto end;
Py_DECREF(a);
a = temp;
} else if (dtype_b != dtype) {
PyObject *temp = promote_array(dtype, b, dtype_b);
if (temp == 0) goto end;
Py_DECREF(b);
b = temp;
}
Dtype dtype;
if (coerce_to_arrays(&a, &b, &dtype) < 0) return 0;
int ndim_a, ndim_b;
size_t *shape_a, *shape_b;
reinterpret_cast<Array_base*>(a)->ndim_shape(&ndim_a, &shape_a);
reinterpret_cast<Array_base*>(b)->ndim_shape(&ndim_b, &shape_b);
int ndim;
size_t stride_a, stride_b, shape[max_ndim];;
ndim = std::max(ndim_a, ndim_b);
stride_a = stride_b = 1;
PyObject *result = 0;
int ndim = std::max(ndim_a, ndim_b);
size_t stride_a = 1, stride_b = 1, shape[max_ndim];;
ptrdiff_t hops_a[max_ndim], hops_b[max_ndim];
for (int d = ndim - 1, d_a = ndim_a - 1, d_b = ndim_b - 1;
d >= 0; --d, --d_a, --d_b) {
......@@ -409,24 +376,8 @@ bool Divide<long>::operator()(long &result, long x, long y)
PyObject *dot_product(PyObject *a, PyObject *b)
{
Dtype dtype_a = get_dtype(a), dtype_b = get_dtype(b);
// Make sure a and b are tinyarrays.
if (dtype_a != Dtype::NONE) {
Py_INCREF(a);
} else {
a = array_from_arraylike(a, &dtype_a);
if (!a) return 0;
}
if (dtype_b != Dtype::NONE) {
Py_INCREF(b);
} else {
b = array_from_arraylike(b, &dtype_b);
if (!b) {
Py_DECREF(a);
return 0;
}
}
Dtype dtype;
if (coerce_to_arrays(&a, &b, &dtype) < 0) return 0;
PyObject *result = 0;
int ndim_a, ndim_b;
......@@ -437,16 +388,11 @@ PyObject *dot_product(PyObject *a, PyObject *b)
"dot does not support zero-dimensional arrays yet.");
goto end;
}
if (dtype_a != dtype_b) {
PyErr_SetString(PyExc_ValueError,
"Dtype must be the same for now.");
goto end;
}
if (ndim_a == 1 && ndim_b == 1)
result = array_scalar_product_dtable[int(dtype_a)](a, b);
result = array_scalar_product_dtable[int(dtype)](a, b);
else
result = array_matrix_product_dtable[int(dtype_a)](a, b);
result = array_matrix_product_dtable[int(dtype)](a, b);
end:
Py_DECREF(a);
......
......@@ -40,6 +40,39 @@ Dtype dtype_of_scalar(PyObject *obj)
return Dtype::NONE;
}
template<typename O, typename I>
Array<O> *promote_array(Array<I> *in)
{
int ndim;
size_t *shape, size;
in->ndim_shape(&ndim, &shape);
Array<O> *out = Array<O>::make(ndim, shape, &size);
if (!out) return 0;
I *src = in->data();
O *dest = out->data();
for (size_t i = 0; i < size; ++i) dest[i] = src[i];
return out;
}
PyObject *promote_array(Dtype out_dtype, PyObject *in, Dtype in_dtype)
{
if (in_dtype == Dtype::NONE)
in_dtype = get_dtype(in);
assert(get_dtype(in) == get_dtype(in));
if (out_dtype == Dtype::DOUBLE) {
assert(in_dtype == Dtype::LONG);
return (PyObject*)promote_array<double>((Array<long>*)in);
} else {
assert(out_dtype == Dtype::COMPLEX);
if (in_dtype == Dtype::LONG) {
return (PyObject*)promote_array<Complex>((Array<long>*)in);
} else {
assert(in_dtype == Dtype::DOUBLE);
return (PyObject*)promote_array<Complex>((Array<double>*)in);
}
}
}
const char *seq_err_msg =
"A sequence does not support sequence protocol - "
"this is probably due to a bug in numpy for 0-d arrays.";
......@@ -837,37 +870,50 @@ PyObject *array_from_arraylike(PyObject *src, Dtype *dtype, Dtype min_dtype)
}
}
template<typename O, typename I>
Array<O> *promote_array(Array<I> *in)
int coerce_to_arrays(PyObject **a_, PyObject **b_, Dtype *coerced_dtype)
{
int ndim;
size_t *shape, size;
in->ndim_shape(&ndim, &shape);
Array<O> *out = Array<O>::make(ndim, shape, &size);
if (!out) return 0;
I *src = in->data();
O *dest = out->data();
for (size_t i = 0; i < size; ++i) dest[i] = src[i];
return out;
}
PyObject *a = *a_, *b = *b_;
Dtype dtype_a = get_dtype(a), dtype_b = get_dtype(b), dtype;
PyObject *promote_array(Dtype out_dtype, PyObject *in, Dtype in_dtype)
{
if (in_dtype == Dtype::NONE)
in_dtype = get_dtype(in);
assert(get_dtype(in) == get_dtype(in));
if (out_dtype == Dtype::DOUBLE) {
assert(in_dtype == Dtype::LONG);
return (PyObject*)promote_array<double>((Array<long>*)in);
// Make sure a and b are tinyarrays.
if (dtype_a != Dtype::NONE) {
Py_INCREF(a);
} else {
assert(out_dtype == Dtype::COMPLEX);
if (in_dtype == Dtype::LONG) {
return (PyObject*)promote_array<Complex>((Array<long>*)in);
} else {
assert(in_dtype == Dtype::DOUBLE);
return (PyObject*)promote_array<Complex>((Array<double>*)in);
a = array_from_arraylike(a, &dtype_a);
if (!a) return -1;
}
if (dtype_b != Dtype::NONE) {
Py_INCREF(b);
} else {
b = array_from_arraylike(b, &dtype_b, dtype_a);
if (!b) {
Py_DECREF(a);
return -1;
}
}
// Promote to a common dtype.
dtype = Dtype(std::max(int(dtype_a), int(dtype_b)));
if (dtype_a != dtype) {
PyObject *temp = promote_array(dtype, a, dtype_a);
if (temp == 0) goto fail;
Py_DECREF(a);
a = temp;
} else if (dtype_b != dtype) {
PyObject *temp = promote_array(dtype, b, dtype_b);
if (temp == 0) goto fail;
Py_DECREF(b);
b = temp;
}
// Success
*a_ = a; *b_ = b; *coerced_dtype = dtype;
return 0;
fail:
Py_DECREF(a);
Py_DECREF(b);
return -1;
}
template <typename T>
......
......@@ -95,10 +95,7 @@ inline Dtype get_dtype(PyObject *obj)
PyObject *array_from_arraylike(PyObject *src, Dtype *dtype,
Dtype min_dtype = Dtype(0));
template<typename O, typename I>
Array<O> *promote_array(Array<I> *in);
PyObject *promote_array(Dtype out_dtype, PyObject *in,
Dtype in_dtype = Dtype::NONE);
// Coerced_dtype will contain the common dtype of the coerced arrays.
int coerce_to_arrays(PyObject **a, PyObject **b, Dtype *coerced_dtype);
#endif // !ARRAY_HH
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment