Skip to content
Snippets Groups Projects
Commit e4133517 authored by Christoph Groth's avatar Christoph Groth
Browse files

builder: make group ids deterministic

This makes the order of sites in finalized builders reproducible.
parent b30ab3bc
No related branches found
No related tags found
No related merge requests found
......@@ -91,9 +91,8 @@ class Site(object):
raise ValueError('Dimensionality mismatch')
return group(*tuple(a + b for a, b in izip(tag, delta)))
def __hash__(self):
return id(self.group) ^ hash(self.tag)
return self.group.group_id ^ hash(self.tag)
def __eq__(self, other):
return self.group is other.group and self.tag == other.tag
......@@ -110,6 +109,10 @@ class Site(object):
return self.group.pos(self.tag)
# Counter used to give each newly created group an unique id.
next_group_id = 0
class SiteGroup(object):
"""
Abstract base class for site groups.
......@@ -125,10 +128,14 @@ class SiteGroup(object):
__metaclass__ = abc.ABCMeta
def __init__(self):
self.packed_group_id = pgid_of_group(self)
global next_group_id
self.group_id = next_group_id
next_group_id += 1
self.packed_group_id = struct.pack(gid_pack_fmt, self.group_id)
def __repr__(self):
return '<{0} at {1}>'.format(self.__class__.__name__, hex(id(self)))
return '<{0} object: Site group {1}>'.format(
self.__class__.__name__, self.group_id)
@abc.abstractmethod
def pack_tag(self, tag):
......@@ -183,11 +190,7 @@ class SimpleSiteGroup(SiteGroup):
# This is used for packing and unpacking group ids (gids).
gid_pack_fmt = '@P'
gid_pack_size = len(struct.pack(gid_pack_fmt, id(None)))
def pgid_of_group(group):
assert isinstance(group, SiteGroup)
return struct.pack(gid_pack_fmt, id(group))
gid_pack_size = len(struct.pack(gid_pack_fmt, 0))
# The reason why this is a global function and not a method of Builder is that
......
......@@ -54,8 +54,6 @@ def test_graph():
def test_site_groups():
pgid = builder.pgid_of_group
sys = builder.Builder()
assert_equal(sys._group_by_pgid, {})
sg = builder.SimpleSiteGroup()
......@@ -72,9 +70,10 @@ def test_site_groups():
assert_equal(sys[sg(1)], 123)
assert_raises(KeyError, sys.__getitem__, osg(1))
assert_equal(sys._group_by_pgid, {pgid(sg) : sg})
assert_equal(sys._group_by_pgid, {sg.packed_group_id : sg})
sys[osg(1)] = 321
assert_equal(sys._group_by_pgid, {pgid(sg) : sg, pgid(osg) : osg})
assert_equal(sys._group_by_pgid, {sg.packed_group_id : sg,
osg.packed_group_id : osg})
assert_equal(sys[osg(1)], 321)
assert_equal(sg(-5).shifted((-2,), osg), osg(-7))
......
......@@ -19,7 +19,7 @@ def test_make_lattice():
def test_pack_unpack():
for dim in [1, 2, 3, 5, 10, 99]:
group = lattice.make_lattice(np.identity(dim))
group_by_pgid = {builder.pgid_of_group(group) : group}
group_by_pgid = {group.packed_group_id : group}
tag = tuple(xrange(dim))
site = group(*tag)
psite = site.packed()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment