Skip to content
GitLab
Explore
Sign in
Register
Primary navigation
Search or go to…
Project
T
tinyarray
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container Registry
Model registry
Operate
Environments
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
Jörg Behrmann
tinyarray
Commits
accb8c15
Commit
accb8c15
authored
8 years ago
by
Joseph Weston
Browse files
Options
Downloads
Patches
Plain Diff
implement comparison between arrays with the same shape
If the arrays do not have the same shape then we return NotImplemented.
parent
fc33353c
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/array.cc
+100
-19
100 additions, 19 deletions
src/array.cc
test_tinyarray.py
+32
-0
32 additions, 0 deletions
test_tinyarray.py
with
132 additions
and
19 deletions
src/array.cc
+
100
−
19
View file @
accb8c15
...
...
@@ -11,6 +11,7 @@
#include
<cstddef>
#include
<sstream>
#include
<limits>
#include
<assert.h>
#include
"array.hh"
#include
"arithmetic.hh"
#include
"functions.hh"
...
...
@@ -910,48 +911,128 @@ Hash hash(PyObject *obj)
}
template
<
typename
T
>
bool
is_equal_data
(
PyObject
*
a_
,
PyObject
*
b_
,
size_t
size
)
bool
compare_scalar
(
const
int
op
,
const
T
a
,
const
T
b
)
{
switch
(
op
){
case
Py_EQ
:
return
a
==
b
;
case
Py_NE
:
return
a
!=
b
;
case
Py_LE
:
return
a
<=
b
;
case
Py_GE
:
return
a
>=
b
;
case
Py_LT
:
return
a
<
b
;
case
Py_GT
:
return
a
>
b
;
default:
assert
(
false
);
// if we get here something is very wrong
return
false
;
// stop the compiler complaining
}
}
template
<
>
bool
compare_scalar
<
Complex
>
(
const
int
op
,
const
Complex
a
,
const
Complex
b
)
{
switch
(
op
){
case
Py_EQ
:
return
a
==
b
;
case
Py_NE
:
return
a
!=
b
;
// this function is never called in a context where
// the following code path is run -- fall through
case
Py_LE
:
case
Py_GT
:
case
Py_LT
:
case
Py_GE
:
default:
assert
(
false
);
return
false
;
// stop the compiler complaining
}
}
template
<
typename
T
>
bool
compare_data
(
int
op
,
PyObject
*
a_
,
PyObject
*
b_
,
size_t
size
)
{
assert
(
Array
<
T
>::
check_exact
(
a_
));
Array
<
T
>
*
a
=
(
Array
<
T
>*
)
a_
;
assert
(
Array
<
T
>::
check_exact
(
b_
));
Array
<
T
>
*
b
=
(
Array
<
T
>*
)
b_
;
T
*
data_a
=
a
->
data
();
T
*
data_b
=
b
->
data
();
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
if
(
data_a
[
i
]
!=
data_b
[
i
])
return
false
;
return
true
;
const
T
*
data_a
=
a
->
data
();
const
T
*
data_b
=
b
->
data
();
// sequences are ordered the same as their first differing elements, see:
// https://docs.python.org/2/reference/expressions.html#not-in
// comparison for "multidimensional" sequences is identical to comparing
// the flattened sequences when they have the same shape (the present case)
size_t
i
=
0
;
for
(;
i
<
size
;
++
i
)
if
(
data_a
[
i
]
!=
data_b
[
i
])
break
;
// any of these operations should return true when objects are equal
if
(
i
==
size
)
return
((
op
==
Py_EQ
)
||
(
op
==
Py_LE
)
||
(
op
==
Py_GE
));
// encapsulate this into a function to handle the COMPLEX case
return
compare_scalar
<
T
>
(
op
,
data_a
[
i
],
data_b
[
i
]);
}
bool
(
*
is_equal_data_dtable
[])(
PyObject
*
,
PyObject
*
,
size_t
)
=
DTYPE_DISPATCH
(
is_equal_data
);
// don't generate dispatch table for COMPLEX datatype as it will never be used
// in `rich_compare` (COMPLEX is unorderable), and the compiler complains about
// generating `compare_scalar<complex>` because some operations are undefined
bool
(
*
compare_data_dtable
[])(
int
,
PyObject
*
,
PyObject
*
,
size_t
)
=
DTYPE_DISPATCH
(
compare_data
);
PyObject
*
richcompare
(
PyObject
*
a
,
PyObject
*
b
,
int
op
)
{
if
(
op
!=
Py_EQ
&&
op
!=
Py_NE
)
{
Py_INCREF
(
Py_NotImplemented
);
return
Py_NotImplemented
;
}
PyObject
*
result
;
const
bool
equality_comparison
=
(
op
==
Py_EQ
||
op
==
Py_NE
);
// short circuit when we are comparing the same object
bool
equal
=
(
a
==
b
);
if
(
equal
)
goto
done
;
if
(
equal
)
{
// any of these operations should return true when objects are equal
equal
=
(
op
==
Py_EQ
)
||
(
op
==
Py_GE
)
||
(
op
==
Py_LE
);
result
=
equal
?
Py_True
:
Py_False
;
goto
done
;
}
Dtype
dtype
;
if
(
coerce_to_arrays
(
&
a
,
&
b
,
&
dtype
)
<
0
)
return
0
;
// obviate the need for `compare_scalar<Complex` to
// handle the case of an undefined comparison
if
(
dtype
==
COMPLEX
&&
!
equality_comparison
)
{
result
=
Py_NotImplemented
;
goto
decref_then_done
;
}
int
ndim_a
,
ndim_b
;
size_t
*
shape_a
,
*
shape_b
;
reinterpret_cast
<
Array_base
*>
(
a
)
->
ndim_shape
(
&
ndim_a
,
&
shape_a
);
reinterpret_cast
<
Array_base
*>
(
b
)
->
ndim_shape
(
&
ndim_b
,
&
shape_b
);
if
(
ndim_a
!=
ndim_b
)
goto
decref_then_done
;
for
(
int
d
=
0
;
d
<
ndim_a
;
++
d
)
if
(
shape_a
[
d
]
!=
shape_b
[
d
])
goto
decref_then_done
;
equal
=
is_equal_data_dtable
[
int
(
dtype
)](
a
,
b
,
calc_size
(
ndim_a
,
shape_a
));
// TODO: enable array comparisons between arrays of differing
// dimensions
if
(
ndim_a
!=
ndim_b
)
{
if
(
equality_comparison
)
{
goto
equality_then_done
;
}
else
{
result
=
Py_NotImplemented
;
goto
decref_then_done
;
}
}
for
(
int
d
=
0
;
d
<
ndim_a
;
++
d
)
{
if
(
shape_a
[
d
]
!=
shape_b
[
d
])
{
if
(
equality_comparison
)
{
goto
equality_then_done
;
}
else
{
result
=
Py_NotImplemented
;
goto
decref_then_done
;
}
}
}
// actually compare the data
equal
=
compare_data_dtable
[
int
(
dtype
)](
op
,
a
,
b
,
calc_size
(
ndim_a
,
shape_a
));
result
=
equal
?
Py_True
:
Py_False
;
goto
decref_then_done
;
// non error-path exit points from this function
equality_then_done
:
result
=
((
op
==
Py_EQ
)
==
equal
)
?
Py_True
:
Py_False
;
decref_then_done
:
Py_DECREF
(
a
);
Py_DECREF
(
b
);
done
:
PyObject
*
result
=
((
op
==
Py_EQ
)
==
equal
)
?
Py_True
:
Py_False
;
Py_INCREF
(
result
);
return
result
;
}
...
...
This diff is collapsed.
Click to expand it.
test_tinyarray.py
+
32
−
0
View file @
accb8c15
...
...
@@ -9,6 +9,7 @@
import
operator
,
warnings
import
platform
import
itertools
as
it
import
tinyarray
as
ta
from
nose.tools
import
assert_raises
import
numpy
as
np
...
...
@@ -412,6 +413,37 @@ def test_sizeof():
assert_equal
(
sizeof
,
sizeof_should_be
)
def
test_comparison
():
ops
=
operator
for
op
in
[
ops
.
ge
,
ops
.
gt
,
ops
.
le
,
ops
.
lt
,
ops
.
eq
,
ops
.
ne
]:
for
dtype
in
(
int
,
float
,
complex
):
for
left
,
right
in
it
.
product
((
np
.
zeros
,
np
.
ones
),
repeat
=
2
):
for
shape
in
[(),
(
1
,),
(
2
,),
(
2
,
2
),
(
2
,
2
,
2
),
(
2
,
3
)]:
a
=
left
(
shape
,
dtype
)
b
=
right
(
shape
,
dtype
)
if
dtype
is
complex
and
op
not
in
[
ops
.
eq
,
ops
.
ne
]:
# unorderable types
assert_raises
(
TypeError
,
op
,
ta
.
array
(
a
),
ta
.
array
(
b
))
else
:
# passing the same object
same
=
ta
.
array
(
a
)
assert_equal
(
op
(
same
,
same
),
op
(
a
.
tolist
(),
a
.
tolist
()))
# passing different objects, but equal
assert_equal
(
op
(
ta
.
array
(
a
),
ta
.
array
(
a
)),
op
(
a
.
tolist
(),
a
.
tolist
()))
# passing different objects, not equal
assert_equal
(
op
(
ta
.
array
(
a
),
ta
.
array
(
b
)),
op
(
a
.
tolist
(),
b
.
tolist
()))
# test different ndims and different shapes
for
shp1
,
shp2
in
[((
2
,),
(
2
,
2
)),
((
2
,
2
),
(
2
,
3
))]:
a
=
left
(
shp1
,
dtype
)
b
=
right
(
shp2
,
dtype
)
if
op
not
in
(
ops
.
eq
,
ops
.
ne
):
# unorderable types
assert_raises
(
TypeError
,
op
,
ta
.
array
(
a
),
ta
.
array
(
b
))
def
test_pickle
():
import
pickle
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment