HW 1/2: Gravity simulation with tree method in O(N log N), 2D/3D

video

Backgrounds

Newton’s gravity law:

a = \frac M{r^2}

The traditional approach for simulating a N-body system is, for each particle i and j, we evaluate once this formula. This requires O(N^2) complicity.

But if two objects are very far from me, then I may consider these two as a single one.
That is, in formula:

a = \frac{M_1}{{r_1}^2} + \frac{M_2}{{r_2}^2} \approx \frac{M_1 + M_2}{{r_{CoM}}^2}

where r_{CoM} = \frac{M_1r_1 + M_2r_2}{M_1 + M_2}.

But how far is far?
Well, we define r_{CoM} < \lambda(r_1 - r_2) as far.
I.e. distance to the CoM of a pair of binary star is much smaller than its size.

The \lambda here is called shape factor. The smaller it is, the better result you get, and the more computation required. But usually picking \lambda = 1 could obtain results that are good enough.

To recognize distances between objects, we build a quadtree (2D) or octree (3D) to store position information for quick lookup. And the above talk could be translated into:

Well, we define distance_to_node_CoM < lambda * node_size as far.

The maximum depth of tree would be O(log N) thanks to tree’s sparsity.
So inserting an particle into a tree, which depends on tree depth, takes O(log N), therefore creating a tree will be O(N log N).
Computing the gravity at a specific point, which also depends on tree depth, take O(log N), therefore computing the gravity at all particles’ position, takes O(N log N).
To sum up, our algorithm is O(N log N).

Further O improvements could be done by adapting the Fast Multipole Method, which would be O(N), but I didn’t managed to create that code, in 1 Hu-days and 10 Peng-days.

Algorithms descriptions

First, we place particles into tree nodes, 1 particle per leaf node.

When computing gravity on a specific area, starting from root node:
If distance to that node is larger than the node size, then we will consider that node as a single huge particle, and compute gravity according to its CoM & mass.
otherwise, separate that node into 4 child nodes, and repeat the process.
Until that the node is a leaf node, stop, directly compute the only one particle inside.

Implementation details

Due to the fact that Taichi doesn’t support recursion in functions, I have to hand-write an ad-hoc queue for visiting & generating a tree.
That is, allocate an array trash with dynamic length, to serve as a queue for BFS.

When inserting a node, it first kick-off the existing particle inside into trash, will re-insert later.
When round 1 complete, pop one of the trash element, repeat the above work.
Until there’s no more elements in the trash.

When visiting a node, it visits all the children nodes.
If one of the node is far: compute the gravity according to node CoM, mass.
If one of the node is leaf node: compute the particle gravity directly.
If one of the node is not far: push it into the trash for future visit.
When round 1 complete, pop one of the trash element, repeat the above work.
Until there’s no more elements in the trash.

You may think I used ti.root.dynamic for trash? No, the performance of dynamic SNode is so poor that I’d rather use an extra scalar trash_len[None] and use dynamic range-for instead.

In fact, I used completely no sparse structure in this script, despite it would be very useful if @yuanming would like to improve them since we moved Taichi into Python…

@yuanming Hi, I’m brain-fried by trash now, maybe we should consider:

  1. Improve ti.root.dynamic() performance.
  2. Add ti.root.tree() as a new kind of sparse structure. (IMO BFS/DFS could be straightforward with the listgen system)

But I’d also thank to Taichi’s cool meta programming feature, which enables me to combine the 2D and 3D version into one piece of code without pain.

Complete code

import taichi as ti
import taichi_glsl as tl
ti.init()
if not hasattr(ti, 'jkl'):
    ti.jkl = ti.indices(1, 2, 3)

kUseTree = True
#kDisplay = 'pixels cmap tree mouse save_result'
kDisplay = 'pixels cmap'
kResolution = 800
kShapeFactor = 1
kMaxParticles = 8192
kMaxDepth = kMaxParticles * 1
kMaxNodes = kMaxParticles * 4
kDim = 2

dt = 0.00005
LEAF = -1
TREE = -2

particle_mass = ti.var(ti.f32)
particle_pos = ti.Vector(kDim, ti.f32)
particle_vel = ti.Vector(kDim, ti.f32)
particle_table = ti.root.dense(ti.i, kMaxParticles)
particle_table.place(particle_pos).place(particle_vel).place(particle_mass)
particle_table_len = ti.var(ti.i32, ())

if kUseTree:
    trash_particle_id = ti.var(ti.i32)
    trash_base_parent = ti.var(ti.i32)
    trash_base_geo_center = ti.Vector(kDim, ti.f32)
    trash_base_geo_size = ti.var(ti.f32)
    trash_table = ti.root.dense(ti.i, kMaxDepth)
    trash_table.place(trash_particle_id)
    trash_table.place(trash_base_parent, trash_base_geo_size)
    trash_table.place(trash_base_geo_center)
    trash_table_len = ti.var(ti.i32, ())

    node_mass = ti.var(ti.f32)
    node_weighted_pos = ti.Vector(kDim, ti.f32)
    node_particle_id = ti.var(ti.i32)
    node_children = ti.var(ti.i32)
    node_table = ti.root.dense(ti.i, kMaxNodes)
    node_table.place(node_mass, node_particle_id, node_weighted_pos)
    node_table.dense({2: ti.jk, 3: ti.jkl}[kDim], 2).place(node_children)
    node_table_len = ti.var(ti.i32, ())

if 'mouse' in kDisplay:
    display_image = ti.Vector(3, ti.f32, (kResolution, kResolution))
elif len(kDisplay):
    display_image = ti.var(ti.f32, (kResolution, kResolution))


@ti.func
def alloc_node():
    ret = ti.atomic_add(node_table_len[None], 1)
    assert ret < kMaxNodes
    node_mass[ret] = 0
    node_weighted_pos[ret] = particle_pos[0] * 0
    node_particle_id[ret] = LEAF
    for which in ti.grouped(ti.ndrange(*([2] * kDim))):
        node_children[ret, which] = LEAF
    return ret


@ti.func
def alloc_particle():
    ret = ti.atomic_add(particle_table_len[None], 1)
    assert ret < kMaxParticles
    particle_mass[ret] = 0
    particle_pos[ret] = particle_pos[0] * 0
    particle_vel[ret] = particle_pos[0] * 0
    return ret


@ti.func
def alloc_trash():
    ret = ti.atomic_add(trash_table_len[None], 1)
    assert ret < kMaxDepth
    return ret


@ti.func
def alloc_a_node_for_particle(particle_id, parent, parent_geo_center,
                              parent_geo_size):
    position = particle_pos[particle_id]
    mass = particle_mass[particle_id]

    depth = 0
    while depth < kMaxDepth:
        already_particle_id = node_particle_id[parent]
        if already_particle_id == LEAF:
            break
        if already_particle_id != TREE:
            node_particle_id[parent] = TREE
            trash_id = alloc_trash()
            trash_particle_id[trash_id] = already_particle_id
            trash_base_parent[trash_id] = parent
            trash_base_geo_center[trash_id] = parent_geo_center
            trash_base_geo_size[trash_id] = parent_geo_size
            already_pos = particle_pos[already_particle_id]
            already_mass = particle_mass[already_particle_id]
            node_weighted_pos[parent] -= already_pos * already_mass
            node_mass[parent] -= already_mass

        node_weighted_pos[parent] += position * mass
        node_mass[parent] += mass

        which = abs(position > parent_geo_center)
        child = node_children[parent, which]
        if child == LEAF:
            child = alloc_node()
            node_children[parent, which] = child
        child_geo_size = parent_geo_size * 0.5
        child_geo_center = parent_geo_center + (which - 0.5) * child_geo_size

        parent_geo_center = child_geo_center
        parent_geo_size = child_geo_size
        parent = child

        depth = depth + 1

    node_particle_id[parent] = particle_id
    node_weighted_pos[parent] = position * mass
    node_mass[parent] = mass


