Commit a455ba70 authored by Christoph Groth's avatar Christoph Groth
Browse files

implement transpose

parent 3171f49f
......@@ -1001,6 +1001,52 @@ fail:
return -1;
}
template <typename T>
PyObject *transpose(PyObject *in_)
{
assert(Array<T>::check_exact(in_)); Array<T> *in = (Array<T>*)in_;
int ndim;
ptrdiff_t hops[max_ndim];
size_t *shape_in, shape_out[max_ndim], stride = 1;
in->ndim_shape(&ndim, &shape_in);
if (ndim == 0) {
Py_INCREF(in_);
return in_;
}
for (int id = ndim - 1, od = 0; id >= 0; --id, ++od) {
size_t ext = shape_in[id];
shape_out[od] = ext;
hops[od] = stride;
stride *= ext;
}
for (int d = 1; d < ndim; ++d) hops[d - 1] -= hops[d] * shape_out[d];
Array<T> *out = Array<T>::make(ndim, shape_out);
if (!out) return 0;
T *src = in->data(), *dest = out->data();
int d = 0;
size_t i[max_ndim];
--ndim;
i[0] = shape_out[0];
while (true) {
if (i[d]) {
--i[d];
if (d == ndim) {
*dest++ = *src;
src += hops[d];
} else {
++d;
i[d] = shape_out[d];
}
} else {
if (d == 0) return (PyObject*)out;
--d;
src += hops[d];
}
}
}
template <typename T>
Array<T> *Array<T>::make(int ndim, size_t size)
{
......@@ -1077,6 +1123,12 @@ PyBufferProcs Array<T>::as_buffer = {
(getbufferproc)getbuffer<T> // bf_getbuffer
};
template <typename T>
PyMethodDef Array<T>::methods[] = {
{"transpose", (PyCFunction)transpose<T>, METH_NOARGS},
{0, 0} // Sentinel
};
template <typename T>
PyTypeObject Array<T>::pytype = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
......@@ -1108,7 +1160,7 @@ PyTypeObject Array<T>::pytype = {
0, // tp_weaklistoffset
(getiterfunc)Array_iter<T>::make, // tp_iter
0, // tp_iternext
0, // tp_methods
methods, // tp_methods
0, // tp_members
getset, // tp_getset
0, // tp_base
......@@ -1126,3 +1178,7 @@ PyTypeObject Array<T>::pytype = {
template class Array<long>;
template class Array<double>;
template class Array<Complex>;
template PyObject *transpose<long>(PyObject*);
template PyObject *transpose<double>(PyObject*);
template PyObject *transpose<Complex>(PyObject*);
......@@ -61,6 +61,7 @@ public:
private:
T ob_item[1];
static PyMethodDef methods[];
static PySequenceMethods as_sequence;
static PyMappingMethods as_mapping;
static PyBufferProcs as_buffer;
......@@ -98,4 +99,6 @@ PyObject *array_from_arraylike(PyObject *src, Dtype *dtype,
// Coerced_dtype will contain the common dtype of the coerced arrays.
int coerce_to_arrays(PyObject **a, PyObject **b, Dtype *coerced_dtype);
template <typename T> PyObject *transpose(PyObject *in);
#endif // !ARRAY_HH
......@@ -106,6 +106,18 @@ PyObject *array(PyObject *, PyObject *args)
return array_from_arraylike(src, &dtype);
}
PyObject *(*transpose_dtable[])(PyObject*) = DTYPE_DISPATCH(transpose);
PyObject *transpose(PyObject *, PyObject *args)
{
PyObject *a;
if (!PyArg_ParseTuple(args, "O", &a)) return 0;
Dtype dtype = Dtype::NONE;
a = array_from_arraylike(a, &dtype);
if (!a) return 0;
return transpose_dtable[int(dtype)](a);
}
PyObject *dot(PyObject *, PyObject *args)
{
PyObject *a, *b;
......@@ -154,6 +166,7 @@ PyMethodDef functions[] = {
{"ones", ones, METH_VARARGS},
{"identity", identity, METH_VARARGS},
{"array", array, METH_VARARGS},
{"transpose", transpose, METH_VARARGS},
{"dot", dot, METH_VARARGS},
{"add", binary_ufunc<Add>, METH_VARARGS},
......
......@@ -53,6 +53,7 @@ def test_array():
assert_raises(TypeError, len, b)
assert_equal(memoryview(b).tobytes(), memoryview(a).tobytes())
assert_equal(np.array(b), np.array(l))
assert_equal(ta.transpose(l), np.transpose(l))
# Here, the tinyarray is created via the buffer interface. It's
# possible to distinguish shape 0 from (0, 0).
......@@ -207,7 +208,7 @@ def test_binary_operators():
for op in [ops.add, ops.sub, ops.mul, ops.div, ops.mod, ops.floordiv]:
for dtype in dtypes:
for shape in [(), 1, 3]:
for shape in [(), 1, 3, (3, 2)]:
if dtype is complex and op in [ops.mod, ops.floordiv]:
continue
a = make(shape, dtype)
......@@ -226,7 +227,7 @@ def test_binary_ufuncs():
np_func = np.__dict__[name]
ta_func = ta.__dict__[name]
for dtype in dtypes:
for shape in [(), 1, 3]:
for shape in [(), 1, 3, (3, 2)]:
if dtype is complex and \
name in ["remainder", "floor_divide"]:
continue
......@@ -240,7 +241,7 @@ def test_unary_operators():
ops = operator
for op in [ops.neg, ops.pos, ops.abs]:
for dtype in dtypes:
for shape in [(), 1, 3]:
for shape in [(), 1, 3, (3, 2)]:
a = make(shape, dtype)
assert_equal(op(ta.array(a.tolist())), op(a))
......@@ -250,7 +251,7 @@ def test_unary_ufuncs():
np_func = np.__dict__[name]
ta_func = ta.__dict__[name]
for dtype in dtypes:
for shape in [(), 1, 3]:
for shape in [(), 1, 3, (3, 2)]:
a = make(shape, dtype)
if dtype is complex and name in ["round", "floor", "ceil"]:
assert_raises(TypeError, ta_func, a.tolist())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment