Homework 0: Spores(孢子)| 发现了 taichi 编译器的性能瓶颈

完整的程序从开始执行到第一帧画出来,用了 20 min。。。

感觉可以当 taichi 编译器的编译耗时测试项目了。


注意事项

  • 以下代码画出第一帧耗时应该 < 5s。
    因为我注释了很多代码。
  • 以下代码只保留了一行给孢子加突起的语句,即
    第128行的:d = min(d, placedBarrel(pos, 0., 0.))
  • 20 min 的耗时是把所有注释都取消掉,即有 22 句 placedBarrel。
  • 对自己的 CPU 有信心的可以多取消点注释
  • 默认只循环 1000 次,因为这个 shader 循环时间比较短,转gif不需要那么多帧
    不输出 img 可以调大一点

Spores.py

import taichi as ti
import time

# ti.init(debug=True, arch=ti.cpu)
ti.init(arch=ti.gpu)

"""
spores - by: mprice
https://www.shadertoy.com/view/4lsXWj
"""
GUI_TITLE = "Spores"
w, h = wh = (640, 360)
# w, h = wh = (360, 640)
pixels = ti.Vector(3, dt=ti.f32, shape=wh)
iResolution = ti.Vector([w, h])

MAX_ITER = 100
MAX_DIST = 20.0
EPSILON = 0.001
PI = 3.14159265

HIT_HOLE = 0
HIT_BARREL = 1
flag = ti.var(ti.i32, shape=2)
flag[HIT_HOLE] = False
flag[HIT_BARREL] = False

## Shader help func
@ti.func
def mix(x, y, a: ti.f32):
    """
    [The Book of Shaders: mix](https://thebookofshaders.com/glossary/?search=mix)
    """
    return x*(1-a) + y*a

@ti.func
def clamp(x, minVal, maxVal):
    """
    [The Book of Shaders: clamp](https://thebookofshaders.com/glossary/?search=clamp)
    """
    return min(max(x, minVal), maxVal)

@ti.func
def smoothstep(edge0, edge1, x):
    """
    [The Book of Shaders: smoothstep](https://thebookofshaders.com/glossary/?search=smoothstep)
    """
    t = clamp((x - edge0) / (edge1 - edge0), 0.0, 1.0)
    return t * t * (3.0 - 2.0 * t)


@ti.func
def rotateX(p: ti.Vector, ang: ti.f32) -> ti.Vector:
    rmat: ti.Matrix = ti.Matrix([
        [1.0,  0.0,          0.0        ],
        [0.0,  ti.cos(ang), -ti.sin(ang)],
        [0.0,  ti.sin(ang),  ti.cos(ang)]
    ])
    return rmat @ p


@ti.func
def rotateY(p: ti.Vector, ang: ti.f32) -> ti.Vector:
    rmat: ti.Matrix = ti.Matrix([
        [ ti.cos(ang),  0.0,  ti.sin(ang)],
        [ 0.0,          1.0,  0.0        ],
        [-ti.sin(ang),  0.0,  ti.cos(ang)]
    ])
    return rmat @ p


@ti.func
def rotateZ(p: ti.Vector, ang: ti.f32) -> ti.Vector:
    rmat: ti.Matrix = ti.Matrix([
        [ti.cos(ang), -ti.sin(ang), 0.0],
        [ti.sin(ang),  ti.cos(ang), 0.0],
        [0.0,          0.0,         1.0]
    ])
    return rmat @ p


@ti.func
def sphere(pos: ti.Vector, r: ti.f32) -> ti.f32:
    return pos.norm() - r


@ti.func
def barrel(pos: ti.Vector) -> ti.f32:
    d: ti.f32 = sphere(pos, 0.5)
    pos[1] += 0.5
    holed: ti.f32 = -sphere(pos, 0.25)
    d = max(d, holed)
    if holed == d:
        flag[HIT_HOLE] = True
    else:
        flag[HIT_HOLE] = flag[HIT_HOLE]
    
    return d


@ti.func
def placedBarrel(
    pos: ti.Vector, 
    rx: ti.f32, 
    ry: ti.f32
) -> ti.f32:
    pos = rotateY(pos, ry)
    pos = rotateX(pos, rx)
    pos[1] += 2.0
    return barrel(pos)


@ti.func
def distfunc(iTime: ti.f32, pos: ti.Vector) -> ti.f32:
    pos += ti.Vector([iTime, iTime, iTime])
    c = ti.Vector([10.0, 10.0, 10.0])
    pos = (pos % c) - 0.5*c
    pos = rotateX(pos, iTime)

    flag[HIT_HOLE] = False
    flag[HIT_BARREL] = False

    # Any of you smart people have a domain transformation way to
    # do a rotational tiling effect instead of this? :)
    sphered: ti.f32 = sphere(pos, 2.0)
    d: ti.f32 = sphered
    ## 下面是给球上添加突起
    d = min(d, placedBarrel(pos, 0., 0.))
    # d = min(d, placedBarrel(pos, 0.8, 0.))
    # d = min(d, placedBarrel(pos, 1.6, 0.))
    # d = min(d, placedBarrel(pos, 2.4, 0.))
    # d = min(d, placedBarrel(pos, 3.2, 0.))
    # d = min(d, placedBarrel(pos, 4.0, 0.))
    # d = min(d, placedBarrel(pos, 4.8, 0.))
    # d = min(d, placedBarrel(pos, 5.6, 0.))

    # d = min(d, placedBarrel(pos, 0.8, PI / 2.0))
    # d = min(d, placedBarrel(pos, 1.6, PI / 2.0))
    # d = min(d, placedBarrel(pos, 2.4, PI / 2.0))
    # d = min(d, placedBarrel(pos, 4.0, PI / 2.0))
    # d = min(d, placedBarrel(pos, 4.8, PI / 2.0))
    # d = min(d, placedBarrel(pos, 5.6, PI / 2.0))
    # d = min(d, placedBarrel(pos, 1.2, PI / 4.0))
    # d = min(d, placedBarrel(pos, 2.0, PI / 4.0))

    # d = min(d, placedBarrel(pos, 1.2, 3.0 * PI / 4.0))
    # d = min(d, placedBarrel(pos, 2.0, 3.0 * PI / 4.0))
    # d = min(d, placedBarrel(pos, 1.2, 5.0 * PI / 4.0))
    # d = min(d, placedBarrel(pos, 2.0, 5.0 * PI / 4.0))
    # d = min(d, placedBarrel(pos, 1.2, 7.0 * PI / 4.0))
    # d = min(d, placedBarrel(pos, 2.0, 7.0 * PI / 4.0))

    flag[HIT_BARREL] = (d != sphered)

    return d


@ti.func
def mainImage(
    iMouse: ti.Vector, 
    iTime: ti.f32,
    i: ti.i32, 
    j: ti.i32
) -> ti.Vector:
    fragCoord = ti.Vector([i, j])
    fragColor = ti.Vector([0.0, 0.0, 0.0])

    m_x: ti.i32 = (iMouse[0] / iResolution[0]) - 0.5
    m_y: ti.i32 = (iMouse[1] / iResolution[1]) - 0.5

    ## vec3
    cameraOrigin = ti.Vector([
        5.0 * ti.sin(m_x * PI * 2.), 
        m_y * 15.0, 
        5.0 * ti.cos(m_x * PI * 2.)
    ])
    cameraTarget = ti.Vector([0.0, 0.0, 0.0])
    upDirection  = ti.Vector([0.0, 1.0, 0.0])
    cameraDir    = (cameraTarget - cameraOrigin).normalized()
    cameraRight  = upDirection.cross(cameraOrigin).normalized()
    cameraUp     = cameraDir.cross(cameraRight)
    # TODO: check (gl_FragCoord.xy == fragCoord)
    screenPos    = -1.0 + 2.0 * fragCoord / iResolution
    screenPos[0] *= iResolution[0] / iResolution[1]
    rayDir = (
        cameraRight * screenPos[0] \
        + cameraUp  * screenPos[1] \
        + cameraDir
    ).normalized()

    pos: ti.Vector = cameraOrigin
    totalDist = 0.0
    dist = EPSILON

    for _ in range(MAX_ITER):
        if (dist < EPSILON) or (totalDist > MAX_DIST):
            break
        
        dist = distfunc(iTime, pos)
        totalDist += dist
        pos += dist * rayDir
    # for i in range(MAX_ITER) END

    if (dist < EPSILON):
        eps = ti.Vector([0.0, EPSILON])
        eps_yxx = ti.Vector([EPSILON, 0.0, 0.0])
        eps_xyx = ti.Vector([0.0, EPSILON, 0.0])
        eps_xxy = ti.Vector([0.0, 0.0, EPSILON])
        normal = ti.Vector([
            distfunc(iTime, pos + eps_yxx) - distfunc(iTime, pos - eps_yxx),
            distfunc(iTime, pos + eps_xyx) - distfunc(iTime, pos - eps_xyx),
            distfunc(iTime, pos + eps_xxy) - distfunc(iTime, pos - eps_xxy)
        ]).normalized()
        lightdir = ti.Vector([1.0, -1.0, 0.0]).normalized()
        diffuse = max(0.2, lightdir.dot(normal))
        # tc = vec2(pos[0], pos.z)
        # texcol = texture(iChannel0, tc).rgb

        lightcol = ti.Vector([1.0, 1.0, 1.0])
        darkcol  = ti.Vector([0.4, 0.8, 0.9])
        sma = 0.4
        smb = 0.6

        if (flag[HIT_HOLE]): 
            lightcol = ti.Vector([1.0, 1.0, 0.8])
        elif flag[HIT_BARREL]:
            lightcol[0] = 0.95
        else:
            sma = 0.2
            smb = 0.3
        # if (HIT_HOLE) END

        facingRatio = smoothstep(sma, smb, abs(normal.dot(rayDir)))
        illumcol    = mix(lightcol, darkcol, 1.0 - facingRatio)
        fragColor   = illumcol
    
    else:  # dist >= EPSILON
        strp: ti.f32 = smoothstep(
            0.8, 0.9, 
            (screenPos[1] * 10. + iTime) % 1
        )
        fragColor = mix(
            ti.Vector([1.0, 1.0, 1.0]), 
            ti.Vector([0.4, 0.8, 0.9]), 
            strp
        )
    # if (dist <=> EPSILON) END
    
    return fragColor


