Commit eddb7e15 authored by Christoph Groth's avatar Christoph Groth
Browse files

also fix matrix constructor (by merging it with array constructor)

parent 4272ad13
......@@ -348,7 +348,7 @@ PyObject *(*make_and_readin_array_dtable[])(
DTYPE_DISPATCH(make_and_readin_array);
template <typename T>
PyObject *make_and_readin_scalar(PyObject *in, bool exact)
PyObject *make_and_readin_scalar(PyObject *in, bool exact, int ndim = 0)
{
T value;
if (exact)
......@@ -358,13 +358,17 @@ PyObject *make_and_readin_scalar(PyObject *in, bool exact)
if (value == T(-1) && PyErr_Occurred()) return 0;
Array<T> *result = Array<T>::make(0, 1);
Array<T> *result = Array<T>::make(ndim, 1);
*result->data() = value;
size_t *shape;
result->ndim_shape(0, &shape);
for (int d = 0; d < ndim; ++d) shape[d] = 1;
return (PyObject*)result;
}
PyObject *(*make_and_readin_scalar_dtable[])(PyObject*, bool) =
PyObject *(*make_and_readin_scalar_dtable[])(PyObject*, bool, int) =
DTYPE_DISPATCH(make_and_readin_scalar);
int examine_buffer(PyObject *in, Py_buffer *view, Dtype *dtype)
......@@ -1341,7 +1345,8 @@ int load_index_seq_as_ulong(PyObject *obj, unsigned long *uout,
// If *dtype == NONE the simplest fitting dtype (at least dtype_min)
// will be used and written back to *dtype. Any other value of *dtype requests
// an array of the given dtype.
PyObject *array_from_arraylike(PyObject *in, Dtype *dtype, Dtype dtype_min)
PyObject *array_from_arraylike(PyObject *in, Dtype *dtype, Dtype dtype_min,
bool as_matrix)
{
Dtype dtype_in = get_dtype(in), dt = *dtype;
int ndim;
......@@ -1351,133 +1356,44 @@ PyObject *array_from_arraylike(PyObject *in, Dtype *dtype, Dtype dtype_min)
// `in` is already an array.
if (dt == NONE)
dt = Dtype(std::max(int(dtype_in), int(dtype_min)));
if (dt == dtype_in)
Py_INCREF(result = in);
else
result = convert_array(dt, in, dtype_in);
*dtype = dt;
return result;
} else if(PySequence_Check(in)) {
// `in` is not an array, but is a sequence
bool find_type = (dt == NONE);
// Try if buffer interface is supported
Py_buffer view;
if (examine_buffer(in, &view, find_type ? &dt : 0) == 0) {
if (find_type && int(dt) < int(dtype_min)) dt = dtype_min;
for (int i = 0; i < view.ndim; i++)
shape[i] = view.shape[i];
result = make_and_readin_buffer_dtable[int(dt)](&view, view.ndim,
shape);
PyBuffer_Release(&view);
*dtype = dt;
return result;
}
if (examine_arraylike(in, &ndim, shape, seqs,
find_type ? &dt : 0) == 0) {
if (find_type) {
PyObject *seqs_copy[max_ndim];
for (int d = 0; d < ndim; ++d)
Py_INCREF(seqs_copy[d] = seqs[d]);
if (dt == NONE) {
assert(shape[ndim - 1] == 0);
dt = default_dtype;
}
if (int(dt) < int(dtype_min)) dt = dtype_min;
while (true) {
result = make_and_readin_array_dtable[int(dt)](
in, ndim, ndim, shape, seqs, true);
if (result) break;
dt = Dtype(int(dt) + 1);
if (dt == NONE) {
result = 0;
break;
}
PyErr_Clear();
for (int d = 0; d < ndim; ++d)
Py_INCREF(seqs[d] = seqs_copy[d]);
}
for (int d = 0; d < ndim; ++d) Py_DECREF(seqs_copy[d]);
if (as_matrix) {
size_t *in_shape;
reinterpret_cast<Array_base*>(in)->ndim_shape(&ndim, &in_shape);
if (ndim == 2) {
if (dt == dtype_in)
Py_INCREF(result = in);
else
result = convert_array(dt, in, dtype_in);
} else if (ndim < 2) {
shape[1] = (ndim == 0) ? 1 : in_shape[0];
shape[0] = 1;
result = convert_array(dt, in, dtype_in, 2, shape);
} else {
// A specific dtype has been requested.
result = make_and_readin_array_dtable[int(dt)](
in, ndim, ndim, shape, seqs, false);
PyErr_SetString(PyExc_ValueError,
"Matrix must be 2-dimensional.");
result = 0;
}
*dtype = dt;
return result;
}
} else {
// `in` is a scalar
dtype_in = dtype_of_scalar(in);
bool find_type = (dt == NONE);
if (dtype_in == NONE) {
PyErr_SetString(PyExc_TypeError, "Expecting a number.");
return 0;
}
if (find_type) {
dt = Dtype(std::max(int(dtype_in), int(dtype_min)));
result = make_and_readin_scalar_dtable[int(dt)](in, true);
} else {
result = make_and_readin_scalar_dtable[int(dt)](in, false);
}
*dtype = dt;
return result;
}
return 0;
}
// If *dtype == NONE the simplest fitting dtype (at least dtype_min)
// will be used and written back to *dtype. Any other value of *dtype requests
// an array of the given dtype.
PyObject *matrix_from_arraylike(PyObject *in, Dtype *dtype, Dtype dtype_min)
{
Dtype dtype_in = get_dtype(in), dt = *dtype;
int ndim;
size_t shape[max_ndim];
PyObject *seqs[max_ndim], *result;
if (dtype_in != NONE) {
// `in` is already an array.
if (dt == NONE)
dt = Dtype(std::max(int(dtype_in), int(dtype_min)));
size_t *in_shape;
reinterpret_cast<Array_base*>(in)->ndim_shape(&ndim, &in_shape);
if (ndim == 2) {
if (dt == dtype_in)
Py_INCREF(result = in);
else
result = convert_array(dt, in, dtype_in);
} else if (ndim < 2) {
shape[0] = 1;
shape[1] = (ndim == 0) ? 1 : in_shape[0];
result = convert_array(dt, in, dtype_in, 2, shape);
} else {
PyErr_SetString(PyExc_ValueError, "Matrix must be 2-dimensional.");
result = 0;
}
*dtype = dt;
return result;
} else {
// `in` is not an array.
} else if(PySequence_Check(in)) {
// `in` is not an array, but is a sequence
const bool find_type = (*dtype == NONE);
bool find_type = (dt == NONE);
// Try if buffer interface is supported
Py_buffer view;
if (examine_buffer(in, &view, find_type ? &dt : 0) == 0) {
if (find_type && int(dt) < int(dtype_min)) dt = dtype_min;
for (int i = 0; i < view.ndim; i++)
shape[i] = view.shape[i];
if (view.ndim != 2) {
if (as_matrix && view.ndim != 2) {
if (view.ndim > 2) {
PyErr_SetString(PyExc_ValueError,
"Matrix must be 2-dimensional.");
......@@ -1486,9 +1402,8 @@ PyObject *matrix_from_arraylike(PyObject *in, Dtype *dtype, Dtype dtype_min)
shape[1] = (view.ndim == 0) ? 1 : shape[0];
shape[0] = 1;
}
if (find_type && int(dt) < int(dtype_min)) dt = dtype_min;
result = make_and_readin_buffer_dtable[int(dt)](&view, view.ndim,
shape);
result = make_and_readin_buffer_dtable[int(dt)](
&view, (as_matrix ? 2 : view.ndim), shape);
PyBuffer_Release(&view);
*dtype = dt;
......@@ -1497,7 +1412,7 @@ PyObject *matrix_from_arraylike(PyObject *in, Dtype *dtype, Dtype dtype_min)
if (examine_arraylike(in, &ndim, shape, seqs,
find_type ? &dt : 0) == 0) {
if (ndim != 2) {
if (as_matrix && ndim != 2) {
if (ndim > 2) {
PyErr_SetString(PyExc_ValueError,
"Matrix must be 2-dimensional.");
......@@ -1507,19 +1422,17 @@ PyObject *matrix_from_arraylike(PyObject *in, Dtype *dtype, Dtype dtype_min)
shape[0] = 1;
}
if (find_type) {
// No specific dtype has been requested. It will be
// determined by the input.
PyObject *seqs_copy[max_ndim];
for (int d = 0; d < ndim; ++d)
Py_INCREF(seqs_copy[d] = seqs[d]);
if (dt == NONE) {
assert(shape[1] == 0);
assert(shape[(as_matrix ? 2 : ndim) - 1] == 0);
dt = default_dtype;
}
if (int(dt) < int(dtype_min)) dt = dtype_min;
while (true) {
result = make_and_readin_array_dtable[int(dt)](
in, ndim, 2, shape, seqs, true);
in, ndim, (as_matrix ? 2 : ndim), shape, seqs, true);
if (result) break;
dt = Dtype(int(dt) + 1);
if (dt == NONE) {
......@@ -1534,12 +1447,34 @@ PyObject *matrix_from_arraylike(PyObject *in, Dtype *dtype, Dtype dtype_min)
} else {
// A specific dtype has been requested.
result = make_and_readin_array_dtable[int(dt)](
in, ndim, 2, shape, seqs, false);
in, ndim, (as_matrix ? 2 : ndim), shape, seqs, false);
}
*dtype = dt;
return result;
}
} else {
// `in` is a scalar
dtype_in = dtype_of_scalar(in);
bool find_type = (dt == NONE);
if (dtype_in == NONE) {
PyErr_SetString(PyExc_TypeError, "Expecting a number.");
return 0;
}
if (find_type) {
dt = Dtype(std::max(int(dtype_in), int(dtype_min)));
result = make_and_readin_scalar_dtable[int(dt)](
in, true, (as_matrix ? 2 : 0));
} else {
result = make_and_readin_scalar_dtable[int(dt)](
in, false, (as_matrix ? 2 : 0));
}
*dtype = dt;
return result;
}
return 0;
......
......@@ -149,9 +149,8 @@ inline Dtype get_dtype(PyObject *obj)
}
PyObject *array_from_arraylike(PyObject *in, Dtype *dtype,
Dtype dtype_min = Dtype(0));
PyObject *matrix_from_arraylike(PyObject *in, Dtype *dtype,
Dtype dtype_min = Dtype(0));
Dtype dtype_min = Dtype(0),
bool as_matrix = false);
// Coerced_dtype will contain the common dtype of the coerced arrays.
int coerce_to_arrays(PyObject **a, PyObject **b, Dtype *coerced_dtype);
......
......@@ -199,7 +199,7 @@ PyObject *matrix(PyObject *, PyObject *args)
Dtype dtype = NONE;
if (!PyArg_ParseTuple(args, "O|O&", &src, dtype_converter, &dtype))
return 0;
return matrix_from_arraylike(src, &dtype);
return array_from_arraylike(src, &dtype, Dtype(0), true);
}
PyDoc_STRVAR(matrix_doc,
......
......@@ -114,8 +114,13 @@ def test_matrix():
a = ta.matrix(l)
b = np.matrix(l)
assert_equal(a, b)
assert_equal(a.shape, b.shape)
a = ta.matrix(ta.array(l))
assert_equal(a, b)
assert_equal(a.shape, b.shape)
a = ta.matrix(np.array(l))
assert_equal(a, b)
assert_equal(a.shape, b.shape)
if sys.version_info[:2] > (2, 6):
# Creation of tinyarrays from NumPy matrices only works for Python >
......@@ -367,6 +372,7 @@ def test_other_scalar_types():
for t in types:
a = t(123.456)
assert_equal(ta.array(a), np.array(a))
assert_equal(ta.matrix(a), np.matrix(a))
def test_pickle():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment