Get/Set matrix column inside Taichi scope

What is the best way to access and modify a column of a matrix inside a taichi kernel?

I wrote some helper function for this task:

    @ti.func
    def getCol(mat, idx):
        ret = ti.Vector([0.,0.,0.])
        for i in ti.static(range(3)):
            ret[i] = mat[i, idx_by_value__]
        return ret

    @ti.func
    def setCol(mat, idx, vec):
        assert mat.n == len(vec)
        for i in ti.static(range(3)):
            mat[i, idx_by_value__] = vec[i]
        return mat

Here is the test script:

        M = ti.Matrix.field(3, 3, dtype=ti.f32, shape=(1,1))
        M[0,0] = ti.Matrix([
                    [0,1,2],
                    [3,4,5],
                    [6,7,8]
                ])

        @ti.kernel
        def test():
            print(getCol(M[0,0], 1))
            print(setCol(M[0,0], 2, [8,5,2]))
        test()

If I use the parameter ‘idx’ directly, it will catch an error that says “The 1-th index of a Matrix/Vector must be a compile-time constant integer, got <class ‘taichi.lang.expr.Expr’>.”

However, with this magic variable ‘{paramName}by_value_’, these two function works as expected.
I know it can’t be a long-term solution, but are there other ways to do this?

Thank you!

Welcome to Taichi community @Amos.

This is a really good question.

The problem is Taichi only support dynamic indexing in very limited cases. There are some developers who are supporting this feature. For more info about dynamic index, please see here.

In the near future, you could use code as below:

import taichi as ti
ti.init(arch=ti.cpu, dynamic_index=True) #you don't need to use `ti.static` anymore if you enable `dynamic_index`

@ti.func
def getCol(mat, idx):
    ret = ti.Vector([0.,0.,0.])
    for i in range(3):
        ret[i] = mat[i, idx]
    return ret

@ti.func
def setCol(mat, idx, vec):
    assert mat.n == len(vec)
    for i in range(3):
        mat[i, idx] = vec[i]
    return mat

M = ti.Matrix.field(3, 3, dtype=ti.f32, shape=(1,1))
M[0,0] = ti.Matrix([[0,1,2],[3,4,5],[6,7,8]])

@ti.kernel
def test():
    print(getCol(M[0,0], 1))
    print(setCol(M[0,0], 2, [8,5,2]))

test()