Commit 3171f49f authored by Christoph Groth's avatar Christoph Groth
Browse files

optimize creation of tinyarrays from tinyarrays

parent 17395996
......@@ -41,7 +41,7 @@ Dtype dtype_of_scalar(PyObject *obj)
}
template<typename O, typename I>
Array<O> *promote_array(Array<I> *in)
Array<O> *convert_array(Array<I> *in)
{
int ndim;
size_t *shape, size;
......@@ -54,21 +54,32 @@ Array<O> *promote_array(Array<I> *in)
return out;
}
PyObject *promote_array(Dtype out_dtype, PyObject *in, Dtype in_dtype)
PyObject *convert_array(Dtype out_dtype, PyObject *in, Dtype in_dtype)
{
if (in_dtype == Dtype::NONE)
in_dtype = get_dtype(in);
assert(get_dtype(in) == get_dtype(in));
if (out_dtype == Dtype::DOUBLE) {
assert(in_dtype == Dtype::LONG);
return (PyObject*)promote_array<double>((Array<long>*)in);
assert(in_dtype != out_dtype);
if (out_dtype == Dtype::LONG) {
if (in_dtype == Dtype::COMPLEX) {
PyErr_SetString(PyExc_TypeError, "Cannot convert complex to int.");
return 0;
}
return (PyObject*)convert_array<long>((Array<double>*)in);
} else if (out_dtype == Dtype::DOUBLE) {
if (in_dtype == Dtype::COMPLEX) {
PyErr_SetString(PyExc_TypeError,
"Cannot convert complex to float.");
return 0;
}
return (PyObject*)convert_array<double>((Array<long>*)in);
} else {
assert(out_dtype == Dtype::COMPLEX);
if (in_dtype == Dtype::LONG) {
return (PyObject*)promote_array<Complex>((Array<long>*)in);
return (PyObject*)convert_array<Complex>((Array<long>*)in);
} else {
assert(in_dtype == Dtype::DOUBLE);
return (PyObject*)promote_array<Complex>((Array<double>*)in);
return (PyObject*)convert_array<Complex>((Array<double>*)in);
}
}
}
......@@ -900,6 +911,18 @@ Py_ssize_t load_index_seq_as_ulong(PyObject *obj, unsigned long *uout,
// an array of the given dtype.
PyObject *array_from_arraylike(PyObject *src, Dtype *dtype, Dtype min_dtype)
{
Dtype src_dtype = get_dtype(src);
if (src_dtype != Dtype::NONE) {
// src is already an array
if (*dtype == Dtype::NONE)
*dtype = Dtype(std::max(int(src_dtype), int(min_dtype)));
if (*dtype == src_dtype) {
Py_INCREF(src);
return src;
}
return convert_array(*dtype, src, src_dtype);
}
int ndim;
size_t shape[max_ndim];
PyObject *seqs[max_ndim];
......@@ -943,34 +966,26 @@ PyObject *array_from_arraylike(PyObject *src, Dtype *dtype, Dtype min_dtype)
int coerce_to_arrays(PyObject **a_, PyObject **b_, Dtype *coerced_dtype)
{
PyObject *a = *a_, *b = *b_;
Dtype dtype_a = get_dtype(a), dtype_b = get_dtype(b), dtype;
// Make sure a and b are tinyarrays.
if (dtype_a != Dtype::NONE) {
Py_INCREF(a);
} else {
a = array_from_arraylike(a, &dtype_a);
if (!a) return -1;
}
if (dtype_b != Dtype::NONE) {
Py_INCREF(b);
} else {
b = array_from_arraylike(b, &dtype_b, dtype_a);
if (!b) {
Py_DECREF(a);
return -1;
}
Dtype dtype_a = Dtype::NONE, dtype_b = Dtype::NONE, dtype;
a = array_from_arraylike(a, &dtype_a);
if (!a) return -1;
b = array_from_arraylike(b, &dtype_b, dtype_a);
if (!b) {
Py_DECREF(a);
return -1;
}
// Promote to a common dtype.
dtype = Dtype(std::max(int(dtype_a), int(dtype_b)));
if (dtype_a != dtype) {
PyObject *temp = promote_array(dtype, a, dtype_a);
PyObject *temp = convert_array(dtype, a, dtype_a);
if (temp == 0) goto fail;
Py_DECREF(a);
a = temp;
} else if (dtype_b != dtype) {
PyObject *temp = promote_array(dtype, b, dtype_b);
PyObject *temp = convert_array(dtype, b, dtype_b);
if (temp == 0) goto fail;
Py_DECREF(b);
b = temp;
......
......@@ -135,13 +135,9 @@ PyObject *unary_ufunc(PyObject *, PyObject *args)
PyObject *a;
if (!PyArg_ParseTuple(args, "O", &a)) return 0;
Dtype dtype = get_dtype(a);
if (dtype != Dtype::NONE) {
Py_INCREF(a);
} else {
a = array_from_arraylike(a, &dtype);
if (!a) return 0;
}
Dtype dtype = Dtype::NONE;
a = array_from_arraylike(a, &dtype);
if (!a) return 0;
PyObject *result = operation_dtable[int(dtype)](a);
Py_DECREF(a);
return result;
......
......@@ -71,6 +71,21 @@ def test_array():
assert_raises(ValueError, ta.array, [[0, 0], [[0], [0]]], dtype)
def test_conversion():
for src_dtype in dtypes:
for dest_dtype in dtypes:
src = ta.zeros(3, src_dtype)
tsrc = tuple(src)
impossible = src_dtype is complex and dest_dtype in [int, float]
for s in [src, tsrc]:
if impossible:
assert_raises(TypeError, ta.array, s, dest_dtype)
else:
dest = ta.array(s, dest_dtype)
assert isinstance(dest[0], dest_dtype)
assert_equal(src, dest)
def test_special_constructors():
for dtype in dtypes:
for shape in some_shapes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment