mirror of https://github.com/commaai/tinygrad.git
only div, no reciprocal (#601)
* only div, no reciprocal * remove reciprocal * fix pad shuffling
This commit is contained in:
parent
d581a99d90
commit
a8de233e12
|
@ -146,7 +146,7 @@ The autodiff stuff is all in mlops now so you can focus on the raw operations
|
|||
|
||||
```
|
||||
Buffer # class of memory on this device
|
||||
unary_op (NOOP, NEG, NOT, EXP, LOG, RECIPROCAL) # A -> A
|
||||
unary_op (NOOP, NEG, NOT, EXP, LOG) # A -> A
|
||||
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
|
||||
binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ, MAX) # A + A -> A (all the same size)
|
||||
movement_op (RESHAPE, PERMUTE, EXPAND, FLIP, PAD, SHRINK) # A -> B (different size)
|
||||
|
|
|
@ -110,7 +110,12 @@ def get_run_onnx(onnx_model):
|
|||
elif n.op_type == "Concat": ret = inp[0].cat(*inp[1:], dim=opt['axis'])
|
||||
elif n.op_type == "Transpose": ret = inp[0].permute(order=opt.get('perm', list(range(len(inp[0].shape))[::-1])))
|
||||
elif n.op_type == "Squeeze": ret = inp[0].reshape([s for i,s in enumerate(inp[0].shape) if i not in opt['axes']])
|
||||
elif n.op_type == "Div": ret = inp[0].div(inp[1])
|
||||
elif n.op_type == "Div":
|
||||
if prod(inp[1].shape) == 1:
|
||||
# due to SHUFFLE_PAD_OPS issues, this saves a kernel by taking the reciprocal of constants first, then using mul
|
||||
ret = inp[0] * (1.0/inp[1])
|
||||
else:
|
||||
ret = inp[0].div(inp[1])
|
||||
elif n.op_type == "Constant": ret = opt['value'] if 'value' in opt else opt['value_float']
|
||||
elif n.op_type == "Reshape": ret = inp[0].reshape([int(x) if x != 0 else inp[0].shape[i] for i,x in enumerate(safe_numpy(inp[1]))])
|
||||
elif n.op_type == "Resize":
|
||||
|
|
|
@ -4,7 +4,7 @@ import sys, weakref, importlib, inspect
|
|||
from weakref import WeakValueDictionary
|
||||
from tinygrad.helpers import ConvArgs, prod, DEBUG
|
||||
from tinygrad.shape import ShapeTracker
|
||||
from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, OpType, LazyOp, get_buffers, map_buffers, GenericExecAST
|
||||
from tinygrad.ops import DeviceBuffer, UnaryOps, BinaryOps, ReduceOps, MovementOps, ProcessingOps, LoadOps, OpType, LazyOp, get_buffers, get_lazyops, map_buffers, GenericExecAST
|
||||
from tinygrad.graph import log_op
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
|
@ -37,7 +37,6 @@ Device = _Device()
|
|||
REMOVE_MOVEMENT_NOPS, MERGE_UNARY_OPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS = OPT>=1, OPT>=1, OPT>=1, OPT>=1
|
||||
MERGE_ELEMENTWISE_OPS, MERGE_ONE_REDUCE_INTO_ELEMENTWISE = OPT>=2, OPT>=2
|
||||
PUSH_PERMUTES = OPT>=3 # fairly untested, but gets kernels back to 200 for openpilot
|
||||
SHUFFLE_PAD_OPS = OPT>=4 # NOTE: 0/0 is NaN if you pad, so this can change the output
|
||||
|
||||
# **** realize functions ****
|
||||
def _ast_reduceops(self:LazyBuffer) -> LazyOp:
|
||||
|
@ -259,8 +258,8 @@ class LazyBuffer:
|
|||
if op == MovementOps.STRIDED and local_st.contiguous and self.st.contiguous:
|
||||
return self.movement_op(MovementOps.RESHAPE, tuple(i for i,_ in arg))
|
||||
|
||||
# if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead
|
||||
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and self.realized is None and len(self.children) == 0 and (SHUFFLE_PAD_OPS or op != MovementOps.PAD) and op not in [MovementOps.EXPAND, MovementOps.STRIDED]:
|
||||
# if this MovementOp is being applied to a BinaryOp, apply the MovementOp to all the BinaryOp inputs instead. NOTE: UnaryOps is never an OpType
|
||||
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and self.realized is None and len(self.children) == 0 and op not in [MovementOps.EXPAND, MovementOps.STRIDED] and (op != MovementOps.PAD or all(x.op != BinaryOps.DIV for x in get_lazyops(self.op))):
|
||||
return replace_with_movement_op(self.op, op, arg)
|
||||
|
||||
# create the buffer
|
||||
|
|
|
@ -5,7 +5,7 @@ from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, Processing
|
|||
from tinygrad.helpers import shape_to_axis
|
||||
|
||||
base_fxn_for_op : Dict[Op, Callable] = {
|
||||
UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.RECIPROCAL: lambda x: 1.0/x, UnaryOps.NOT: lambda x: (1.0 - x),
|
||||
UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.NOT: lambda x: (1.0 - x),
|
||||
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow,
|
||||
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
||||
ReduceOps.MAX: lambda x, new_shape: (x.amax if hasattr(x, 'amax') else x.max)(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
||||
|
|
|
@ -42,7 +42,6 @@ class CLASTKernel(ASTKernel):
|
|||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.NOT: "((float)1.0-A)",
|
||||
UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)",
|
||||
UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)",
|
||||
UnaryOps.RECIPROCAL: "native_recip(A)" if NATIVE_EXPLOG else "((float)1.0/A)",
|
||||
BinaryOps.ADD: "(A+B)", BinaryOps.SUB: "(A-B)", BinaryOps.MUL: "(A*B)",
|
||||
BinaryOps.DIV: "(A/B)", BinaryOps.POW: "pow(A,B)", BinaryOps.CMPEQ: "(A==B)",
|
||||
BinaryOps.MAX: "max(A,B)", ReduceOps.SUM: "A+=B", ReduceOps.MAX: "A=max(A,B)"
|
||||
|
|
|
@ -32,7 +32,6 @@ class LLVMBuffer(ExplicitExecAST):
|
|||
UnaryOps.NEG: lambda builder,x: builder.fneg(x, flags=('fast',)),
|
||||
UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||
UnaryOps.LOG: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||
UnaryOps.RECIPROCAL: lambda builder,x: builder.fdiv(ir.Constant(ir.FloatType(), 1), x, flags=('fast',)),
|
||||
UnaryOps.NOT: lambda builder,x: builder.fsub(ir.Constant(ir.FloatType(), 1), x, flags=('fast',)),
|
||||
BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)),
|
||||
BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)),
|
||||
|
|
|
@ -25,17 +25,6 @@ class Exp(Function):
|
|||
def backward(self, grad_output):
|
||||
return self.saved_tensors[0].binary_op(BinaryOps.MUL, grad_output)
|
||||
|
||||
class Reciprocal(Function):
|
||||
def forward(self, x):
|
||||
ret = x.unary_op(UnaryOps.RECIPROCAL)
|
||||
self.save_for_backward(ret)
|
||||
return ret
|
||||
|
||||
def backward(self, grad_output):
|
||||
return grad_output.unary_op(UnaryOps.NEG).binary_op(BinaryOps.MUL, self.saved_tensors[0]).binary_op(BinaryOps.MUL, self.saved_tensors[0])
|
||||
|
||||
# TODO: add Neg? confirm the optimizer on Sub good enough
|
||||
|
||||
# ************* reduce ops *************
|
||||
|
||||
class Sum(Function):
|
||||
|
@ -111,11 +100,19 @@ class Pow(Function):
|
|||
|
||||
def backward(self, grad_output):
|
||||
x,y,powxy = self.saved_tensors
|
||||
# grad_x = grad_output * y * (pow(x,y)/x)
|
||||
# grad_y = grad_output * log(x) * pow(x,y)
|
||||
return grad_output.binary_op(BinaryOps.MUL, y.binary_op(BinaryOps.MUL, powxy.binary_op(BinaryOps.DIV, x))) if self.needs_input_grad[0] else None, \
|
||||
grad_output.binary_op(BinaryOps.MUL, x.unary_op(UnaryOps.LOG).binary_op(BinaryOps.MUL, powxy)) if self.needs_input_grad[1] else None
|
||||
|
||||
class Div(Function):
|
||||
def forward(self, x, y):
|
||||
self.save_for_backward(x, y)
|
||||
return x.binary_op(BinaryOps.DIV, y)
|
||||
|
||||
def backward(self, grad_output):
|
||||
x, y = self.saved_tensors
|
||||
return grad_output.binary_op(BinaryOps.DIV, y) if self.needs_input_grad[0] else None, \
|
||||
grad_output.unary_op(UnaryOps.NEG).binary_op(BinaryOps.MUL, x).binary_op(BinaryOps.DIV, y.binary_op(BinaryOps.MUL, y)) if self.needs_input_grad[1] else None
|
||||
|
||||
# ************* movement ops *************
|
||||
|
||||
# NOTE: this is sum in reverse
|
||||
|
|
|
@ -8,7 +8,7 @@ from tinygrad.shape import ShapeTracker
|
|||
|
||||
# these are the llops your accelerator must implement, along with toCpu
|
||||
# the Enum class doesn't work with mypy, this is static. sorry it's ugly
|
||||
class UnaryOps(Enum): NOOP = auto(); NEG = auto(); EXP = auto(); LOG = auto(); NOT = auto(); RECIPROCAL = auto() # noqa: E702
|
||||
class UnaryOps(Enum): NOOP = auto(); NEG = auto(); EXP = auto(); LOG = auto(); NOT = auto() # noqa: E702
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); POW = auto(); CMPEQ = auto(); MAX = auto() # noqa: E702
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); FLIP = auto(); STRIDED = auto(); PAD = auto(); SHRINK = auto() # noqa: E702
|
||||
|
|
|
@ -347,7 +347,6 @@ class Tensor:
|
|||
def contiguous(self): return mlops.Contiguous.apply(self)
|
||||
def log(self): return mlops.Log.apply(self)
|
||||
def exp(self): return mlops.Exp.apply(self)
|
||||
def reciprocal(self): return mlops.Reciprocal.apply(self)
|
||||
|
||||
# ***** math functions (unary) *****
|
||||
|
||||
|
@ -358,6 +357,7 @@ class Tensor:
|
|||
def abs(self): return self.relu() + (-self).relu()
|
||||
def sign(self): return self / (self.abs() + 1e-10)
|
||||
def relu(self): return self.maximum(0)
|
||||
def reciprocal(self): return 1.0/self
|
||||
|
||||
# ***** activation functions (unary) *****
|
||||
|
||||
|
@ -386,7 +386,7 @@ class Tensor:
|
|||
def sub(self, x, reverse=False): return self._broadcasted(mlops.Sub, x, reverse)
|
||||
def mul(self, x, reverse=False): return self._broadcasted(mlops.Mul, x, reverse)
|
||||
def pow(self, x, reverse=False): return self._broadcasted(mlops.Pow, x, reverse)
|
||||
def div(self, x, reverse=False): return (self.reciprocal() * x) if reverse else (self * (x.reciprocal() if isinstance(x, Tensor) else (1/x)))
|
||||
def div(self, x, reverse=False): return self._broadcasted(mlops.Div, x, reverse)
|
||||
def matmul(self, x:Tensor, reverse=False): return x.dot(self) if reverse else self.dot(x)
|
||||
|
||||
def maximum(self, x): return self._broadcasted(mlops.Maximum, x)
|
||||
|
|
Loading…
Reference in New Issue