2023-03-09 02:05:04 +08:00
|
|
|
# this is an example of how you can write terrible DSP compute breaking ops like warpPerspective
|
|
|
|
# here we use a CUSTOM op to write atan2
|
|
|
|
|
|
|
|
import unittest
|
|
|
|
import numpy as np
|
2023-03-09 04:22:11 +08:00
|
|
|
from typing import Optional, Tuple
|
2024-01-02 06:58:48 +08:00
|
|
|
from tinygrad.helpers import prod
|
|
|
|
from tinygrad.dtype import dtypes
|
2023-03-09 02:05:04 +08:00
|
|
|
|
|
|
|
# *** first, we implement the atan2 op at the lowest level ***
|
2023-03-09 04:22:11 +08:00
|
|
|
# `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers
|
2023-12-01 09:07:16 +08:00
|
|
|
from tinygrad.lazy import Buffer, create_lazybuffer
|
2023-11-28 03:34:37 +08:00
|
|
|
from tinygrad.device import CompiledASTRunner, Device
|
2023-06-27 04:55:42 +08:00
|
|
|
from tinygrad.shape.shapetracker import ShapeTracker
|
2023-03-09 02:05:04 +08:00
|
|
|
|
2023-03-09 04:22:11 +08:00
|
|
|
# we don't always have GPU support, so the type signature is the abstract CompiledBuffer instead of GPUBuffer
|
2023-12-01 09:07:16 +08:00
|
|
|
def atan2_gpu(ret:Buffer, a:Buffer, b:Buffer):
|
2023-03-11 08:56:07 +08:00
|
|
|
assert a.dtype == b.dtype and a.dtype == dtypes.float32, "gpu function only supports float32"
|
2024-01-02 09:39:26 +08:00
|
|
|
src = """
|
|
|
|
__kernel void atan2_gpu(global float *c, global float *a, global float *b) {
|
|
|
|
int idx = get_global_id(0);
|
|
|
|
c[idx] = atan2(a[idx], b[idx]);
|
|
|
|
}"""
|
2024-01-24 10:45:43 +08:00
|
|
|
CompiledASTRunner(None, "atan2_gpu", src, Device[ret.device], global_size=[ret.size]).exec([ret, a, b])
|
2023-03-09 02:05:04 +08:00
|
|
|
|
2023-12-01 09:07:16 +08:00
|
|
|
def atan2_cpu(ret:Buffer, a:Buffer, b:Buffer): ret.copyin(np.require(np.arctan2(a._buf, b._buf), requirements='C').data)
|
2023-03-09 02:05:04 +08:00
|
|
|
|
|
|
|
# *** second, we write the ATan2 mlop ***
|
|
|
|
# NOTE: The derivative of atan2 doesn't need a custom op! https://www.liquisearch.com/atan2/derivative
|
|
|
|
# In general, it is also optional to write a backward function, just your backward pass won't work without it
|
|
|
|
|
2023-12-21 06:33:21 +08:00
|
|
|
from tinygrad.ops import LoadOps, BinaryOps
|
2023-03-09 02:05:04 +08:00
|
|
|
from tinygrad.lazy import LazyBuffer
|
|
|
|
from tinygrad.tensor import Function
|
|
|
|
|
|
|
|
class ATan2(Function):
|
|
|
|
def forward(self, a:LazyBuffer, b:LazyBuffer) -> LazyBuffer:
|
2023-03-09 04:22:11 +08:00
|
|
|
assert prod(a.shape) == prod(b.shape) and a.device == b.device, "shape or device mismatch"
|
2023-03-09 02:05:04 +08:00
|
|
|
self.a, self.b = a, b
|
2023-12-21 06:33:21 +08:00
|
|
|
return create_lazybuffer(a.device, ShapeTracker.from_shape(a.shape), max(a.dtype, b.dtype), LoadOps.CUSTOM,
|
|
|
|
arg={"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device], srcs=(a.contiguous(), b.contiguous()))
|
2023-03-09 04:22:11 +08:00
|
|
|
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
|
2023-08-23 12:01:10 +08:00
|
|
|
denom = (self.a.e(BinaryOps.MUL, self.a)).e(BinaryOps.ADD, self.b.e(BinaryOps.MUL, self.b))
|
|
|
|
return grad_output.e(BinaryOps.MUL, self.b.e(BinaryOps.DIV, denom)) if self.needs_input_grad[0] else None, \
|
|
|
|
grad_output.e(BinaryOps.MUL, self.a.const(0).e(BinaryOps.SUB, self.a).e(BinaryOps.DIV, denom)) if self.needs_input_grad[1] else None
|
2023-03-09 02:05:04 +08:00
|
|
|
|
|
|
|
# *** third, we use our lovely new mlop in some tests ***
|
|
|
|
|
2023-07-24 23:19:58 +08:00
|
|
|
from tinygrad.tensor import Tensor
|
2023-03-09 02:05:04 +08:00
|
|
|
|
|
|
|
@unittest.skipUnless(Device.DEFAULT in ["CPU", "GPU"], "atan2 is only implemented for CPU and GPU")
|
|
|
|
class TestCustomFunction(unittest.TestCase):
|
|
|
|
def test_atan2_forward(self):
|
|
|
|
# create some random Tensors, permute them just because we can
|
|
|
|
a = Tensor.randn(4,4,requires_grad=True).permute(1,0)
|
|
|
|
b = Tensor.randn(4,4,requires_grad=True).permute(1,0)
|
|
|
|
|
|
|
|
# run the forward pass. note: up until the .numpy(), it's all lazy
|
|
|
|
c = ATan2.apply(a, b)
|
|
|
|
print(c.numpy())
|
|
|
|
|
|
|
|
# check the forward pass (in numpy)
|
|
|
|
np.testing.assert_allclose(c.numpy(), np.arctan2(a.numpy(), b.numpy()), atol=1e-5)
|
|
|
|
|
|
|
|
# fun fact, this never actually calls forward, so it works in all the backends
|
|
|
|
def test_atan2_backward(self):
|
|
|
|
# have to go forward before we can go backward
|
|
|
|
a = Tensor.randn(4,4,requires_grad=True).permute(1,0)
|
|
|
|
b = Tensor.randn(4,4,requires_grad=True).permute(1,0)
|
|
|
|
c = ATan2.apply(a, b)
|
|
|
|
|
|
|
|
# run the backward pass
|
|
|
|
c.mean().backward()
|
|
|
|
assert a.grad is not None and b.grad is not None, "tinygrad didn't compute gradients"
|
|
|
|
print(a.grad.numpy())
|
|
|
|
print(b.grad.numpy())
|
|
|
|
|
|
|
|
# check the backward pass (in torch)
|
|
|
|
import torch
|
|
|
|
ta, tb = torch.tensor(a.numpy(), requires_grad=True), torch.tensor(b.numpy(), requires_grad=True)
|
|
|
|
tc = torch.atan2(ta, tb)
|
|
|
|
tc.mean().backward()
|
|
|
|
assert ta.grad is not None and tb.grad is not None, "torch didn't compute gradients"
|
|
|
|
np.testing.assert_allclose(a.grad.numpy(), ta.grad.numpy(), atol=1e-5)
|
|
|
|
np.testing.assert_allclose(b.grad.numpy(), tb.grad.numpy(), atol=1e-5)
|
|
|
|
|
2023-11-16 03:13:38 +08:00
|
|
|
@unittest.skipIf(Device.DEFAULT in ["CPU"], "atan2_cpu not jittable")
|
2023-03-09 02:05:04 +08:00
|
|
|
def test_atan2_jit(self):
|
|
|
|
# custom ops even work in the JIT!
|
2024-02-13 00:34:34 +08:00
|
|
|
from tinygrad.features.jit import TinyJit
|
2023-03-09 02:05:04 +08:00
|
|
|
|
|
|
|
@TinyJit
|
|
|
|
def jitted_atan2(a:Tensor, b:Tensor) -> Tensor:
|
|
|
|
return ATan2.apply(a, b).realize()
|
|
|
|
|
|
|
|
for _ in range(5):
|
|
|
|
a = Tensor.randn(4,4,requires_grad=True).permute(1,0)
|
|
|
|
b = Tensor.randn(4,4,requires_grad=True).permute(1,0)
|
|
|
|
c = jitted_atan2(a, b)
|
|
|
|
np.testing.assert_allclose(c.numpy(), np.arctan2(a.numpy(), b.numpy()), atol=1e-5)
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
unittest.main()
|