Commit 946518ae authored by Olaf's avatar Olaf

reduce calculations

parent 4b06bec4
......@@ -25,7 +25,7 @@ def create_node(root):
pos_ref = pos - origin # set reference origin at 0
node = [[], [], [], [], (np.sum(pos.T*mass, axis = 1)/np.sum(mass)).tolist(), np.sum(mass), origin.tolist(), 2*r] # initialize node with origin, center of mass, total mass, width
node = [[], [], [], [], (np.sum(pos.T*mass, axis = 1)/np.sum(mass)).tolist(), np.sum(mass), 2*r] # initialize node with origin, center of mass, total mass, width
east = (np.dot(pos_ref, [1,0])>0)
west = (np.dot(pos_ref, [1,0])<=0)
......@@ -92,7 +92,7 @@ def create_root(pos, mass, origin):
contains 4 lists containing all information concerning the 4 new nodes, the center of mass, the total mass and the origin of the root
"""
r = np.max(np.abs(pos - origin))
nbody = np.array(range(mass.shape[0]))
nbody = np.arange(mass.shape[0])
root_data = [pos, mass, origin, nbody, 2*r]
root = create_node(root_data)
......@@ -115,7 +115,7 @@ def build_tree(root):
node: list
contains 4 lists containing all information concerning the 4 new nodes, the center of mass, the total mass and the origin of the "root"
"""
for i in range(len(root) - 4):
for i in range(len(root) - 3):
if len(root[i]) == 5:
root[i] = build_tree(create_node(root[i]))
......@@ -167,7 +167,7 @@ def force_tree(theta, tree, pos):
U: np.array([N,])
Potential energy
"""
force, U = force_node(pos, tree, theta, np.array(range(pos.shape[0])), pos.shape)
force, U = force_node(pos, tree, theta, np.arange(pos.shape[0]), pos.shape)
# return
return force, U
......@@ -211,11 +211,11 @@ def force_node(pos, node, theta, nbody, Mshape):
force[nbody], U[nbody] = force_cal(pos, node[0], node[1])
else:
nbody_for = np.delete(nbody, np.array(range(pos.shape[0]))[nbody == node[-1]])
pos_for = np.delete(pos, np.array(range(pos.shape[0]))[nbody == node[-1]], 0)
nbody_for = np.delete(nbody, np.arange(pos.shape[0])[nbody == node[-1]])
pos_for = np.delete(pos, np.arange(pos.shape[0])[nbody == node[-1]], 0)
force[nbody_for], U[nbody_for] = force_cal(pos_for, node[0], node[1])
elif len(node) == 8:
elif len(node) == (2**Mshape[1] + 3):
d = np.linalg.norm(pos - np.asarray(node[4]), axis = 1)
nbody_calc = nbody[node[-1]/d < theta]
......@@ -225,7 +225,7 @@ def force_node(pos, node, theta, nbody, Mshape):
pos_cont = pos[node[-1]/d >= theta]
nbody_cont = nbody[node[-1]/d >= theta]
for i in range(len(node) - 4):
for i in range(len(node) - 3):
force1, U1 = force_node(pos_cont, node[i], theta, nbody_cont, Mshape)
force += force1
U += U1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment