Graident mismatch with pytorch for atomic division

There is a potential bug in computing the gradient of atomic division. When running the following code, the first assersion can pass (forward computation matches with pytorch), but the second numpy assersion does not pass (backward computation of atomic division might be problematic). Or am I using the wrong way to compute gradients?

import torch
import taichi as ti
import numpy as np


b = 5
n = 10
x = ti.field(dtype=ti.f32, shape=(b, n), needs_grad=True)
y = ti.field(dtype=ti.f32, shape=(b, n), needs_grad=True)
loss = ti.field(dtype=ti.f32, shape=(), needs_grad=True)

def divide():
    for IDX in ti.grouped(x):
        y[IDX] /= x[IDX]

def reduce():
    for IDX in ti.grouped(y):
        loss[None] += y[IDX]

def main():
    x_torch = torch.randn((b, n)).requires_grad_()
    y_torch = torch.randn((b, n))

    with ti.Tape(loss):

    z_torch = y_torch / x_torch
    loss_torch = (z_torch).sum()

    assert abs(loss.to_torch().item() - loss_torch.item()) < 1e-5
        x_torch.grad.squeeze(-1).numpy(), x.grad.to_numpy(), decimal=5

if __name__ == "__main__":

Example Output:

[Taichi] mode=release
[Taichi] version 0.7.10, llvm 10.0.0, commit 0f0205fc, win, python 3.7.4
[I 12/29/20 19:53:08.900] [] Graphical python shell detected, using wrapped sys.stdout
[Taichi] Starting on arch=x64
[Taichi] materializing...
-50.8487663269043 -50.8487663269043
Traceback (most recent call last):
  File "", line 48, in <module>
  File "", line 43, in main
    x_torch.grad.squeeze(-1).numpy(), x.grad.to_numpy(), decimal=5
  File "C:\Users\jitang\AppData\Local\Programs\Python\Python37\lib\site-packages\numpy\testing\_private\", line 1047, in assert_array_almost_equal
  File "C:\Users\jitang\AppData\Local\Programs\Python\Python37\lib\site-packages\numpy\testing\_private\", line 846, in assert_array_compare
    raise AssertionError(msg)
Arrays are not almost equal to 5 decimals

Mismatched elements: 50 / 50 (100%)
Max absolute difference: 77790.13
Max relative difference: 2.5147126
 x: array([[ 3.76420e+00, -5.43730e-01,  6.64118e-01,  1.16813e+00,
         4.57189e+00,  3.06695e+00, -7.56317e+00,  2.69717e+00,
         6.72860e-01,  1.03529e+01],...
 y: array([[-1.21419e+01, -3.10191e-01,  9.97835e-01, -8.31697e-01,
         9.12188e+00, -4.85009e+00, -2.99018e+01,  4.39251e+00,
        -5.45994e-01, -3.62938e+01],...
1 Like