Commit 33c3a965 authored by Joseph Weston's avatar Joseph Weston Committed by Christoph Groth
Browse files

fix bug for when numpy scalars are passed to the array constructor

parent a07cd6bd
...@@ -70,11 +70,6 @@ Dtype dtype_of_scalar(PyObject *obj) ...@@ -70,11 +70,6 @@ Dtype dtype_of_scalar(PyObject *obj)
if (PyLong_Check(obj)) return LONG; if (PyLong_Check(obj)) return LONG;
if (PyObject_HasAttr(obj, index_str)) return LONG; if (PyObject_HasAttr(obj, index_str)) return LONG;
// I'm not sure about this paragraph. Does the existence of a __complex__
// method already signify that the number is to be interpreted as a complex
// number? What about __float__? Perhaps the following code is useless.
// In practice (with built-in and numpy numerical types) it never plays a
// role anyway.
if (PyObject_HasAttr(obj, complex_str)) return COMPLEX; if (PyObject_HasAttr(obj, complex_str)) return COMPLEX;
if (PyObject_HasAttr(obj, float_str)) return DOUBLE; if (PyObject_HasAttr(obj, float_str)) return DOUBLE;
if (PyObject_HasAttr(obj, int_str)) return LONG; if (PyObject_HasAttr(obj, int_str)) return LONG;
...@@ -352,6 +347,26 @@ PyObject *(*make_and_readin_array_dtable[])( ...@@ -352,6 +347,26 @@ PyObject *(*make_and_readin_array_dtable[])(
PyObject*, int, int, const size_t*, PyObject**, bool) = PyObject*, int, int, const size_t*, PyObject**, bool) =
DTYPE_DISPATCH(make_and_readin_array); DTYPE_DISPATCH(make_and_readin_array);
template <typename T>
PyObject *make_and_readin_scalar(PyObject *in, bool exact)
{
T value;
if (exact)
value = number_from_pyobject_exact<T>(in);
else
value = number_from_pyobject<T>(in);
if (value == T(-1) && PyErr_Occurred()) return 0;
Array<T> *result = Array<T>::make(0, 1);
*result->data() = value;
return (PyObject*)result;
}
PyObject *(*make_and_readin_scalar_dtable[])(PyObject*, bool) =
DTYPE_DISPATCH(make_and_readin_scalar);
int examine_buffer(PyObject *in, Py_buffer *view, Dtype *dtype) int examine_buffer(PyObject *in, Py_buffer *view, Dtype *dtype)
{ {
Dtype dt = NONE; Dtype dt = NONE;
...@@ -1343,8 +1358,8 @@ PyObject *array_from_arraylike(PyObject *in, Dtype *dtype, Dtype dtype_min) ...@@ -1343,8 +1358,8 @@ PyObject *array_from_arraylike(PyObject *in, Dtype *dtype, Dtype dtype_min)
*dtype = dt; *dtype = dt;
return result; return result;
} else { } else if(PySequence_Check(in)) {
// `in` is not an array. // `in` is not an array, but is a sequence
bool find_type = (dt == NONE); bool find_type = (dt == NONE);
...@@ -1396,6 +1411,26 @@ PyObject *array_from_arraylike(PyObject *in, Dtype *dtype, Dtype dtype_min) ...@@ -1396,6 +1411,26 @@ PyObject *array_from_arraylike(PyObject *in, Dtype *dtype, Dtype dtype_min)
*dtype = dt; *dtype = dt;
return result; return result;
} }
} else {
// `in` is a scalar
dtype_in = dtype_of_scalar(in);
bool find_type = (dt == NONE);
if (dtype_in == NONE) {
PyErr_SetString(PyExc_TypeError, "Expecting a number.");
return 0;
}
if (find_type) {
dt = Dtype(std::max(int(dtype_in), int(dtype_min)));
result = make_and_readin_scalar_dtable[int(dt)](in, true);
} else {
result = make_and_readin_scalar_dtable[int(dt)](in, false);
}
*dtype = dt;
return result;
} }
return 0; return 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment