Commit 25520bf3 authored by Christoph Groth's avatar Christoph Groth
Browse files

unlock tests that have been commented-out until now

parent f698c75f
......@@ -33,6 +33,7 @@ PyObject *array_scalar_product(PyObject *a_, PyObject *b_)
T *data_a = a->data(), *data_b = b->data();
// It's important not to start with result = 0. This leads to wrong
// results with regard to the sign of zero as 0.0 + -0.0 is 0.0.
if (n == 0) return pyobject_from_number(T(0));
assert(n > 0);
T result = data_a[0] * data_b[0];
for (size_t i = 1; i < n; ++i) {
......@@ -88,24 +89,29 @@ PyObject *array_matrix_product(PyObject *a_, PyObject *b_)
return 0;
}
Array<T> *result = Array<T>::make(ndim, shape);
size_t size;
Array<T> *result = Array<T>::make(ndim, shape, &size);
if (!result) return 0;
const T *data_a = a->data(), *data_b = b->data();
T *dest = result->data();
const T *src_a = data_a;
for (size_t i = 0; i < a0; ++i, src_a += n) {
const T *src_b = data_b;
for (size_t j = 0; j < b0; ++j, src_b += (n - 1) * b1) {
for (size_t k = 0; k < b1; ++k, ++src_b) {
// It's important not to start with sum = 0. This leads to
// wrong results with regard to the sign of zero as 0.0 + -0.0
// is 0.0.
assert(n > 0);
T sum = src_a[0] * src_b[0];
for (size_t l = 1; l < n; ++l)
sum += src_a[l] * src_b[l * b1];
*dest++ = sum;
if (n == 0) {
for (size_t i = 0; i < size; ++i) dest[i] = 0;
} else {
assert(n > 0);
const T *data_a = a->data(), *data_b = b->data();
const T *src_a = data_a;
for (size_t i = 0; i < a0; ++i, src_a += n) {
const T *src_b = data_b;
for (size_t j = 0; j < b0; ++j, src_b += (n - 1) * b1) {
for (size_t k = 0; k < b1; ++k, ++src_b) {
// It's important not to start with sum = 0. This leads to
// wrong results with regard to the sign of zero as 0.0 +
// -0.0 is 0.0.
T sum = src_a[0] * src_b[0];
for (size_t l = 1; l < n; ++l)
sum += src_a[l] * src_b[l * b1];
*dest++ = sum;
}
}
}
}
......
......@@ -41,8 +41,8 @@ def test_array():
b = ta.array(l)
b_shape = shape_of_seq(l)
# a_shape and b_shape are not always equal. Example: a_shape ==
# (0, 0), b_shape = (0,).
# a_shape and b_shape are not always equal.
# Example: a_shape == (0, 0), b_shape = (0,).
assert isinstance(repr(b), str)
assert_equal(b.ndim, len(b_shape))
......@@ -151,12 +151,10 @@ def test_dot():
assert_equal(ta.dot([1, 2], (3, 4)), 11)
for dtype in dtypes:
# The commented testcases can be added once there is support for
# creating tinyarrays via the buffer interface.
shape_pairs = [(1, 1), (2, 2), (3, 3),
# (0, 0),
# (0, (0, 1)), ((0, 1), 1),
# (0, (0, 2)), ((0, 2), 2),
(0, 0),
(0, (0, 1)), ((0, 1), 1),
(0, (0, 2)), ((0, 2), 2),
(1, (1, 2)), ((2, 1), 1),
(2, (2, 1)), ((1, 2), 2),
(2, (2, 3)), ((3, 2), 2),
......@@ -187,9 +185,7 @@ def test_dot():
assert_almost_equal(ta.dot(ta.array(a), ta.array(b)), np.dot(a, b),
13)
# The commented out testcases do not work due to a bug in numpy:
# PySequence_Check should return 0 for 0-d arrays.
shape_pairs = [#((), 2), (2, ()),
shape_pairs = [((), 2), (2, ()),
(1, 2),
(1, (2, 2)), ((1, 1), 2),
((2, 2), (3, 2)),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment