DIY一个Taichi里的点乘函数

学习了一下文档以后发现ti.field是没有点乘运算的,但是在实际编写模拟的时候,经常可能会用到2D Field量A (nxn) 点乘 向量x (nx1)的情况,所以想把它做成一个函数。(如果这种做法是多余的请大佬指出)

但是着手编写的时候才发现,由于Taichi必须要在初始化的时候定义所有变量,不能在函数scope中临时声明变量,于是就只能写出了这样的版本。。:

import taichi as ti

ti.init()

n = 10

A = ti.field(dtype=ti.f64, shape=(n,n))
x = ti.field(dtype=ti.f64, shape=n)
Ax = ti.field(dtype=ti.f64, shape=n)

# 想象中的版本
@ti.func
def dot(A,x):
    Ax = ti.field(dtype=ti.f64, shape=n)
    for i in range(A.shape[0]):
        for j in range(A.shape[1]):
            Ax[i] += A[i,j] * x[j]
    return Ax
    

# 实际的版本
@ti.func
def dot(A,x,Ax):
    for i in range(A.shape[0]):
        for j in range(A.shape[1]):
            Ax[i] += A[i,j] * x[j]

@ti.kernel
def main():
    init()
    dot(A,x,Ax)
    for i in Ax:
        print(Ax[i])

main()

如上所示,因为必须要事先声明存储点乘结果的field Ax,所以只能被迫写成了把Ax作为一个参数,通过修改这个参数来保存结果的目的。

这样的写法虽然不是不能用,但是每次遇到做点乘都要预先声明一个用来保存结果的field,次数多了以后就会有一大堆仅仅用来保存点乘结果的field堵在代码的最上方感觉太笨重了。况且还不能这样写:

# 假设r和b也是n维向量, 且假装Taichi可以直接做向量加减
r = b - dot(A,x)

所以请大佬不吝指点。谢谢!

如果矩阵/向量比较小的话可以考虑用 Matrix/Vector。
(大矩阵操作也可以考虑传到 python scope 下用别的库?

1 个赞

转到numpy的话肯定就可以很轻松做了是没错,但是我猜一般不太应该频繁进行大量这种运算把。而且点乘一个field和一个向量应该是模拟计算典型会遇到的场景,不知道开发者有没有想过把这种运算集合到语言里去。@yuanming @archibate

确实,不能动态创建field有时会很不方便。开了个issue,欢迎大家讨论:

1 个赞