@ti.kernel
def add_particle_at(mx: ti.f32, my: ti.f32, mass: ti.f32):
    mouse_pos = tl.vec(mx, my) + tl.randND(2) * (0.05 / kResolution)

    particle_id = alloc_particle()
    if ti.static(kDim == 2):
        particle_pos[particle_id] = mouse_pos
    else:
        particle_pos[particle_id] = tl.vec(mouse_pos, 0.0)
    particle_mass[particle_id] = mass


@ti.kernel
def add_random_particles(angular_velocity: ti.f32):
    num = ti.static(1)
    particle_id = alloc_particle()
    if ti.static(kDim == 2):
        particle_pos[particle_id] = tl.randSolid2D() * 0.2 + 0.5
        velocity = (particle_pos[particle_id] - 0.5) * angular_velocity * 250
        particle_vel[particle_id] = tl.vec(-velocity.y, velocity.x)
    else:
        particle_pos[particle_id] = tl.randUnit3D() * 0.2 + 0.5
        velocity = (particle_pos[particle_id].xy - 0.5) * angular_velocity * 180
        particle_vel[particle_id] = tl.vec(-velocity.y, velocity.x, 0.0)
    particle_mass[particle_id] = tl.randRange(0.0, 1.5)


@ti.kernel
def build_tree():
    node_table_len[None] = 0
    trash_table_len[None] = 0
    alloc_node()

    particle_id = 0
    while particle_id < particle_table_len[None]:
        alloc_a_node_for_particle(particle_id, 0, particle_pos[0] * 0 + 0.5,
                                  1.0)

        trash_id = 0
        while trash_id < trash_table_len[None]:
            alloc_a_node_for_particle(trash_particle_id[trash_id],
                                      trash_base_parent[trash_id],
                                      trash_base_geo_center[trash_id],
                                      trash_base_geo_size[trash_id])
            trash_id = trash_id + 1

        trash_table_len[None] = 0
        particle_id = particle_id + 1


@ti.func
def gravity_func(distance):
    return tl.normalizePow(distance, -2, 1e-3)


@ti.func
def get_tree_gravity_at(position):
    acc = particle_pos[0] * 0

    trash_table_len[None] = 0
    trash_id = alloc_trash()
    assert trash_id == 0
    trash_base_parent[trash_id] = 0
    trash_base_geo_size[trash_id] = 1.0

    trash_id = 0
    while trash_id < trash_table_len[None]:
        parent = trash_base_parent[trash_id]
        parent_geo_size = trash_base_geo_size[trash_id]

        particle_id = node_particle_id[parent]
        if particle_id >= 0:
            distance = particle_pos[particle_id] - position
            acc += particle_mass[particle_id] * gravity_func(distance)

        else:  # TREE or LEAF
            for which in ti.grouped(ti.ndrange(*([2] * kDim))):
                child = node_children[parent, which]
                if child == LEAF:
                    continue
                node_center = node_weighted_pos[child] / node_mass[child]
                distance = node_center - position
                if distance.norm_sqr() > kShapeFactor**2 * parent_geo_size**2:
                    acc += node_mass[child] * gravity_func(distance)
                else:
                    new_trash_id = alloc_trash()
                    child_geo_size = parent_geo_size * 0.5
                    trash_base_parent[new_trash_id] = child
                    trash_base_geo_size[new_trash_id] = child_geo_size

        trash_id = trash_id + 1

    return acc


@ti.func
def get_raw_gravity_at(pos):
    acc = particle_pos[0] * 0
    for i in range(particle_table_len[None]):
        acc += particle_mass[i] * gravity_func(particle_pos[i] - pos)
    return acc


@ti.kernel
def substep_raw():
    for i in range(particle_table_len[None]):
        acceleration = get_raw_gravity_at(particle_pos[i])
        particle_vel[i] += acceleration * dt
    for i in range(particle_table_len[None]):
        particle_pos[i] += particle_vel[i] * dt


