Refactor code_for_op to accept a dtype (#2555)

* update cstyle renderers to take a dtype in code_for_op

* implement NEG for bools in LLVM

* update triton

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
qazal 2023-12-02 01:05:28 -05:00 committed by GitHub
parent 82fd932921
commit 99ee2ec37a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 28 additions and 28 deletions

View File

@ -59,18 +59,18 @@ def uops_to_triton(function_name:str, uops:List[UOp]):
def kk(s): kernel.append(" "*depth+s)
code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.EXP2: lambda x,: f"tl.math.exp2({x})",
UnaryOps.LOG2: lambda x,: f"tl.math.log2({x})",
UnaryOps.SIN: lambda x,: f"tl.sin({x})",
UnaryOps.SQRT: lambda x,: f"tl.sqrt({x})",
UnaryOps.NEG: lambda x,: f"-{x}",
BinaryOps.ADD: lambda x,y,: f"({x}+{y})", BinaryOps.SUB: lambda x,y,: f"({x}-{y})",
BinaryOps.MUL: lambda x,y,: f"({x}*{y})", BinaryOps.DIV: lambda x,y,: f"({x}/{y})" if y != '0.0' else f"{x}*tl.where({x}==0.0, float('nan'), float('inf'))",
BinaryOps.MAX: lambda x,y,: f"tl.maximum({x},{y})",
BinaryOps.CMPLT: lambda x,y,: f"({x}<{y})",
BinaryOps.MOD: lambda x,y,: f"tl.abs({x})%tl.abs({y})*tl.where({x}<0,-1,1)",
TernaryOps.MULACC: lambda x,y,z,: f"(({x}*{y})+{z})",
TernaryOps.WHERE: lambda x,y,z,: f"tl.where({x},{y},{z})",
UnaryOps.EXP2: lambda x,dtype,: f"tl.math.exp2({x})",
UnaryOps.LOG2: lambda x,dtype,: f"tl.math.log2({x})",
UnaryOps.SIN: lambda x,dtype: f"tl.sin({x})",
UnaryOps.SQRT: lambda x,dtype: f"tl.sqrt({x})",
UnaryOps.NEG: lambda x,dtype: f"-{x}" if dtype != dtypes.bool else f"tl.where({x}, 0, 1)",
BinaryOps.ADD: lambda x,y,dtype: f"({x}+{y})", BinaryOps.SUB: lambda x,y,: f"({x}-{y})",
BinaryOps.MUL: lambda x,y,dtype: f"({x}*{y})", BinaryOps.DIV: lambda x,y,: f"({x}/{y})" if y != '0.0' else f"{x}*tl.where({x}==0.0, float('nan'), float('inf'))",
BinaryOps.MAX: lambda x,y,dtype: f"tl.maximum({x},{y})",
BinaryOps.CMPLT: lambda x,y,dtype: f"({x}<{y})",
BinaryOps.MOD: lambda x,y,dtype: f"tl.abs({x})%tl.abs({y})*tl.where({x}<0,-1,1)",
TernaryOps.MULACC: lambda x,y,z,dtype: f"(({x}*{y})+{z})",
TernaryOps.WHERE: lambda x,y,z,dtype: f"tl.where({x},{y},{z})",
}
def int_div(x,y): return f"({x}//{y})" if y != '0' else f"{x}*tl.where({x}==0, float('nan'), float('inf'))"
for u in uops:

View File

