mirror of https://github.com/commaai/tinygrad.git
Refactor code_for_op to accept a dtype (#2555)
* update cstyle renderers to take a dtype in code_for_op * implement NEG for bools in LLVM * update triton --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
parent
82fd932921
commit
99ee2ec37a
|
@ -59,18 +59,18 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
|
|||
|
||||
def kk(s): kernel.append(" "*depth+s)
|
||||
code_for_op: Final[Dict[Op, Callable]] = {
|
||||
UnaryOps.EXP2: lambda x,: f"tl.math.exp2({x})",
|
||||
UnaryOps.LOG2: lambda x,: f"tl.math.log2({x})",
|
||||
UnaryOps.SIN: lambda x,: f"tl.sin({x})",
|
||||
UnaryOps.SQRT: lambda x,: f"tl.sqrt({x})",
|
||||
UnaryOps.NEG: lambda x,: f"-{x}",
|
||||
BinaryOps.ADD: lambda x,y,: f"({x}+{y})", BinaryOps.SUB: lambda x,y,: f"({x}-{y})",
|
||||
BinaryOps.MUL: lambda x,y,: f"({x}*{y})", BinaryOps.DIV: lambda x,y,: f"({x}/{y})" if y != '0.0' else f"{x}*tl.where({x}==0.0, float('nan'), float('inf'))",
|
||||
BinaryOps.MAX: lambda x,y,: f"tl.maximum({x},{y})",
|
||||
BinaryOps.CMPLT: lambda x,y,: f"({x}<{y})",
|
||||
BinaryOps.MOD: lambda x,y,: f"tl.abs({x})%tl.abs({y})*tl.where({x}<0,-1,1)",
|
||||
TernaryOps.MULACC: lambda x,y,z,: f"(({x}*{y})+{z})",
|
||||
TernaryOps.WHERE: lambda x,y,z,: f"tl.where({x},{y},{z})",
|
||||
UnaryOps.EXP2: lambda x,dtype,: f"tl.math.exp2({x})",
|
||||
UnaryOps.LOG2: lambda x,dtype,: f"tl.math.log2({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"tl.sin({x})",
|
||||
UnaryOps.SQRT: lambda x,dtype: f"tl.sqrt({x})",
|
||||
UnaryOps.NEG: lambda x,dtype: f"-{x}" if dtype != dtypes.bool else f"tl.where({x}, 0, 1)",
|
||||
BinaryOps.ADD: lambda x,y,dtype: f"({x}+{y})", BinaryOps.SUB: lambda x,y,: f"({x}-{y})",
|
||||
BinaryOps.MUL: lambda x,y,dtype: f"({x}*{y})", BinaryOps.DIV: lambda x,y,: f"({x}/{y})" if y != '0.0' else f"{x}*tl.where({x}==0.0, float('nan'), float('inf'))",
|
||||
BinaryOps.MAX: lambda x,y,dtype: f"tl.maximum({x},{y})",
|
||||
BinaryOps.CMPLT: lambda x,y,dtype: f"({x}<{y})",
|
||||
BinaryOps.MOD: lambda x,y,dtype: f"tl.abs({x})%tl.abs({y})*tl.where({x}<0,-1,1)",
|
||||
TernaryOps.MULACC: lambda x,y,z,dtype: f"(({x}*{y})+{z})",
|
||||
TernaryOps.WHERE: lambda x,y,z,dtype: f"tl.where({x},{y},{z})",
|
||||
}
|
||||
def int_div(x,y): return f"({x}//{y})" if y != '0' else f"{x}*tl.where({x}==0, float('nan'), float('inf'))"
|
||||
for u in uops:
|
||||
|
|
|
@ -29,16 +29,16 @@ class CStyleLanguage(NamedTuple):
|
|||
uses_ptr_arithmetic: bool = False
|
||||
launch_bounds: bool = False
|
||||
code_for_op: Dict = {
|
||||
UnaryOps.NEG: lambda x: f"(-{x})",
|
||||
UnaryOps.EXP2: lambda x: f"exp2({x})",
|
||||
UnaryOps.LOG2: lambda x: f"log2({x})",
|
||||
UnaryOps.SIN: lambda x: f"sin({x})",
|
||||
UnaryOps.SQRT: lambda x: f"sqrt({x})",
|
||||
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
|
||||
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
|
||||
BinaryOps.MAX: lambda a,b: f"max({a},{b})", BinaryOps.MOD: lambda a,b: f"({a}%{b})",
|
||||
BinaryOps.CMPLT: lambda a,b: f"({a}<{b})", TernaryOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})",
|
||||
TernaryOps.WHERE: lambda a,b,c: f"({a}!=0?{b}:{c})"
|
||||
UnaryOps.NEG: lambda x,dtype: f"(-{x})" if dtype != dtypes.bool else f"(!{x})",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"exp2({x})",
|
||||
UnaryOps.LOG2: lambda x,dtype: f"log2({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"sin({x})",
|
||||
UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
|
||||
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})",
|
||||
BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})",
|
||||
BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
|
||||
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", TernaryOps.MULACC: lambda a,b,c,dtype: f"(({a}*{b})+{c})",
|
||||
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}!=0?{b}:{c})"
|
||||
}
|
||||
|
||||
# returns a str expression of the casted xs with the given type
|
||||
|
@ -157,11 +157,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
|
|||
assert dtype is not None
|
||||
# remove parens if ALU types are the same. TODO: can do more here
|
||||
if vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL}:
|
||||
val = lang.code_for_op[args](strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]])
|
||||
val = lang.code_for_op[args](strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]], dtype)
|
||||
elif args == BinaryOps.MAX:
|
||||
val = lang.code_for_op[args](*[lang.render_cast([r[x]], dtype) if x.dtype != dtype else r[x] for x in vin])
|
||||
val = lang.code_for_op[args](*[lang.render_cast([r[x]], dtype) if x.dtype != dtype else r[x] for x in vin] + [dtype])
|
||||
else:
|
||||
val = lang.code_for_op[args](*[r[x] for x in vin])
|
||||
val = lang.code_for_op[args](*[r[x] for x in vin] + [dtype])
|
||||
assert child_count[u] != 0, f"childless ALU op found {u}"
|
||||
if (child_count[u] <= 1 or dtypes.is_int(dtype)) and args != BinaryOps.MAX: # fix index rendering issue. fix clang nested max macro issue
|
||||
r[u] = val
|
||||
|
|
|
@ -7,7 +7,7 @@ from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
|
|||
LLVM_FAST_MATH_FLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
|
||||
|
||||
code_for_op: Final[Dict[Op, Callable]] = {
|
||||
UnaryOps.NEG: lambda builder,x: builder.neg(x) if isinstance(x.type, ir.IntType) else builder.fneg(x, flags=LLVM_FAST_MATH_FLAGS),
|
||||
UnaryOps.NEG: lambda builder,x: builder.neg(x) if isinstance(x.type, ir.IntType) else builder.fneg(x, flags=LLVM_FAST_MATH_FLAGS) if isinstance(x.type, ir.FloatType) else builder.not_(x),
|
||||
UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
|
||||
UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
|
||||
UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
|
||||
|
|
|
@ -18,6 +18,6 @@ class OpenCLLanguage(CStyleLanguage):
|
|||
xid = [f'get_global_id({i})' for i in range(3)]
|
||||
uses_vload = True
|
||||
# NOTE: mad is used so the loads aren't reordered into the math on 845
|
||||
code_for_op = {**CStyleLanguage().code_for_op, TernaryOps.MULACC: lambda a,b,c: f"mad({a},{b},{c})"}
|
||||
code_for_op = {**CStyleLanguage().code_for_op, TernaryOps.MULACC: lambda a,b,c,dtype: f"mad({a},{b},{c})"}
|
||||
|
||||
OpenCLRenderer = functools.partial(uops_to_cstyle, OpenCLLanguage())
|
||||
|
|
|
@ -13,7 +13,7 @@ class WGSLLanguage(CStyleLanguage):
|
|||
barrier="workgroupBarrier();"
|
||||
generic_var_prefix = "var "
|
||||
external_local_bufs = True
|
||||
code_for_op = { **CStyleLanguage().code_for_op, BinaryOps.CMPLT: lambda x,y: f"f32({x}<{y})", TernaryOps.MULACC: lambda x,y,z: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c: f"select({c},{b},{a}!=0.)" }
|
||||
code_for_op = { **CStyleLanguage().code_for_op, BinaryOps.CMPLT: lambda x,y,dtype: f"f32({x}<{y})", TernaryOps.MULACC: lambda x,y,z,dtype: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c,dtype: f"select({c},{b},{a}!=0.)" }
|
||||
|
||||
def render_local(self, name: str, size: int):
|
||||
return f"var<workgroup> {name}: array<f32,{size}>;"
|
||||
|
|
Loading…
Reference in New Issue