Commit 1b14b07d authored by Christoph Groth's avatar Christoph Groth
Browse files

add matrix constructor

parent c3d309cd
* Fix this: ta.array(np.matrix([1,2,3]))
* Implement missing arithmetic operations.
* Implement missing comparisons.
......
......@@ -14,6 +14,8 @@ const char *Array<double>::pyformat = "d";
template <>
const char *Array<Complex>::pyformat = "Zd";
const char *dtype_names[] = {"int", "float", "complex"};
namespace {
PyObject *int_str, *long_str, *float_str, *complex_str, *index_str;
......@@ -42,47 +44,57 @@ Dtype dtype_of_scalar(PyObject *obj)
}
template<typename O, typename I>
Array<O> *convert_array(Array<I> *in)
PyObject *convert_array(PyObject *in_, int ndim, size_t *shape)
{
int ndim;
size_t *shape, size;
in->ndim_shape(&ndim, &shape);
assert(Array<I>::check_exact(in_)); Array<I> *in = (Array<I>*)in_;
size_t size;
if (ndim == -1) {
assert(shape == 0);
in->ndim_shape(&ndim, &shape);
} else {
#ifndef NDEBUG
int in_ndim;
size_t *in_shape;
in->ndim_shape(&in_ndim, &in_shape);
assert(shape);
assert(calc_size(ndim, shape) == calc_size(in_ndim, in_shape));
#endif
}
Array<O> *out = Array<O>::make(ndim, shape, &size);
if (!out) return 0;
I *src = in->data();
O *dest = out->data();
for (size_t i = 0; i < size; ++i) dest[i] = src[i];
return out;
return (PyObject*)out;
}
PyObject *convert_array(Dtype out_dtype, PyObject *in, Dtype in_dtype)
typedef PyObject *Convert_array(PyObject*, int, size_t*);
Convert_array *convert_array_dtable[][3] = {
{convert_array<long, long>,
convert_array<long, double>,
0},
{convert_array<double, long>,
convert_array<double, double>,
0},
{convert_array<Complex, long>,
convert_array<Complex, double>,
convert_array<Complex, Complex>}
};
PyObject *convert_array(Dtype dtype_out, PyObject *in, Dtype dtype_in,
int ndim = -1, size_t *shape = 0)
{
if (in_dtype == Dtype::NONE)
in_dtype = get_dtype(in);
if (dtype_in == Dtype::NONE)
dtype_in = get_dtype(in);
assert(get_dtype(in) == get_dtype(in));
assert(in_dtype != out_dtype);
if (out_dtype == Dtype::LONG) {
if (in_dtype == Dtype::COMPLEX) {
PyErr_SetString(PyExc_TypeError, "Cannot convert complex to int.");
return 0;
}
return (PyObject*)convert_array<long>((Array<double>*)in);
} else if (out_dtype == Dtype::DOUBLE) {
if (in_dtype == Dtype::COMPLEX) {
PyErr_SetString(PyExc_TypeError,
"Cannot convert complex to float.");
return 0;
}
return (PyObject*)convert_array<double>((Array<long>*)in);
} else {
assert(out_dtype == Dtype::COMPLEX);
if (in_dtype == Dtype::LONG) {
return (PyObject*)convert_array<Complex>((Array<long>*)in);
} else {
assert(in_dtype == Dtype::DOUBLE);
return (PyObject*)convert_array<Complex>((Array<double>*)in);
}
Convert_array *func = convert_array_dtable[int(dtype_out)][int(dtype_in)];
if (!func) {
PyErr_Format(PyExc_TypeError, "Cannot convert %s to %s.",
dtype_names[int(dtype_in)], dtype_names[int(dtype_out)]);
return 0;
}
return func(in, ndim, shape);
}
const char *seq_err_msg =
......@@ -247,13 +259,19 @@ fail:
}
template <typename T>
PyObject *make_and_readin_array(PyObject *src, int ndim,
const size_t *shape, PyObject **seqs,
PyObject *make_and_readin_array(PyObject *in, int ndim_in, int ndim_out,
const size_t *shape_out, PyObject **seqs,
bool exact)
{
Array<T> *result = Array<T>::make(ndim, shape);
Array<T> *result = Array<T>::make(ndim_out, shape_out);
assert(ndim_out >= ndim_in);
#ifndef NDEBUG
for (int d = 0, e = ndim_out - ndim_in; d < e; ++d)
assert(shape_out[d] == 1);
#endif
if (result == 0) return 0;
if (readin_arraylike<T>(result->data(), ndim, shape, src, seqs, exact)
if (readin_arraylike<T>(result->data(), ndim_in,
shape_out + ndim_out - ndim_in, in, seqs, exact)
== -1) {
Py_DECREF(result);
return 0;
......@@ -262,7 +280,7 @@ PyObject *make_and_readin_array(PyObject *src, int ndim,
}
PyObject *(*make_and_readin_array_dtable[])(
PyObject*, int, const size_t*, PyObject**, bool) =
PyObject*, int, int, const size_t*, PyObject**, bool) =
DTYPE_DISPATCH(make_and_readin_array);
template <typename T>
......@@ -907,56 +925,130 @@ Py_ssize_t load_index_seq_as_ulong(PyObject *obj, unsigned long *uout,
return len;
}
// If *dtype == Dtype::NONE the simplest fitting dtype (at least min_dtype)
// If *dtype == Dtype::NONE the simplest fitting dtype (at least dtype_min)
// will be used and written back to *dtype. Any other value of *dtype requests
// an array of the given dtype.
PyObject *array_from_arraylike(PyObject *src, Dtype *dtype, Dtype min_dtype)
PyObject *array_from_arraylike(PyObject *in, Dtype *dtype, Dtype dtype_min)
{
Dtype src_dtype = get_dtype(src), dt = *dtype;
Dtype dtype_in = get_dtype(in), dt = *dtype;
int ndim;
size_t shape[max_ndim];
PyObject *seqs[max_ndim], *result;
if (src_dtype != Dtype::NONE) {
// src is already an array
if (dtype_in != Dtype::NONE) {
// in is already an array
if (dt == Dtype::NONE)
dt = Dtype(std::max(int(src_dtype), int(min_dtype)));
if (dt == src_dtype)
Py_INCREF(result = src);
dt = Dtype(std::max(int(dtype_in), int(dtype_min)));
if (dt == dtype_in)
Py_INCREF(result = in);
else
result = convert_array(dt, src, src_dtype);
result = convert_array(dt, in, dtype_in);
} else if (dt == Dtype::NONE) {
// No specific dtype has been requested. It will be determined by the
// input.
PyObject *seqs_copy[max_ndim];
if (examine_arraylike(src, &ndim, shape, seqs, &dt) == -1)
if (examine_arraylike(in, &ndim, shape, seqs, &dt) == -1)
return 0;
for (int d = 0; d < ndim; ++d) Py_INCREF(seqs_copy[d] = seqs[d]);
if (dt == Dtype::NONE) {
assert(shape[ndim - 1] == 0);
dt = default_dtype;
}
if (int(dt) < int(min_dtype)) dt = min_dtype;
if (int(dt) < int(dtype_min)) dt = dtype_min;
while (true) {
result = make_and_readin_array_dtable[int(dt)](
src, ndim, shape, seqs, true);
in, ndim, ndim, shape, seqs, true);
if (result) break;
PyErr_Clear();
dt = Dtype(int(dt) + 1);
if (dt == Dtype::NONE) {
PyErr_SetString(PyExc_TypeError, "Expecting a number.");
result = 0;
break;
}
PyErr_Clear();
for (int d = 0; d < ndim; ++d) Py_INCREF(seqs[d] = seqs_copy[d]);
}
for (int d = 0; d < ndim; ++d) Py_DECREF(seqs_copy[d]);
} else {
// A specific dtype has been requested.
if (examine_arraylike(src, &ndim, shape, seqs, 0) == -1)
result = 0;
if (examine_arraylike(in, &ndim, shape, seqs, 0) == -1)
return 0;
else
result = make_and_readin_array_dtable[int(dt)](
src, ndim, shape, seqs, false);
in, ndim, ndim, shape, seqs, false);
}
*dtype = dt;
return result;
}
// If *dtype == Dtype::NONE the simplest fitting dtype (at least dtype_min)
// will be used and written back to *dtype. Any other value of *dtype requests
// an array of the given dtype.
PyObject *matrix_from_arraylike(PyObject *in, Dtype *dtype, Dtype dtype_min)
{
Dtype dtype_in = get_dtype(in), dt = *dtype;
int ndim;
size_t shape[max_ndim];
PyObject *seqs[max_ndim], *result;
if (dtype_in != Dtype::NONE) {
// `in` is already an array.
if (dt == Dtype::NONE)
dt = Dtype(std::max(int(dtype_in), int(dtype_min)));
size_t *in_shape;
reinterpret_cast<Array_base*>(in)->ndim_shape(&ndim, &in_shape);
if (ndim == 2) {
if (dt == dtype_in)
Py_INCREF(result = in);
else
result = convert_array(dt, in, dtype_in);
} else if (ndim < 2) {
shape[0] = 1;
shape[1] = (ndim == 0) ? 1 : in_shape[0];
result = convert_array(dt, in, dtype_in, 2, shape);
} else {
PyErr_SetString(PyExc_ValueError, "Matrix must be 2-dimensional.");
result = 0;
}
} else {
// `in` is not an array.
if (examine_arraylike(in, &ndim, shape, seqs,
dt == Dtype::NONE ? &dt : 0) == -1) return 0;
if (ndim != 2) {
if (ndim > 2) {
PyErr_SetString(PyExc_ValueError,
"Matrix must be 2-dimensional.");
return 0;
}
shape[1] = (ndim == 0) ? 1 : shape[0];
shape[0] = 1;
}
if (*dtype == Dtype::NONE) {
// No specific dtype has been requested. It will be determined by
// the input.
PyObject *seqs_copy[max_ndim];
for (int d = 0; d < ndim; ++d) Py_INCREF(seqs_copy[d] = seqs[d]);
if (dt == Dtype::NONE) {
assert(shape[1] == 0);
dt = default_dtype;
}
if (int(dt) < int(dtype_min)) dt = dtype_min;
while (true) {
result = make_and_readin_array_dtable[int(dt)](
in, ndim, 2, shape, seqs, true);
if (result) break;
dt = Dtype(int(dt) + 1);
if (dt == Dtype::NONE) {
result = 0;
break;
}
PyErr_Clear();
for (int d = 0; d < ndim; ++d)
Py_INCREF(seqs[d] = seqs_copy[d]);
}
for (int d = 0; d < ndim; ++d) Py_DECREF(seqs_copy[d]);
} else {
// A specific dtype has been requested.
result = make_and_readin_array_dtable[int(dt)](
in, ndim, 2, shape, seqs, false);
}
}
*dtype = dt;
return result;
......
......@@ -9,6 +9,8 @@ const int max_ndim = 16;
enum class Dtype : char {LONG = 0, DOUBLE, COMPLEX, NONE};
const Dtype default_dtype = Dtype::DOUBLE;
extern const char *dtype_names[];
#define DTYPE_DISPATCH(func) {func<long>, func<double>, func<Complex>}
// We use the ob_size field in a clever way to encode either the length of a
......@@ -93,8 +95,10 @@ inline Dtype get_dtype(PyObject *obj)
return Dtype::NONE;
}
PyObject *array_from_arraylike(PyObject *src, Dtype *dtype,
Dtype min_dtype = Dtype(0));
PyObject *array_from_arraylike(PyObject *in, Dtype *dtype,
Dtype dtype_min = Dtype(0));
PyObject *matrix_from_arraylike(PyObject *in, Dtype *dtype,
Dtype dtype_min = Dtype(0));
// Coerced_dtype will contain the common dtype of the coerced arrays.
int coerce_to_arrays(PyObject **a, PyObject **b, Dtype *coerced_dtype);
......
......@@ -106,6 +106,15 @@ PyObject *array(PyObject *, PyObject *args)
return array_from_arraylike(src, &dtype);
}
PyObject *matrix(PyObject *, PyObject *args)
{
PyObject *src;
Dtype dtype = Dtype::NONE;
if (!PyArg_ParseTuple(args, "O|O&", &src, dtype_converter, &dtype))
return 0;
return matrix_from_arraylike(src, &dtype);
}
PyObject *(*transpose_dtable[])(PyObject*) = DTYPE_DISPATCH(transpose);
PyObject *transpose(PyObject *, PyObject *args)
......@@ -191,6 +200,7 @@ PyMethodDef functions[] = {
{"ones", ones, METH_VARARGS},
{"identity", identity, METH_VARARGS},
{"array", array, METH_VARARGS},
{"matrix", matrix, METH_VARARGS},
{"transpose", transpose, METH_VARARGS},
{"dot", dot, METH_VARARGS},
......
......@@ -72,6 +72,18 @@ def test_array():
assert_raises(ValueError, ta.array, [[0, 0], [[0], [0]]], dtype)
def test_matrix():
for l in [(), 3, (3,), ((3,)), (1, 2), ((1, 2), (3, 4))]:
a = ta.matrix(l)
b = np.matrix(l)
assert_equal(a, b)
a = ta.matrix(ta.array(l))
assert_equal(a, b)
for l in [(((),),), ((3,), ()), ((1, 2), (3,))]:
assert_raises(ValueError, ta.matrix, l)
def test_conversion():
for src_dtype in dtypes:
for dest_dtype in dtypes:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment