How to create OOP codes when ti.var is required in torch.autograd.Function?

I am trying to put ti_data, it_weight, ti_bias, ti_output, torch_kernel into the class Linear. However, I cannot find a way to do this because LinearFunction needs to call torch_kernel.

import torch
from torch.autograd import gradcheck
import torch.nn.functional as F
from torch.autograd import gradcheck
import math
import torch.nn as nn
import taichi as ti

batch_size = 32
real = ti.f32
ti_data = ti.Vector(6, dt=real, shape=batch_size, needs_grad=True)
it_weight = ti.Vector(6, dt=real, shape=(), needs_grad=True)
it_bias = ti.Vector(1, dt=real, shape=batch_size, needs_grad=True)
it_output = ti.Vector(1, dt=real, shape=batch_size, needs_grad=True)


@ti.kernel
def torch_kernel():
    for i in range(batch_size):
        ti_output[i] = ti.Matrix([ti.dot(it_data[i], ti_weight[None])]) + ti_bias[i]


class LinearFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None):
        ti_data.from_torch(input)
        ti_weight.from_torch(weight)
        ti_bias.from_torch(bias)
        torch_kernel()
        return ti_output.to_torch()

    @staticmethod
    def backward(ctx, grad_output):
        ti.clear_all_gradients()
        grad_input = grad_weight = grad_bias = None
        ti_output.grad.from_torch(grad_output)
        torch_kernel.grad()

        if ctx.needs_input_grad[0]:
            grad_input = ti_data.grad.to_torch(as_vector=True)
        if ctx.needs_input_grad[1]:
            grad_weight = ti_weight.grad.to_torch()
        if ctx.needs_input_grad[2]:
            grad_bias = ti_bias.grad.to_torch(as_vector=True)

        return grad_input, grad_weight, grad_bias


class Linear(nn.Module):
    def __init__(self, input_features, output_features):
        super(Linear, self).__init__()
        self.weight = nn.Parameter(torch.Tensor(input_features, output_features))
        self.bias = nn.Parameter(torch.Tensor(output_features))
        self.weight.data.normal_(0, math.sqrt(2. / output_features / input_features))

    def forward(self, input):
        bias = self.bias.unsqueeze(0).expand(batch_size, 1)
        return LinearFunction.apply(input, self.weight, bias)


data = torch.rand(batch_size, 6, dtype=torch.float32, requires_grad=True)
linear = Linear(6, 1)

test = gradcheck(linear, data, eps=1e-3, atol=1e-4)
print(test)

I tried to put all things into ctx, but it seemed not to remember what the ti.var is.

import torch
from torch.autograd import gradcheck
import torch.nn.functional as F
from torch.autograd import gradcheck
import math
import torch.nn as nn
import taichi as ti

batch_size = 32
real = ti.f32


class LinearFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, data, weight, bias, ti_data, ti_weight, ti_bias, ti_output, kernel):
        ctx.output = ti_output
        ctx.kernel = kernel
        ctx.data = input
        ctx.weight = weight
        ctx.bias = bias
        ti_data.from_torch(data)
        ti_weight.from_torch(weight)
        ti_bias.from_torch(bias)
        kernel()
        return ti_output.to_torch()

    @staticmethod
    def backward(ctx, grad_output):
        ti.clear_all_gradients()
        grad_data = grad_weight = grad_bias = None
        ctx.ti_output.grad.from_torch(grad_output)
        ctx.kernel.grad()

        if ctx.needs_input_grad[0]:
            grad_data = ctx.ti_data.grad.to_torch(as_vector=True)
        if ctx.needs_input_grad[1]:
            grad_weight = ctx.ti_weight.grad.to_torch()
        if ctx.needs_input_grad[2]:
            grad_bias = ctx.ti_bias.grad.to_torch(as_vector=True)

        return grad_data, grad_weight, grad_bias, None, None, None, None, None


class Linear(nn.Module):
    def __init__(self, input_features, output_features):
        super(Linear, self).__init__()
        self.ti_data = ti.Vector(6, dt=real, shape=batch_size, needs_grad=True)
        self.ti_weight = ti.Vector(6, dt=real, shape=(), needs_grad=True)
        self.ti_bias = ti.Vector(1, dt=real, shape=batch_size, needs_grad=True)
        self.ti_output = ti.Vector(1, dt=real, shape=batch_size, needs_grad=True)
        self.weight = nn.Parameter(torch.Tensor(input_features, output_features))
        self.bias = nn.Parameter(torch.Tensor(output_features))
        self.weight.data.normal_(0, math.sqrt(2. / output_features / input_features))

    @ti.classkernel
    def torch_kernel(self):
        for i in range(batch_size):
            self.ti_output[i] = ti.Matrix([ti.dot(self.ti_data[i], self.ti_weight[None])]) + self.ti_bias[i]

    def forward(self, data):
        bias = self.bias.unsqueeze(0).expand(batch_size, 1)
        return LinearFunction.apply(data, self.weight, bias,
                                    self.ti_data, self.ti_weight, self.ti_bias, self.ti_output, self.torch_kernel)


data = torch.rand(batch_size, 6, dtype=torch.float32, requires_grad=True)
linear = Linear(6, 1)

test = gradcheck(linear, data, eps=1e-3, atol=1e-4)
print(test)

This code gives error:

Traceback (most recent call last):
  File "/Users/zhe/PycharmProjects/mpm_learning/linear.py", line 69, in <module>
    test = gradcheck(linear, data, eps=1e-3, atol=1e-4)
  File "/Users/zhe/anaconda3/lib/python3.7/site-packages/torch/autograd/gradcheck.py", line 195, in gradcheck
    analytical, reentrant, correct_grad_sizes = get_analytical_jacobian(tupled_inputs, o)
  File "/Users/zhe/anaconda3/lib/python3.7/site-packages/torch/autograd/gradcheck.py", line 96, in get_analytical_jacobian
    retain_graph=True, allow_unused=True)
  File "/Users/zhe/anaconda3/lib/python3.7/site-packages/torch/autograd/__init__.py", line 145, in grad
    inputs, allow_unused)
  File "/Users/zhe/anaconda3/lib/python3.7/site-packages/torch/autograd/function.py", line 76, in apply
    return self._forward_cls.backward(self, *args)
  File "/Users/zhe/PycharmProjects/mpm_learning/linear.py", line 31, in backward
    ctx.ti_output.grad.from_torch(grad_output)
AttributeError: 'LinearFunctionBackward' object has no attribute 'ti_output'

Process finished with exit code 1

Please keep the identifiers consistent. In LinearFunction.forward

ctx.output = ti_output

should be

ctx.ti_output = ti_output

Here’s a working script. There are quite a few similar bugs, please do a diff and see what’s wrong in the original script.

import torch
from torch.autograd import gradcheck
import torch.nn.functional as F
from torch.autograd import gradcheck
import math
import torch.nn as nn
import taichi as ti

batch_size = 32
real = ti.f32


class LinearFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, data, weight, bias, ti_data, ti_weight, ti_bias, ti_output, kernel):
        ctx.ti_output = ti_output
        ctx.ti_weight = ti_weight
        ctx.ti_bias = ti_bias
        ctx.kernel = kernel
        ctx.data = input
        ctx.weight = weight
        ctx.bias = bias
        ctx.ti_data = ti_data
        ti_data.from_torch(data)
        ti_weight.from_torch(weight)
        ti_bias.from_torch(bias)
        kernel()
        return ti_output.to_torch()

    @staticmethod
    def backward(ctx, grad_output):
        ti.clear_all_gradients()
        grad_data = grad_weight = grad_bias = None
        ctx.ti_output.grad.from_torch(grad_output)
        ctx.kernel(__gradient=True)

        if ctx.needs_input_grad[0]:
            grad_data = ctx.ti_data.grad.to_torch(as_vector=True)
        if ctx.needs_input_grad[1]:
            grad_weight = ctx.ti_weight.grad.to_torch()
        if ctx.needs_input_grad[2]:
            grad_bias = ctx.ti_bias.grad.to_torch(as_vector=True)

        return grad_data, grad_weight, grad_bias, None, None, None, None, None


class Linear(nn.Module):
    def __init__(self, input_features, output_features):
        super(Linear, self).__init__()
        self.ti_data = ti.Vector(6, dt=real, shape=batch_size, needs_grad=True)
        self.ti_weight = ti.Vector(6, dt=real, shape=(), needs_grad=True)
        self.ti_bias = ti.Vector(1, dt=real, shape=batch_size, needs_grad=True)
        self.ti_output = ti.Vector(1, dt=real, shape=batch_size, needs_grad=True)
        self.weight = nn.Parameter(torch.Tensor(input_features, output_features))
        self.bias = nn.Parameter(torch.Tensor(output_features))
        self.weight.data.normal_(0, math.sqrt(2. / output_features / input_features))

    @ti.classkernel
    def torch_kernel(self):
        for i in range(batch_size):
            self.ti_output[i] = ti.Matrix([ti.dot(self.ti_data[i], self.ti_weight[None])]) + self.ti_bias[i]

    def forward(self, data):
        bias = self.bias.unsqueeze(0).expand(batch_size, 1)
        return LinearFunction.apply(data, self.weight, bias,
                                    self.ti_data, self.ti_weight, self.ti_bias, self.ti_output, self.torch_kernel)


data = torch.rand(batch_size, 6, dtype=torch.float32, requires_grad=True)
linear = Linear(6, 1)

test = gradcheck(linear, data, eps=1e-3, atol=1e-4)
print(test)

You may also want to check out https://taichi.readthedocs.io/en/latest/odop.html to learn about “__gradient=True”.