@ -29,16 +29,16 @@ class CStyleLanguage(NamedTuple):
uses_ptr_arithmetic: bool = False
launch_bounds: bool = False
code_for_op: Dict = {
UnaryOps.NEG: lambda x: f"(-{x})",
UnaryOps.EXP2: lambda x: f"exp2({x})",
UnaryOps.LOG2: lambda x: f"log2({x})",
UnaryOps.SIN: lambda x: f"sin({x})",
UnaryOps.SQRT: lambda x: f"sqrt({x})",
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
BinaryOps.MAX: lambda a,b: f"max({a},{b})", BinaryOps.MOD: lambda a,b: f"({a}%{b})",
BinaryOps.CMPLT: lambda a,b: f"({a}<{b})", TernaryOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})",
TernaryOps.WHERE: lambda a,b,c: f"({a}!=0?{b}:{c})"
UnaryOps.NEG: lambda x,dtype: f"(-{x})" if dtype != dtypes.bool else f"(!{x})",
UnaryOps.EXP2: lambda x,dtype: f"exp2({x})",
UnaryOps.LOG2: lambda x,dtype: f"log2({x})",
UnaryOps.SIN: lambda x,dtype: f"sin({x})",
UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})",
BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})",
BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", TernaryOps.MULACC: lambda a,b,c,dtype: f"(({a}*{b})+{c})",
TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}!=0?{b}:{c})"
}
# returns a str expression of the casted xs with the given type
@ -157,11 +157,11 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> Tu
assert dtype is not None
# remove parens if ALU types are the same. TODO: can do more here
if vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL}:
val = lang.code_for_op[args](strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]])
val = lang.code_for_op[args](strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]], dtype)
elif args == BinaryOps.MAX:
val = lang.code_for_op[args](*[lang.render_cast([r[x]], dtype) if x.dtype != dtype else r[x] for x in vin])
val = lang.code_for_op[args](*[lang.render_cast([r[x]], dtype) if x.dtype != dtype else r[x] for x in vin] + [dtype])
else:
val = lang.code_for_op[args](*[r[x] for x in vin])
val = lang.code_for_op[args](*[r[x] for x in vin] + [dtype])
assert child_count[u] != 0, f"childless ALU op found {u}"
if (child_count[u] <= 1 or dtypes.is_int(dtype)) and args != BinaryOps.MAX: # fix index rendering issue. fix clang nested max macro issue
r[u] = val

View File

@ -7,7 +7,7 @@ from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
LLVM_FAST_MATH_FLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.NEG: lambda builder,x: builder.neg(x) if isinstance(x.type, ir.IntType) else builder.fneg(x, flags=LLVM_FAST_MATH_FLAGS),
UnaryOps.NEG: lambda builder,x: builder.neg(x) if isinstance(x.type, ir.IntType) else builder.fneg(x, flags=LLVM_FAST_MATH_FLAGS) if isinstance(x.type, ir.FloatType) else builder.not_(x),
UnaryOps.EXP2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
UnaryOps.LOG2: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),
UnaryOps.SIN: lambda builder,x: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [ir.FloatType()]), [x], fastmath=LLVM_FAST_MATH_FLAGS),

View File

@ -18,6 +18,6 @@ class OpenCLLanguage(CStyleLanguage):
xid = [f'get_global_id({i})' for i in range(3)]
uses_vload = True
# NOTE: mad is used so the loads aren't reordered into the math on 845
code_for_op = {**CStyleLanguage().code_for_op, TernaryOps.MULACC: lambda a,b,c: f"mad({a},{b},{c})"}
code_for_op = {**CStyleLanguage().code_for_op, TernaryOps.MULACC: lambda a,b,c,dtype: f"mad({a},{b},{c})"}
OpenCLRenderer = functools.partial(uops_to_cstyle, OpenCLLanguage())

View File

@ -13,7 +13,7 @@ class WGSLLanguage(CStyleLanguage):
barrier="workgroupBarrier();"
generic_var_prefix = "var "
external_local_bufs = True
code_for_op = { **CStyleLanguage().code_for_op, BinaryOps.CMPLT: lambda x,y: f"f32({x}<{y})", TernaryOps.MULACC: lambda x,y,z: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c: f"select({c},{b},{a}!=0.)" }
code_for_op = { **CStyleLanguage().code_for_op, BinaryOps.CMPLT: lambda x,y,dtype: f"f32({x}<{y})", TernaryOps.MULACC: lambda x,y,z,dtype: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c,dtype: f"select({c},{b},{a}!=0.)" }
def render_local(self, name: str, size: int):
return f"var<workgroup> {name}: array<f32,{size}>;"