From e441794c4b95405efa54d5bef5adf1d5eb2636f2 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 11 Oct 2024 14:31:09 +0800 Subject: [PATCH] remove custom op support, we waste time maintaining this (#6991) * remove custom op support, we waste time maintaining this * customop is over --- .pre-commit-config.yaml | 4 +- test/external/fuzz_schedule.py | 3 +- test/test_custom_function.py | 105 --------------------------------- tinygrad/engine/realize.py | 7 --- tinygrad/engine/schedule.py | 2 +- tinygrad/ops.py | 3 +- 6 files changed, 5 insertions(+), 119 deletions(-) delete mode 100644 test/test_custom_function.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c6e967dd..7a01a3f7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,13 +27,13 @@ repos: pass_filenames: false - id: devicetests name: select GPU tests - entry: env GPU=1 PYTHONPATH="." pytest test/test_uops.py test/test_custom_function.py test/test_search.py + entry: env GPU=1 PYTHONPATH="." pytest test/test_uops.py test/test_search.py language: system always_run: true pass_filenames: false - id: tests name: subset of tests - entry: env PYTHONPATH="." python3 -m pytest -n=4 test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_custom_function.py test/test_assign.py test/test_symbolic_shapetracker.py + entry: env PYTHONPATH="." python3 -m pytest -n=4 test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_assign.py test/test_symbolic_shapetracker.py language: system always_run: true pass_filenames: false diff --git a/test/external/fuzz_schedule.py b/test/external/fuzz_schedule.py index 776f2672..b5a7f709 100644 --- a/test/external/fuzz_schedule.py +++ b/test/external/fuzz_schedule.py @@ -2,7 +2,7 @@ import itertools import numpy as np from typing import DefaultDict, Dict, List, Set, Tuple, TypeVar, Union from tinygrad.device import Buffer -from tinygrad.engine.realize import CustomOp, capturing, lower_schedule_item +from tinygrad.engine.realize import capturing, lower_schedule_item from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv from tinygrad.engine.lazy import LazyBuffer from tinygrad.engine.schedule import LBScheduleItem, _graph_schedule, ScheduleItem @@ -72,7 +72,6 @@ def fuzz_schedule(outs:List[LazyBuffer]): def _exec_si(si:ScheduleItem, seed:int): ei = lower_schedule_item(si) if len(capturing): capturing[0].add(ei) - if isinstance(ei.prg, CustomOp): Tensor._seed = seed ei.run() T = TypeVar("T") diff --git a/test/test_custom_function.py b/test/test_custom_function.py deleted file mode 100644 index e64d544d..00000000 --- a/test/test_custom_function.py +++ /dev/null @@ -1,105 +0,0 @@ -# this is an example of how you can write terrible DSP compute breaking ops like warpPerspective -# here we use a CUSTOM op to write atan2 - -import unittest -import numpy as np -from typing import Optional, Tuple -from tinygrad.helpers import prod -from tinygrad.dtype import dtypes - -# *** first, we implement the atan2 op at the lowest level *** -# `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers -from tinygrad.engine.lazy import Buffer, create_lazybuffer -from tinygrad.device import Device -from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.engine.realize import CompiledRunner -from tinygrad.renderer import Program - -# we don't always have GPU support, so the type signature is the abstract CompiledBuffer instead of GPUBuffer -def atan2_gpu(ret:Buffer, a:Buffer, b:Buffer): - assert a.dtype == b.dtype and a.dtype == dtypes.float32, "gpu function only supports float32" - src = """ - __kernel void atan2_gpu(global float *c, global float *a, global float *b) { - int idx = get_global_id(0); - c[idx] = atan2(a[idx], b[idx]); - }""" - CompiledRunner(Program("atan2_gpu", src, ret.device, global_size=[ret.size,1,1])).exec([ret, a, b]) - -def atan2_cpu(ret:Buffer, a:Buffer, b:Buffer): ret.copyin(np.require(np.arctan2(a._buf, b._buf), requirements='C').data) - -# *** second, we write the ATan2 mlop *** -# NOTE: The derivative of atan2 doesn't need a custom op! https://www.liquisearch.com/atan2/derivative -# In general, it is also optional to write a backward function, just your backward pass won't work without it - -from tinygrad.ops import MetaOps -from tinygrad.engine.lazy import LazyBuffer -from tinygrad.tensor import Function - -class ATan2(Function): - def forward(self, a:LazyBuffer, b:LazyBuffer) -> LazyBuffer: - assert prod(a.shape) == prod(b.shape) and a.device == b.device, "shape or device mismatch" - self.a, self.b = a, b - return create_lazybuffer(a.device, ShapeTracker.from_shape(a.shape), max(a.dtype, b.dtype), MetaOps.CUSTOM, - arg={"GPU": atan2_gpu, "CPU": atan2_cpu}[a.device], srcs=(a.contiguous(), b.contiguous())) - def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: - recip = (self.a * self.a + self.b * self.b).recip() - return (grad_output * self.b * recip) if self.needs_input_grad[0] else None, \ - (grad_output * -self.a * recip) if self.needs_input_grad[1] else None - -# *** third, we use our lovely new mlop in some tests *** - -from tinygrad.tensor import Tensor - -@unittest.skipUnless(Device.DEFAULT in ["CPU", "GPU"], "atan2 is only implemented for CPU and GPU") -class TestCustomFunction(unittest.TestCase): - def test_atan2_forward(self): - # create some random Tensors, permute them just because we can - a = Tensor.randn(4,4,requires_grad=True).permute(1,0) - b = Tensor.randn(4,4,requires_grad=True).permute(1,0) - - # run the forward pass. note: up until the .numpy(), it's all lazy - c = ATan2.apply(a, b) - print(c.numpy()) - - # check the forward pass (in numpy) - np.testing.assert_allclose(c.numpy(), np.arctan2(a.numpy(), b.numpy()), atol=1e-5) - - # fun fact, this never actually calls forward, so it works in all the backends - def test_atan2_backward(self): - # have to go forward before we can go backward - a = Tensor.randn(4,4,requires_grad=True).permute(1,0) - b = Tensor.randn(4,4,requires_grad=True).permute(1,0) - c = ATan2.apply(a, b) - - # run the backward pass - c.mean().backward() - assert a.grad is not None and b.grad is not None, "tinygrad didn't compute gradients" - print(a.grad.numpy()) - print(b.grad.numpy()) - - # check the backward pass (in torch) - import torch - ta, tb = torch.tensor(a.numpy(), requires_grad=True), torch.tensor(b.numpy(), requires_grad=True) - tc = torch.atan2(ta, tb) - tc.mean().backward() - assert ta.grad is not None and tb.grad is not None, "torch didn't compute gradients" - np.testing.assert_allclose(a.grad.numpy(), ta.grad.numpy(), atol=1e-5) - np.testing.assert_allclose(b.grad.numpy(), tb.grad.numpy(), atol=1e-5) - - @unittest.skipIf(Device.DEFAULT in ["CPU"], "atan2_cpu not jittable") - def test_atan2_jit(self): - # custom ops even work in the JIT! - from tinygrad.engine.jit import TinyJit - - @TinyJit - def jitted_atan2(a:Tensor, b:Tensor) -> Tensor: - return ATan2.apply(a, b).realize() - - for _ in range(5): - a = Tensor.randn(4,4,requires_grad=True).permute(1,0) - b = Tensor.randn(4,4,requires_grad=True).permute(1,0) - c = jitted_atan2(a, b) - np.testing.assert_allclose(c.numpy(), np.arctan2(a.numpy(), b.numpy()), atol=1e-5) - -if __name__ == "__main__": - unittest.main() diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 40b87daa..156c5a47 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -104,12 +104,6 @@ class CompiledRunner(Runner): assert len(local_size) == 3, "local size must have len 3" return self.clprg(*[x._buf for x in rawbufs], **lra, vals=tuple(var_vals[k] for k in self.p.vars), wait=wait) -class CustomOp(Runner): - def __init__(self, fxn): - self.fxn = fxn - super().__init__(self.fxn.__name__, "CUSTOM", 0, 0) - def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): self.fxn(*rawbufs) - class EmptyOp(Runner): def __init__(self, buf:Buffer): super().__init__(colored(f"empty {buf.size:10d} {buf.dtype}", "yellow"), buf.device) def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): pass @@ -199,7 +193,6 @@ def lower_schedule_item(si:ScheduleItem) -> ExecItem: if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]: kernel_type = BufferXfer return ExecItem(kernel_type(arg, out.device, si.inputs[0].device), list(si.bufs)) - if si.ast.op is UOps.CUSTOM: return ExecItem(CustomOp(arg), list(si.bufs)) if si.ast.op is UOps.EMPTY: return ExecItem(EmptyOp(out), list(si.bufs)) if si.ast.op is UOps.BUFFER_VIEW: return ExecItem(ViewOp(out), list(si.bufs)) raise RuntimeError(f"don't know how to lower {si.ast}") diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 7e866cf5..996d90e8 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -17,7 +17,7 @@ from tinygrad.shape.view import View, strides_for_shape sys.setrecursionlimit(10000) BUF_LIMIT = {"METAL": 32} -METAOPS = {MetaOps.CUSTOM:UOps.CUSTOM, MetaOps.COPY:UOps.COPY, MetaOps.EMPTY:UOps.EMPTY, MetaOps.VIEW:UOps.BUFFER_VIEW} +METAOPS = {MetaOps.COPY:UOps.COPY, MetaOps.EMPTY:UOps.EMPTY, MetaOps.VIEW:UOps.BUFFER_VIEW} # *** ScheduleItem return type *** diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 6ce289d8..44b12fd0 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -33,7 +33,7 @@ class ReduceOps(FastEnum): """A -> B (reduce)""" SUM = auto(); PROD = auto(); MAX = auto() # noqa: E702 class MetaOps(FastEnum): - EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702 + EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); ASSIGN = auto(); VIEW = auto() # noqa: E702 Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps] T = TypeVar("T") @@ -102,7 +102,6 @@ class UOps(FastEnum): CONTIGUOUS = auto() # metaops - CUSTOM = auto() COPY = auto() EMPTY = auto() BUFFER_VIEW = auto()