Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
Joseph Weston
tinyarray
Commits
6ff26513
Commit
6ff26513
authored
May 23, 2016
by
Joseph Weston
Browse files
WIP
parent
09bda66c
Pipeline
#426
failed with stage
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
src/array.cc
View file @
6ff26513
...
...
@@ -7,10 +7,13 @@
// top-level directory of this distribution and at
// https://gitlab.kwant-project.org/kwant/tinyarray.
#include
<iostream>
#include
<Python.h>
#include
<cstddef>
#include
<sstream>
#include
<limits>
#include
<assert.h>
#include
"array.hh"
#include
"arithmetic.hh"
#include
"functions.hh"
...
...
@@ -180,7 +183,7 @@ int examine_sequence(PyObject *arraylike, int *ndim, size_t *shape,
PyObject
*
p
=
arraylike
;
int
d
=
-
1
;
assert
(
PySequence_Check
(
p
));
for
(
bool
is_sequence
=
true
;
;
is_sequence
=
PySequence_Check
(
p
))
{
for
(
bool
is_sequence
=
true
;;
is_sequence
=
PySequence_Check
(
p
))
{
if
(
is_sequence
)
{
++
d
;
if
(
d
==
ptrdiff_t
(
max_ndim
))
{
...
...
@@ -910,48 +913,125 @@ Hash hash(PyObject *obj)
}
template
<
typename
T
>
bool
is_equal_data
(
PyObject
*
a_
,
PyObject
*
b_
,
size_t
size
)
bool
compare_scalar
(
const
int
op
,
const
T
a
,
const
T
b
)
{
switch
(
op
){
case
Py_EQ
:
return
a
==
b
;
case
Py_NE
:
return
a
!=
b
;
case
Py_LE
:
return
a
<=
b
;
case
Py_GE
:
return
a
>=
b
;
case
Py_LT
:
return
a
<
b
;
case
Py_GT
:
return
a
>
b
;
default:
assert
(
false
);
// if we get here something is very wrong
return
false
;
// stop the compiler complaining
}
}
template
<
>
bool
compare_scalar
<
Complex
>
(
const
int
op
,
const
Complex
a
,
const
Complex
b
)
{
switch
(
op
){
case
Py_EQ
:
return
a
==
b
;
case
Py_NE
:
return
a
!=
b
;
// this function is never called in a context where
// the following code path is run -- fall through
case
Py_LE
:
case
Py_GT
:
case
Py_LT
:
case
Py_GE
:
default:
assert
(
false
);
return
false
;
// stop the compiler complaining
}
}
template
<
typename
T
>
bool
compare_data
(
int
op
,
PyObject
*
a_
,
PyObject
*
b_
,
size_t
size
)
{
assert
(
Array
<
T
>::
check_exact
(
a_
));
Array
<
T
>
*
a
=
(
Array
<
T
>*
)
a_
;
assert
(
Array
<
T
>::
check_exact
(
b_
));
Array
<
T
>
*
b
=
(
Array
<
T
>*
)
b_
;
T
*
data_a
=
a
->
data
();
T
*
data_b
=
b
->
data
();
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
if
(
data_a
[
i
]
!=
data_b
[
i
])
return
false
;
return
true
;
const
T
*
data_a
=
a
->
data
();
const
T
*
data_b
=
b
->
data
();
// sequences are ordered the same as their first differing elements, see:
// https://docs.python.org/2/reference/expressions.html#not-in
// comparison for "multidimensional" sequences is identical to comparing
// the flattened sequences when they have the same shape (the present case)
size_t
i
=
0
;
for
(;
i
<
size
;
++
i
)
if
(
data_a
[
i
]
!=
data_b
[
i
])
break
;
// we got to the end without breaking: arrays are equal
if
(
i
==
size
)
return
(
op
==
Py_EQ
);
// encapsulate this into a function to handle the COMPLEX case
return
compare_scalar
<
T
>
(
op
,
data_a
[
i
],
data_b
[
i
]);
}
bool
(
*
is_equal_data_dtable
[])(
PyObject
*
,
PyObject
*
,
size_t
)
=
DTYPE_DISPATCH
(
is_equal_data
);
// don't generate dispatch table for COMPLEX datatype as it will never be used
// in `rich_compare` (COMPLEX is unorderable), and the compiler complains about
// generating `compare_scalar<complex>` because some operations are undefined
bool
(
*
compare_data_dtable
[])(
int
,
PyObject
*
,
PyObject
*
,
size_t
)
=
DTYPE_DISPATCH
(
compare_data
);
PyObject
*
richcompare
(
PyObject
*
a
,
PyObject
*
b
,
int
op
)
{
if
(
op
!=
Py_EQ
&&
op
!=
Py_NE
)
{
Py_INCREF
(
Py_NotImplemented
);
return
Py_NotImplemented
;
}
PyObject
*
result
;
const
bool
equality_comparison
=
(
op
==
Py_EQ
||
op
==
Py_NE
);
// short circuit when we are comparing the same object
bool
equal
=
(
a
==
b
);
if
(
equal
)
goto
done
;
if
(
equal
)
{
result
=
((
op
==
Py_EQ
)
==
equal
)
?
Py_True
:
Py_False
;
goto
done
;
}
Dtype
dtype
;
if
(
coerce_to_arrays
(
&
a
,
&
b
,
&
dtype
)
<
0
)
return
0
;
// obviate the need for `compare_scalar<Complex` to
// handle the case of an undefined comparison
if
(
dtype
==
COMPLEX
&&
!
equality_comparison
)
{
PyErr_SetString
(
PyExc_TypeError
,
"unorderable type: complex()"
);
return
0
;
}
int
ndim_a
,
ndim_b
;
size_t
*
shape_a
,
*
shape_b
;
reinterpret_cast
<
Array_base
*>
(
a
)
->
ndim_shape
(
&
ndim_a
,
&
shape_a
);
reinterpret_cast
<
Array_base
*>
(
b
)
->
ndim_shape
(
&
ndim_b
,
&
shape_b
);
if
(
ndim_a
!=
ndim_b
)
goto
decref_then_done
;
for
(
int
d
=
0
;
d
<
ndim_a
;
++
d
)
if
(
shape_a
[
d
]
!=
shape_b
[
d
])
goto
decref_then_done
;
equal
=
is_equal_data_dtable
[
int
(
dtype
)](
a
,
b
,
calc_size
(
ndim_a
,
shape_a
));
// TODO: enable array comparisons between arrays of differing
// dimensions
if
(
ndim_a
!=
ndim_b
)
{
if
(
equality_comparison
)
{
goto
equality_then_done
;
}
else
{
result
=
Py_NotImplemented
;
goto
decref_then_done
;
}
}
for
(
int
d
=
0
;
d
<
ndim_a
;
++
d
)
{
if
(
shape_a
[
d
]
!=
shape_b
[
d
])
{
if
(
equality_comparison
)
{
goto
equality_then_done
;
}
else
{
result
=
Py_NotImplemented
;
goto
decref_then_done
;
}
}
}
// actually compare the data
std
::
cout
<<
"COMPARING DATA
\n
"
;
equal
=
compare_data_dtable
[
int
(
dtype
)](
op
,
a
,
b
,
calc_size
(
ndim_a
,
shape_a
));
result
=
equal
?
Py_True
:
Py_False
;
goto
done
;
// exit points from this function
equality_then_done:
result
=
((
op
==
Py_EQ
)
==
equal
)
?
Py_True
:
Py_False
;
decref_then_done:
Py_DECREF
(
a
);
Py_DECREF
(
b
);
done:
PyObject
*
result
=
((
op
==
Py_EQ
)
==
equal
)
?
Py_True
:
Py_False
;
Py_INCREF
(
result
);
return
result
;
}
...
...
test_tinyarray.py
View file @
6ff26513
...
...
@@ -9,6 +9,7 @@
import
operator
,
warnings
import
platform
import
itertools
as
it
import
tinyarray
as
ta
from
nose.tools
import
assert_raises
import
numpy
as
np
...
...
@@ -414,17 +415,31 @@ def test_sizeof():
def
test_comparison
():
ops
=
operator
# test comparison for 0D and 1D arrays
# TODO: once higher-dimensional arrays are implemented, test them here
for
shape
in
[(),
(
1
,),
(
2
,)]:
# should cover all code branches
for
op
in
[
ops
.
ge
,
ops
.
gt
,
ops
.
le
,
ops
.
lt
,
ops
.
eq
,
ops
.
ne
]:
for
dtype
in
(
int
,
float
,
complex
):
a
=
ta
.
ones
(
shape
,
dtype
)
b
=
ta
.
zeros
(
shape
,
dtype
)
if
dtype
is
complex
and
op
not
in
[
ops
.
eq
,
ops
.
ne
]:
assert_raises
(
TypeError
,
op
,
a
,
b
)
else
:
assert_equal
(
op
(
a
,
b
),
op
(
a
.
tolist
(),
b
.
tolist
()))
for
op
in
[
ops
.
ge
,
ops
.
gt
,
ops
.
le
,
ops
.
lt
,
ops
.
eq
,
ops
.
ne
]:
for
dtype
in
(
int
,
float
,
complex
):
for
left
,
right
in
it
.
product
((
np
.
zeros
,
np
.
ones
),
repeat
=
2
):
for
shape
in
[(),
(
1
,),
(
2
,),
(
2
,
2
),
(
2
,
2
,
2
),
(
2
,
3
)]:
a
=
left
(
shape
,
dtype
)
b
=
right
(
shape
,
dtype
)
if
dtype
is
complex
and
op
not
in
[
ops
.
eq
,
ops
.
ne
]:
assert_raises
(
TypeError
,
op
,
a
,
b
)
else
:
# passing the same object
same
=
ta
.
array
(
a
)
print
(
op
,
a
,
b
)
assert_equal
(
op
(
same
,
same
),
op
(
a
.
tolist
(),
a
.
tolist
()))
# passing different objects
assert_equal
(
op
(
ta
.
array
(
a
),
ta
.
array
(
a
)),
op
(
a
.
tolist
(),
a
.
tolist
()))
# test different ndims and different shapes
for
shp1
,
shp2
in
[((
2
,),
(
2
,
2
)),
((
2
,
2
),
(
2
,
3
))]:
a
=
left
(
shp1
,
dtype
)
b
=
right
(
shp2
,
dtype
)
if
dtype
is
complex
and
op
not
in
[
ops
.
eq
,
ops
.
ne
]:
assert_raises
(
TypeError
,
op
,
a
,
b
)
elif
op
not
in
(
ops
.
eq
,
ops
.
ne
):
assert_raises
(
NotImplementedError
,
op
,
a
,
b
)
def
test_pickle
():
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment