mirror of https://github.com/commaai/tinygrad.git
Replace SIGN with GT0 (#511)
* Replace sign with gt0 * Replace sign with gt0 * GT0 works on GPU * Fix brackets --------- Co-authored-by: Tom Finet <tom.codeninja@gmail.com>
This commit is contained in:
parent
799b3f185a
commit
54c68defc7
|
@ -145,7 +145,7 @@ The autodiff stuff is all in mlops now so you can focus on the raw operations
|
||||||
|
|
||||||
```
|
```
|
||||||
Buffer # class of memory on this device
|
Buffer # class of memory on this device
|
||||||
unary_op (RELU, EXP, LOG, NEG, SIGN) # A -> A
|
unary_op (RELU, EXP, LOG, NEG, GT0) # A -> A
|
||||||
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
|
reduce_op (SUM, MAX) # A -> B (smaller size, B has 1 in shape)
|
||||||
binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ) # A + B -> C (all the same size)
|
binary_op (ADD, SUB, MUL, DIV, POW, CMPEQ) # A + B -> C (all the same size)
|
||||||
movement_op (RESHAPE, PERMUTE, PAD, SHRINK, EXPAND, FLIP) # A -> B (different size)
|
movement_op (RESHAPE, PERMUTE, PAD, SHRINK, EXPAND, FLIP) # A -> B (different size)
|
||||||
|
|
|
@ -150,8 +150,7 @@ class LLVMBuffer(ExplicitExecAST):
|
||||||
UnaryOps.RELU: lambda builder,x: builder.select(builder.fcmp_ordered("<=", ir.Constant(ir.FloatType(), 0), x, flags=('fast',)), x, ir.Constant(ir.FloatType(), 0)),
|
UnaryOps.RELU: lambda builder,x: builder.select(builder.fcmp_ordered("<=", ir.Constant(ir.FloatType(), 0), x, flags=('fast',)), x, ir.Constant(ir.FloatType(), 0)),
|
||||||
UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x], fastmath=('fast',)),
|
UnaryOps.EXP: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||||
UnaryOps.LOG: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log', [ir.FloatType()]), [x], fastmath=('fast',)),
|
UnaryOps.LOG: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log', [ir.FloatType()]), [x], fastmath=('fast',)),
|
||||||
UnaryOps.SIGN: lambda builder,x: builder.select(builder.fcmp_ordered("==", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), ir.Constant(ir.FloatType(), 0),
|
UnaryOps.GT0: lambda builder,x: builder.select(builder.fcmp_ordered(">", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), ir.Constant(ir.FloatType(), 1), ir.Constant(ir.FloatType(), 0)),
|
||||||
builder.select(builder.fcmp_ordered("<=", ir.Constant(ir.FloatType(), 0), x, flags=('fast',)), ir.Constant(ir.FloatType(), 1), ir.Constant(ir.FloatType(), -1))),
|
|
||||||
UnaryOps.RECIPROCAL: lambda builder,x: builder.fdiv(ir.Constant(ir.FloatType(), 1), x, flags=('fast',)),
|
UnaryOps.RECIPROCAL: lambda builder,x: builder.fdiv(ir.Constant(ir.FloatType(), 1), x, flags=('fast',)),
|
||||||
BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)),
|
BinaryOps.ADD: lambda builder,x,y: builder.fadd(x,y, flags=('fast',)),
|
||||||
BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)),
|
BinaryOps.SUB: lambda builder,x,y: builder.fsub(x,y, flags=('fast',)),
|
||||||
|
|
|
@ -7,7 +7,7 @@ from tinygrad.helpers import shape_to_axis
|
||||||
class CPUBuffer(np.ndarray, GenericExecAST):
|
class CPUBuffer(np.ndarray, GenericExecAST):
|
||||||
fxn_for_op = {
|
fxn_for_op = {
|
||||||
UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.RELU: lambda x: x.relu(),
|
UnaryOps.NOOP: lambda x: x[:], UnaryOps.NEG: lambda x: -x, UnaryOps.RELU: lambda x: x.relu(),
|
||||||
UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), UnaryOps.SIGN: lambda x: x.sign(), UnaryOps.RECIPROCAL: lambda x: 1.0/x,
|
UnaryOps.EXP: lambda x: x.exp(), UnaryOps.LOG: lambda x: x.log(), UnaryOps.GT0: lambda x: operator.gt(x, 0.0), UnaryOps.RECIPROCAL: lambda x: 1.0/x,
|
||||||
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul,
|
BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul,
|
||||||
BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow, BinaryOps.CMPEQ: lambda x,y: (x==y).float(),
|
BinaryOps.DIV: operator.truediv, BinaryOps.POW: operator.pow, BinaryOps.CMPEQ: lambda x,y: (x==y).float(),
|
||||||
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
ReduceOps.SUM: lambda x, new_shape: x.sum(shape_to_axis(x.shape, new_shape), keepdims=True) if tuple(x.shape) != tuple(new_shape) else x[:],
|
||||||
|
@ -18,7 +18,6 @@ class CPUBuffer(np.ndarray, GenericExecAST):
|
||||||
def relu(x): return np.maximum(x, 0)
|
def relu(x): return np.maximum(x, 0)
|
||||||
def exp(x): return np.exp(x)
|
def exp(x): return np.exp(x)
|
||||||
def log(x): return np.log(x)
|
def log(x): return np.log(x)
|
||||||
def sign(x): return np.sign(x)
|
|
||||||
def float(x): return x.astype(np.float32)
|
def float(x): return x.astype(np.float32)
|
||||||
def flip(x, axis): return np.flip(x, axis)
|
def flip(x, axis): return np.flip(x, axis)
|
||||||
def amax(x, *args, **kwargs): return np.amax(x, *args, **kwargs)
|
def amax(x, *args, **kwargs): return np.amax(x, *args, **kwargs)
|
||||||
|
|
|
@ -29,7 +29,7 @@ def split_float4(x):
|
||||||
|
|
||||||
class CLASTKernel(ASTKernel):
|
class CLASTKernel(ASTKernel):
|
||||||
code_for_op : Dict[Op, str] = {
|
code_for_op : Dict[Op, str] = {
|
||||||
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.SIGN: "sign(A)",
|
UnaryOps.NOOP: "(A)", UnaryOps.NEG: "(-(A))", UnaryOps.RELU: "max(A, (float)0.)", UnaryOps.GT0: "((float)1.-step((float)0.,(-A)))",
|
||||||
UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)",
|
UnaryOps.EXP: "native_exp(A)" if NATIVE_EXPLOG else "exp(A)",
|
||||||
UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)",
|
UnaryOps.LOG: "native_log(A)" if NATIVE_EXPLOG else "log(A)",
|
||||||
UnaryOps.RECIPROCAL: "native_recip(A)" if NATIVE_EXPLOG else "((float)1.0/A)",
|
UnaryOps.RECIPROCAL: "native_recip(A)" if NATIVE_EXPLOG else "((float)1.0/A)",
|
||||||
|
@ -92,7 +92,6 @@ class CLASTKernel(ASTKernel):
|
||||||
if isinstance(x.op, ReduceOps) and not do_reduce: return acc
|
if isinstance(x.op, ReduceOps) and not do_reduce: return acc
|
||||||
values = ([acc] if isinstance(x.op, ReduceOps) else []) + [self.ast_parse(v, acc, do_reduce) for v in x.src]
|
values = ([acc] if isinstance(x.op, ReduceOps) else []) + [self.ast_parse(v, acc, do_reduce) for v in x.src]
|
||||||
code = CLASTKernel.code_for_op[x.op] # TODO: replace this with a function
|
code = CLASTKernel.code_for_op[x.op] # TODO: replace this with a function
|
||||||
if CUDA and x.op == UnaryOps.SIGN: self.prekernel.add("inline __device__ float sign(float x) { float val = (signbit(x) == 0.0f) ? 1.0f : -1.0f; return (x == 0.0f) ? 0.0f : val; }")
|
|
||||||
if len(values) == 2:
|
if len(values) == 2:
|
||||||
# TODO: sometimes this is split, sometimes it's multiply
|
# TODO: sometimes this is split, sometimes it's multiply
|
||||||
if isinstance(x.op, ReduceOps) and values[0][0].typ == Types.FLOAT4 and len(values[0])*4 == len(values[1]): values[0] = split_float4(values[0])
|
if isinstance(x.op, ReduceOps) and values[0][0].typ == Types.FLOAT4 and len(values[0])*4 == len(values[1]): values[0] = split_float4(values[0])
|
||||||
|
|
|
@ -15,7 +15,7 @@ class ReLU(Function):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def backward(self, grad_output):
|
def backward(self, grad_output):
|
||||||
return self.saved_tensors[0].unary_op(UnaryOps.SIGN).binary_op(BinaryOps.MUL, grad_output)
|
return self.saved_tensors[0].unary_op(UnaryOps.GT0).binary_op(BinaryOps.MUL, grad_output)
|
||||||
|
|
||||||
class Log(Function):
|
class Log(Function):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
@ -10,7 +10,7 @@ from tinygrad.helpers import getenv
|
||||||
DEBUG = getenv("DEBUG", 0)
|
DEBUG = getenv("DEBUG", 0)
|
||||||
|
|
||||||
# these are the llops your accelerator must implement, along with toCpu
|
# these are the llops your accelerator must implement, along with toCpu
|
||||||
UnaryOps = Enum("UnaryOps", ["NOOP", "NEG", "RELU", "EXP", "LOG", "SIGN", "RECIPROCAL"])
|
UnaryOps = Enum("UnaryOps", ["NOOP", "NEG", "RELU", "EXP", "LOG", "GT0", "RECIPROCAL"])
|
||||||
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"])
|
BinaryOps = Enum("BinaryOps", ["ADD", "SUB", "MUL", "DIV", "POW", "CMPEQ"])
|
||||||
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
|
ReduceOps = Enum("ReduceOps", ["SUM", "MAX"])
|
||||||
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "EXPAND", "FLIP", "STRIDED", "PAD", "SHRINK"])
|
MovementOps = Enum("MovementOps", ["RESHAPE", "PERMUTE", "EXPAND", "FLIP", "STRIDED", "PAD", "SHRINK"])
|
||||||
|
|
Loading…
Reference in New Issue