only div, no reciprocal (#601)

* only div, no reciprocal

* remove reciprocal

* fix pad shuffling
This commit is contained in:
George Hotz 2023-02-25 09:35:03 -08:00 committed by GitHub
parent d581a99d90
commit a8de233e12
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 24 additions and 25 deletions

View File

@ -146,7 +146,7 @@ The autodiff stuff is all in mlops now so you can focus on the raw operations
```
Buffer # class of memory on this device
unary_op (NOOP, NEG, NOT, EXP, LOG, RECIPROCAL) # A -> A
unary_op (NOOP, NEG, NOT, EXP, LOG) # A -> A
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ, MAX) # A + A -> A (all the same size)
movement_op (RESHAPE, PERMUTE, EXPAND, FLIP, PAD, SHRINK) # A -> B (different size)

View File

@ -110,7 +110,12 @@ def get_run_onnx(onnx_model):
elif n.op_type == "Concat": ret = inp[0].cat(*inp[1:], dim=opt['axis'])
elif n.op_type == "Transpose": ret = inp[0].permute(order=opt.get('perm', list(range(len(inp[0].shape))[::-1])))
elif n.op_type == "Squeeze": ret = inp[0].reshape([s for i,s in enumerate(inp[0].shape) if i not in opt['axes']])
elif n.op_type == "Div": ret = inp[0].div(inp[1])
elif n.op_type == "Div":
if prod(inp[1].shape) == 1:
# due to SHUFFLE_PAD_OPS issues, this saves a kernel by taking the reciprocal of constants first, then using mul
ret = inp[0] * (1.0/inp[1])
else:
ret = inp[0].div(inp[1])
elif n.op_type == "Constant": ret = opt['value'] if 'value' in opt else opt['value_float']
elif n.op_type == "Reshape": ret = inp[0].reshape([int(x) if x != 0 else inp[0].shape[i] for i,x in enumerate(safe_numpy(inp[1]))])
elif n.op_type == "Resize":

View File

@ -4,7 +4,7 @@ import sys, weakref, importlib, inspect
from weakref import WeakValueDictionary
from tinygrad.helpers import ConvArgs, prod, DEBUG
from tinygrad.shape import ShapeTracker
from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, OpType, LazyOp, get_buffers, map_buffers, GenericExecAST
from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, OpType, LazyOp, get_buffers, get_lazyops, map_buffers, GenericExecAST
from tinygrad.graph import log_op
from tinygrad.helpers import getenv
@ -37,7 +37,6 @@ Device = _Device()
REMOVE_MOVEMENT_NOPS, MERGE_UNARY_OPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1
MERGE_ELEMENTWISE_OPS, MERGE_ONE_REDUCE_INTO_ELEMENTWISE = OPT>=2, OPT>=2
PUSH_PERMUTES = OPT>=3 # fairly untested, but gets kernels back to 200 for openpilot
SHUFFLE_PAD_OPS = OPT>=4 # NOTE: 0/0 is NaN if you pad, so this can change the output
# **** realize functions ****
def _ast_reduceops(self:LazyBuffer) -> LazyOp:
@ -259,8 +258,8 @@ class LazyBuffer:
if op == MovementOps.STRIDED and local_st.contiguous and self.st.contiguous:
return self.movement_op(MovementOps.RESHAPE, tuple(i for i,_ in arg))
# if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and self.realized is None and len(self.children) == 0 and (SHUFFLE_PAD_OPS or op != MovementOps.PAD) and op not in [MovementOps.EXPAND, MovementOps.STRIDED]:
# if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead. NOTE: UnaryOps is never an OpType
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and self.realized is None and len(self.children) == 0 and op not in [MovementOps.EXPAND, MovementOps.STRIDED] and (op != MovementOps.PAD or all(x.op != BinaryOps.DIV for x in get_lazyops(self.op))):
return replace_with_movement_op(self.op, op, arg)
# create the buffer

View File

@ -5,7 +5,7 @@ from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, Processing
from tinygrad.helpers import shape_to_axis
base_fxn_for_op : Dict[Op, Callable] = {
UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.RECIPROCAL: lambda x: 1.0/x, UnaryOps.NOT: lambda x: (1.0 - x),
UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.NOT: lambda x: (1.0 - x),
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow,
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],

View File

@ -42,7 +42,6 @@ class CLASTKernel(ASTKernel):
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.NOT: "((float)1.0-A)",
UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)",
UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)",
UnaryOps.RECIPROCAL: "native_recip(A)" if NATIVE_EXPLOG else "((float)1.0/A)",
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)",
BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
BinaryOps.MAX: "max(A,B)", ReduceOps.SUM: "A+=B", ReduceOps.MAX: "A=max(A,B)"

View File

@ -32,7 +32,6 @@ class LLVMBuffer(ExplicitExecAST):
UnaryOps.NEG: lambda builder,x: builder.fneg(x, flags=('fast',)),
UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x], fastmath=('fast',)),
UnaryOps.LOG: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log', [ir.FloatType()]), [x], fastmath=('fast',)),
UnaryOps.RECIPROCAL: lambda builder,x: builder.fdiv(ir.Constant(ir.FloatType(), 1), x, flags=('fast',)),
UnaryOps.NOT: lambda builder,x: builder.fsub(ir.Constant(ir.FloatType(), 1), x, flags=('fast',)),
BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)),
BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)),

View File

@ -25,17 +25,6 @@ class Exp(Function):
def backward(self, grad_output):
return self.saved_tensors[0].binary_op(BinaryOps.MUL, grad_output)
class Reciprocal(Function):
def forward(self, x):
ret = x.unary_op(UnaryOps.RECIPROCAL)
self.save_for_backward(ret)
return ret
def backward(self, grad_output):
return grad_output.unary_op(UnaryOps.NEG).binary_op(BinaryOps.MUL, self.saved_tensors[0]).binary_op(BinaryOps.MUL, self.saved_tensors[0])
# TODO: add Neg? confirm the optimizer on Sub good enough
# ************* reduce ops *************
class Sum(Function):
@ -111,11 +100,19 @@ class Pow(Function):
def backward(self, grad_output):
x,y,powxy = self.saved_tensors
# grad_x = grad_output * y * (pow(x,y)/x)
# grad_y = grad_output * log(x) * pow(x,y)
return grad_output.binary_op(BinaryOps.MUL, y.binary_op(BinaryOps.MUL, powxy.binary_op(BinaryOps.DIV, x))) if self.needs_input_grad[0] else None, \
grad_output.binary_op(BinaryOps.MUL, x.unary_op(UnaryOps.LOG).binary_op(BinaryOps.MUL, powxy)) if self.needs_input_grad[1] else None
class Div(Function):
def forward(self, x, y):
self.save_for_backward(x, y)
return x.binary_op(BinaryOps.DIV, y)
def backward(self, grad_output):
x, y = self.saved_tensors
return grad_output.binary_op(BinaryOps.DIV, y) if self.needs_input_grad[0] else None, \
grad_output.unary_op(UnaryOps.NEG).binary_op(BinaryOps.MUL, x).binary_op(BinaryOps.DIV, y.binary_op(BinaryOps.MUL, y)) if self.needs_input_grad[1] else None
# ************* movement ops *************
# NOTE: this is sum in reverse

View File

@ -8,7 +8,7 @@ from tinygrad.shape import ShapeTracker
# these are the llops your accelerator must implement, along with toCpu
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
class UnaryOps(Enum): NOOP = auto(); NEG = auto(); EXP = auto(); LOG = auto(); NOT = auto(); RECIPROCAL = auto() # noqa: E702
class UnaryOps(Enum): NOOP = auto(); NEG = auto(); EXP = auto(); LOG = auto(); NOT = auto() # noqa: E702
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); FLIP = auto(); STRIDED = auto(); PAD = auto(); SHRINK = auto() # noqa: E702

View File

@ -347,7 +347,6 @@ class Tensor:
def contiguous(self): return mlops.Contiguous.apply(self)
def log(self): return mlops.Log.apply(self)
def exp(self): return mlops.Exp.apply(self)
def reciprocal(self): return mlops.Reciprocal.apply(self)
# ***** math functions (unary) *****
@ -358,6 +357,7 @@ class Tensor:
def abs(self): return self.relu() + (-self).relu()
def sign(self): return self / (self.abs() + 1e-10)
def relu(self): return self.maximum(0)
def reciprocal(self): return 1.0/self
# ***** activation functions (unary) *****
@ -386,7 +386,7 @@ class Tensor:
def sub(self, x, reverse=False): return self._broadcasted(mlops.Sub, x, reverse)
def mul(self, x, reverse=False): return self._broadcasted(mlops.Mul, x, reverse)
def pow(self, x, reverse=False): return self._broadcasted(mlops.Pow, x, reverse)
def div(self, x, reverse=False): return (self.reciprocal() * x) if reverse else (self * (x.reciprocal() if isinstance(x, Tensor) else (1/x)))
def div(self, x, reverse=False): return self._broadcasted(mlops.Div, x, reverse)
def matmul(self, x:Tensor, reverse=False): return x.dot(self) if reverse else self.dot(x)
def maximum(self, x): return self._broadcasted(mlops.Maximum, x)