@ti.kernel
def substep_tree():
    particle_id = 0
    while particle_id < particle_table_len[None]:
        acceleration = get_tree_gravity_at(particle_pos[particle_id])
        particle_vel[particle_id] += acceleration * dt
        # well... seems our tree inserter will break if particle out-of-bound:
        particle_vel[particle_id] = tl.boundReflect(particle_pos[particle_id],
                                                    particle_vel[particle_id],
                                                    0, 1)
        particle_id = particle_id + 1
    for i in range(particle_table_len[None]):
        particle_pos[i] += particle_vel[i] * dt


@ti.kernel
def render_arrows(mx: ti.f32, my: ti.f32):
    pos = tl.vec(mx, my)
    acc = get_raw_gravity_at(pos) * 0.001
    tl.paintArrow(display_image, pos, acc, tl.D.yyx)
    acc_tree = get_tree_gravity_at(pos) * 0.001
    tl.paintArrow(display_image, pos, acc_tree, tl.D.yxy)


@ti.kernel
def render_pixels():
    for i in range(particle_table_len[None]):
        position = particle_pos[i].xy
        pix = int(position * kResolution)
        display_image[tl.clamp(pix, 0, kResolution - 1)] += 0.25


def render_tree(gui,
                parent=0,
                parent_geo_center=tl.vec(0.5, 0.5),
                parent_geo_size=1.0):
    child_geo_size = parent_geo_size * 0.5
    if node_particle_id[parent] >= 0:
        tl = parent_geo_center - child_geo_size
        br = parent_geo_center + child_geo_size
        gui.rect(tl, br, radius=1, color=0xff0000)
    for which in map(ti.Vector, [[0, 0], [0, 1], [1, 0], [1, 1]]):
        child = node_children[(parent, which[0], which[1])]
        if child < 0:
            continue
        a = parent_geo_center + (which - 1) * child_geo_size
        b = parent_geo_center + which * child_geo_size
        child_geo_center = parent_geo_center + (which - 0.5) * child_geo_size
        gui.rect(a, b, radius=1, color=0xff0000)
        render_tree(gui, child, child_geo_center, child_geo_size)


if 'cmap' in kDisplay:
    import matplotlib.cm as cm
    cmap = cm.get_cmap('magma')

print('[Hint] Press `r` to add 512 random particles')
print('[Hint] Press `t` to add 512 random particles with angular velocity')
print('[Hint] Drag with mouse left button to add a series of particles')
print('[Hint] Drag with mouse middle button to add zero-mass particles')
print('[Hint] Click mouse right button to add a single particle')
gui = ti.GUI('Tree-code', kResolution)
while gui.running:
    for e in gui.get_events(gui.PRESS):
        if e.key == gui.ESCAPE:
            gui.running = False
        elif e.key == gui.RMB:
            add_particle_at(*gui.get_cursor_pos(), 1.0)
        elif e.key in 'rt':
            if particle_table_len[None] + 512 < kMaxParticles:
                for i in range(512):
                    add_random_particles(e.key == 't')
    if gui.is_pressed(gui.MMB, gui.LMB):
        add_particle_at(*gui.get_cursor_pos(), gui.is_pressed(gui.LMB))

    if kUseTree:
        build_tree()
        substep_tree()
    else:
        substep_raw()
    if len(kDisplay) and 'trace' not in kDisplay:
        display_image.fill(0)
    if 'mouse' in kDisplay:
        render_arrows(*gui.get_cursor_pos())
    if 'pixels' in kDisplay:
        render_pixels()
    if 'cmap' in kDisplay:
        gui.set_image(cmap(display_image.to_numpy()))
    elif len(kDisplay):
        gui.set_image(display_image)
    if 'tree' in kDisplay:
        render_tree(gui)
    if 'pixels' not in kDisplay:
        gui.circles(particle_pos.to_numpy()[:particle_table_len[None]])
    if 'save_result' in kDisplay:
        gui.show(f'{gui.frame:06d}.png')
    else:
        gui.show()

Benchmarking

Testing 8192 particles, 2D:
Raw N-body O(N^2): ~2.1fps
Tree method O(N\log N): ~5.7fps

10 个赞

Beautiful

It reminds me of the three-body problem.