From a8de233e12e23f81c92979a19936270ce0a037b8 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 25 Feb 2023 09:35:03 -0800 Subject: [PATCH] only div, no reciprocal (#601) * only div, no reciprocal * remove reciprocal * fix pad shuffling --- README.md | 2 +- extra/onnx.py | 7 ++++++- tinygrad/lazy.py | 7 +++---- tinygrad/llops/ops_cpu.py | 2 +- tinygrad/llops/ops_gpu.py | 1 - tinygrad/llops/ops_llvm.py | 1 - tinygrad/mlops.py | 23 ++++++++++------------- tinygrad/ops.py | 2 +- tinygrad/tensor.py | 4 ++-- 9 files changed, 24 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 9d53d348..22d7243b 100644 --- a/README.md +++ b/README.md @@ -146,7 +146,7 @@ The autodiff stuff is all in mlops now so you can focus on the raw operations ``` Buffer # class of memory on this device -unary_op (NOOP, NEG, NOT, EXP, LOG, RECIPROCAL) # A -> A +unary_op (NOOP, NEG, NOT, EXP, LOG) # A -> A reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape) binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ, MAX) # A + A -> A (all the same size) movement_op (RESHAPE, PERMUTE, EXPAND, FLIP, PAD, SHRINK) # A -> B (different size) diff --git a/extra/onnx.py b/extra/onnx.py index eee89ff2..72c2fcd6 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -110,7 +110,12 @@ def get_run_onnx(onnx_model): elif n.op_type == "Concat": ret = inp[0].cat(*inp[1:], dim=opt['axis']) elif n.op_type == "Transpose": ret = inp[0].permute(order=opt.get('perm', list(range(len(inp[0].shape))[::-1]))) elif n.op_type == "Squeeze": ret = inp[0].reshape([s for i,s in enumerate(inp[0].shape) if i not in opt['axes']]) - elif n.op_type == "Div": ret = inp[0].div(inp[1]) + elif n.op_type == "Div": + if prod(inp[1].shape) == 1: + # due to SHUFFLE_PAD_OPS issues, this saves a kernel by taking the reciprocal of constants first, then using mul + ret = inp[0] * (1.0/inp[1]) + else: + ret = inp[0].div(inp[1]) elif n.op_type == "Constant": ret = opt['value'] if 'value' in opt else opt['value_float'] elif n.op_type == "Reshape": ret = inp[0].reshape([int(x) if x != 0 else inp[0].shape[i] for i,x in enumerate(safe_numpy(inp[1]))]) elif n.op_type == "Resize": diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 2c1cfc5f..52cf64d8 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -4,7 +4,7 @@ import sys, weakref, importlib, inspect from weakref import WeakValueDictionary from tinygrad.helpers import ConvArgs, prod, DEBUG from tinygrad.shape import ShapeTracker -from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, OpType, LazyOp, get_buffers, map_buffers, GenericExecAST +from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, OpType, LazyOp, get_buffers, get_lazyops, map_buffers, GenericExecAST from tinygrad.graph import log_op from tinygrad.helpers import getenv @@ -37,7 +37,6 @@ Device = _Device() REMOVE_MOVEMENT_NOPS, MERGE_UNARY_OPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1 MERGE_ELEMENTWISE_OPS, MERGE_ONE_REDUCE_INTO_ELEMENTWISE = OPT>=2, OPT>=2 PUSH_PERMUTES = OPT>=3 # fairly untested, but gets kernels back to 200 for openpilot -SHUFFLE_PAD_OPS = OPT>=4 # NOTE: 0/0 is NaN if you pad, so this can change the output # **** realize functions **** def _ast_reduceops(self:LazyBuffer) -> LazyOp: @@ -259,8 +258,8 @@ class LazyBuffer: if op == MovementOps.STRIDED and local_st.contiguous and self.st.contiguous: return self.movement_op(MovementOps.RESHAPE, tuple(i for i,_ in arg)) - # if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead - if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and self.realized is None and len(self.children) == 0 and (SHUFFLE_PAD_OPS or op != MovementOps.PAD) and op not in [MovementOps.EXPAND, MovementOps.STRIDED]: + # if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead. NOTE: UnaryOps is never an OpType + if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and self.realized is None and len(self.children) == 0 and op not in [MovementOps.EXPAND, MovementOps.STRIDED] and (op != MovementOps.PAD or all(x.op != BinaryOps.DIV for x in get_lazyops(self.op))): return replace_with_movement_op(self.op, op, arg) # create the buffer diff --git a/tinygrad/llops/ops_cpu.py b/tinygrad/llops/ops_cpu.py index 2c593335..b7097029 100644 --- a/tinygrad/llops/ops_cpu.py +++ b/tinygrad/llops/ops_cpu.py @@ -5,7 +5,7 @@ from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, Processing from tinygrad.helpers import shape_to_axis base_fxn_for_op : Dict[Op, Callable] = { - UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.RECIPROCAL: lambda x: 1.0/x, UnaryOps.NOT: lambda x: (1.0 - x), + UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.NOT: lambda x: (1.0 - x), BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow, ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:], ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:], diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 6cd86c60..d1c7f6ce 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -42,7 +42,6 @@ class CLASTKernel(ASTKernel): UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.NOT: "((float)1.0-A)", UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)", UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)", - UnaryOps.RECIPROCAL: "native_recip(A)" if NATIVE_EXPLOG else "((float)1.0/A)", BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)", BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)", BinaryOps.MAX: "max(A,B)", ReduceOps.SUM: "A+=B", ReduceOps.MAX: "A=max(A,B)" diff --git a/tinygrad/llops/ops_llvm.py b/tinygrad/llops/ops_llvm.py index 1e35cfa2..13f69b8c 100644 --- a/tinygrad/llops/ops_llvm.py +++ b/tinygrad/llops/ops_llvm.py @@ -32,7 +32,6 @@ class LLVMBuffer(ExplicitExecAST): UnaryOps.NEG: lambda builder,x: builder.fneg(x, flags=('fast',)), UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x], fastmath=('fast',)), UnaryOps.LOG: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log', [ir.FloatType()]), [x], fastmath=('fast',)), - UnaryOps.RECIPROCAL: lambda builder,x: builder.fdiv(ir.Constant(ir.FloatType(), 1), x, flags=('fast',)), UnaryOps.NOT: lambda builder,x: builder.fsub(ir.Constant(ir.FloatType(), 1), x, flags=('fast',)), BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)), BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)), diff --git a/tinygrad/mlops.py b/tinygrad/mlops.py index 006802a5..99569658 100644 --- a/tinygrad/mlops.py +++ b/tinygrad/mlops.py @@ -25,17 +25,6 @@ class Exp(Function): def backward(self, grad_output): return self.saved_tensors[0].binary_op(BinaryOps.MUL, grad_output) -class Reciprocal(Function): - def forward(self, x): - ret = x.unary_op(UnaryOps.RECIPROCAL) - self.save_for_backward(ret) - return ret - - def backward(self, grad_output): - return grad_output.unary_op(UnaryOps.NEG).binary_op(BinaryOps.MUL, self.saved_tensors[0]).binary_op(BinaryOps.MUL, self.saved_tensors[0]) - -# TODO: add Neg? confirm the optimizer on Sub good enough - # ************* reduce ops ************* class Sum(Function): @@ -111,11 +100,19 @@ class Pow(Function): def backward(self, grad_output): x,y,powxy = self.saved_tensors - # grad_x = grad_output * y * (pow(x,y)/x) - # grad_y = grad_output * log(x) * pow(x,y) return grad_output.binary_op(BinaryOps.MUL, y.binary_op(BinaryOps.MUL, powxy.binary_op(BinaryOps.DIV, x))) if self.needs_input_grad[0] else None, \ grad_output.binary_op(BinaryOps.MUL, x.unary_op(UnaryOps.LOG).binary_op(BinaryOps.MUL, powxy)) if self.needs_input_grad[1] else None +class Div(Function): + def forward(self, x, y): + self.save_for_backward(x, y) + return x.binary_op(BinaryOps.DIV, y) + + def backward(self, grad_output): + x, y = self.saved_tensors + return grad_output.binary_op(BinaryOps.DIV, y) if self.needs_input_grad[0] else None, \ + grad_output.unary_op(UnaryOps.NEG).binary_op(BinaryOps.MUL, x).binary_op(BinaryOps.DIV, y.binary_op(BinaryOps.MUL, y)) if self.needs_input_grad[1] else None + # ************* movement ops ************* # NOTE: this is sum in reverse diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 7b96e24d..a1ff7427 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -8,7 +8,7 @@ from tinygrad.shape import ShapeTracker # these are the llops your accelerator must implement, along with toCpu # the Enum class doesn't work with mypy, this is static. sorry it's ugly -class UnaryOps(Enum): NOOP = auto(); NEG = auto(); EXP = auto(); LOG = auto(); NOT = auto(); RECIPROCAL = auto() # noqa: E702 +class UnaryOps(Enum): NOOP = auto(); NEG = auto(); EXP = auto(); LOG = auto(); NOT = auto() # noqa: E702 class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702 class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); FLIP = auto(); STRIDED = auto(); PAD = auto(); SHRINK = auto() # noqa: E702 diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 244edbf0..4c6b1056 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -347,7 +347,6 @@ class Tensor: def contiguous(self): return mlops.Contiguous.apply(self) def log(self): return mlops.Log.apply(self) def exp(self): return mlops.Exp.apply(self) - def reciprocal(self): return mlops.Reciprocal.apply(self) # ***** math functions (unary) ***** @@ -358,6 +357,7 @@ class Tensor: def abs(self): return self.relu() + (-self).relu() def sign(self): return self / (self.abs() + 1e-10) def relu(self): return self.maximum(0) + def reciprocal(self): return 1.0/self # ***** activation functions (unary) ***** @@ -386,7 +386,7 @@ class Tensor: def sub(self, x, reverse=False): return self._broadcasted(mlops.Sub, x, reverse) def mul(self, x, reverse=False): return self._broadcasted(mlops.Mul, x, reverse) def pow(self, x, reverse=False): return self._broadcasted(mlops.Pow, x, reverse) - def div(self, x, reverse=False): return (self.reciprocal() * x) if reverse else (self * (x.reciprocal() if isinstance(x, Tensor) else (1/x))) + def div(self, x, reverse=False): return self._broadcasted(mlops.Div, x, reverse) def matmul(self, x:Tensor, reverse=False): return x.dot(self) if reverse else self.dot(x) def maximum(self, x): return self._broadcasted(mlops.Maximum, x)