@ti.kernel
def render(t: ti.f32):
    iMouse = gui.get_cursor_pos()
    for i, j in pixels:
        pixels[i, j] = mainImage(iMouse, t, i, j)

    return


gui = ti.GUI(GUI_TITLE, res=wh)
def main(output_img=False):
    print(time.strftime("%H:%M:%S, ", time.localtime()), end='')
    for ts in range(1000):
        if gui.get_event(ti.GUI.ESCAPE):
            exit()

        render(ts * 0.03)
        gui.set_image(pixels.to_numpy())
        if output_img:
            gui.show(f'frame/{ts:04d}.png')
        else:
            gui.show()
        if ts == 0:
            print(time.strftime("%H:%M:%S", time.localtime()))


if __name__ == '__main__':
    # main(output_img=True)
    main()

效果
out


在我的电脑上 placedBarrel 语句个数与第一帧出现耗时:

image

原始统计数据

taichi compiler.csv

CPU(debug=True), Compiling kernel render,,, Total compiler time
Barrel num, start, end,       dt(s), start,    end,     dt(s), 
 0, 15:55:22.294, 15:55:23.474,   1, 15:55:22, 15:55:24, 2
 1, 15:56:16.334, 15:56:18.891,   3, 15:56:16, 15:56:19, 3
 2, 16:00:14.492, 16:00:21.029,   7, 16:00:14, 16:00:21, 7
 3, 16:00:45.970, 16:00:58.003,  13, 16:00:45, 16:00:59, 14
 4, 16:01:31.231, 16:01:58.038,  27, 16:01:31, 16:01:59, 28
 5, 16:28:22.676, 16:28:51.803,  28, 16:28:22, 16:28:53, 31
 6, 16:30:05.657, 16:30:50.708,  45, 16:30:05, 16:30:53, 48
 7, 16:31:50.737, 16:32:51.709,  61, 16:31:50, 16:32:54, 64
 8, 16:33:48.472, 16:35:13.123,  85, 16:33:48, 16:35:16, 88
 9, 16:19:07.563, 16:21:02.287, 114, 16:19:07, 16:21:06, 119
10, 16:21:44.158, 16:24:12.995, 149, 16:21:44, 16:24:18, 154
11, 16:24:40.723, 16:27:47.523, 187, 16:24:40, 16:27:53, 193

GPU,, Total compiler time
Barrel num, avg. fps, start, end, dt(s)
 0, 35fps, 15:20:35, 15:20:41,  6
 1, 30fps, 15:21:29, 15:21:38,  9
 2, 31fps, 15:24:20, 15:24:36,  16
 3, 29fps, 15:25:18, 15:25:43,  25
 4, 31fps, 15:27:22, 15:27:50,  28
 5, 30fps, 15:28:34, 15:29:19,  45
 6, 30fps, 15:30:12, 15:31:16,  64
 7, 29fps, 15:32:09, 15:33:32,  83
 8, 28fps, 15:34:52, 15:36:51, 119
 9, 29fps, 15:37:50, 15:40:33, 163
10, 27fps, 15:41:53, 15:46:03, 250
11, 25fps, 15:47:58, 15:52:30, 272

22, fps, 16:38:07, 16:57:06, 1139
3 个赞

看那个对数图,有点近似直线了。

还没去翻源代码,盲猜一波:

是不是编译器缓存没做好。然后这几条语句都是每一句依赖前一句的结果,每次都需要重复计算。
最后复杂度就成指数了。


根据群里 dalao 的指导

在开头加一句 ti.core.toggle_advanced_optimization(False) 可以减少编译时间

行数, start, end, dt(s)
 8, 19:26:37, 19:27:21, 44
16, 19:35:40, 19:37:53, 133
22, 19:51:24, 19:55:48, 264
1 个赞

这个画风好可爱…

你说得对,现在advanced optimizations的性能还没有优化好,会导致一些程序编译时间比较长。mingkuan同学正在辛苦的优化, e.g. https://github.com/taichi-dev/taichi/issues/1059

1 个赞

这个我在着色器网站上看见过

就是 shadertoy 上的