Breaking the kernel simplicity rule

Hi I am trying to put this kernel in a nn.module class with a set of torch.autograd.Function stuff. I am aware that this may not work with auto diff.

@ti.kernel
    def class_kernel(self):
        for i, j in ti.ndrange(self.grid_size[0]-2, self.grid_size[1]-2):
           self.a[i,j] = ...
    
        for i, j in ti.ndrange(self.grid_size[0]-2, self.grid_size[1]-2):
           self.b[i,j] = ...
           ...

RuntimeError: [reverse_segments.cpp:reverse_segments@64] Invalid program input for autodiff. Please check the documentation for the “Kernel Simplicity Rule”.

Then I changed these into two separate class kernels and call them in autograd forward & backward, but the same problem exists. Could you give me a hint on how to fix this? thx!
Complete ver: https://codeshare.io/5OVBy0

Hi @143, could you move Ln 81 _eps... into the loop body? Sorry about the confusion.

1 Like