Update ssa input order and annotate types in cstyle and assembly (#4117)

variable prefix is never optional (removed the default "t") and UOp can be optional (added the default None).
This commit is contained in:
chenyu 2024-04-09 13:10:29 -04:00 committed by GitHub
parent 15f2f39658
commit 1ef9c50fd7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 31 additions and 30 deletions

View File

@ -1,4 +1,4 @@
from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple
from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple, Optional, cast
import functools, struct
from collections import defaultdict
from tinygrad.codegen.linearizer import UOps, UOp
@ -97,16 +97,16 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
c: DefaultDict[str, int] = defaultdict(int)
r: Dict[UOp, Union[List[str], str]] = {}
def ssa(u, prefix="t", dtype=None) -> str:
def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
nonlocal c, r
prefix += f"_{dtype if dtype else lang.types[u.dtype]}_"
prefix += f"_{dtype if dtype is not None else lang.types[cast(DType, cast(UOp, u).dtype)]}_"
c[prefix] += 1
if u: r[u] = f"%{prefix}{c[prefix]-1}"
if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
return f"%{prefix}{c[prefix]-1}"
c_label: DefaultDict[str, int] = defaultdict(int)
r_label: Dict[UOp, str] = {}
def ssa_label(u, prefix):
def ssa_label(prefix:str, u:UOp):
nonlocal c_label, r_label
c_label[prefix] += 1
r_label[u] = f"{lang.label_prefix}{prefix}_{c_label[prefix]-1}"
@ -114,26 +114,26 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
def const(x:ConstType, dtype:DType, mov=False):
if mov or dtype in lang.const_requires_mov:
kk(*lang.render_const(x, dtype, mov=(out:=ssa(None, 'const', lang.types[dtype]))))
kk(*lang.render_const(x, dtype, mov=(out:=ssa('const', dtype=lang.types[dtype]))))
return out
return lang.render_const(x, dtype)
def cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
if atype == dtype:
if u: r[u] = a
return a
kk(*lang.render_cast((ret:=ssa(u, 'cast', lang.types[dtype])), a, dtype, atype, bitcast))
kk(*lang.render_cast((ret:=ssa('cast', u, lang.types[dtype])), a, dtype, atype, bitcast))
return ret
for u in uops:
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
if uop is UOps.IF:
assert vin[0].dtype is not None
kk(*lang.render_bra(lb:=ssa_label(u, 'if'), cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
kk(*lang.render_bra(lb:=ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
elif uop is UOps.BARRIER and lang.barrier: kk(lang.barrier)
elif uop is UOps.ENDLOOP:
kk(lang.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, lang.types[dtypes.int]),
lang.asm_for_op[BinaryOps.CMPLT](pred:=ssa(None, "pred", "pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, lang.types[dtypes.int]))
lang.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, lang.types[dtypes.int]))
kk(*lang.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
elif uop is UOps.ENDIF:
kk(f"{r_label[vin[0]]}:")
@ -146,19 +146,19 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
kk(*lang.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=u.arg, offset=vin[1].arg))
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.LOOP: kk(*lang.render_loop(ssa(u, 'ridx'), r[vin[0]], ssa_label(u, 'loop')))
if uop is UOps.LOOP: kk(*lang.render_loop(ssa('ridx', u), r[vin[0]], ssa_label('loop', u)))
elif uop is UOps.ALU:
assert vin[0].dtype is not None
if args is BinaryOps.CMPLT or args is BinaryOps.CMPEQ:
# pass in the other dtype here
kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], vin[0].dtype, lang.types[vin[0].dtype]))
kk(lang.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], vin[0].dtype, lang.types[vin[0].dtype]))
else:
kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], dtype, lang.types[dtype]))
kk(lang.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], dtype, lang.types[dtype]))
elif uop is UOps.DEFINE_ACC:
if dtype.count > 1:
r[u] = [ssa(None, 'acc', lang.types[dtype.scalar()]) for _ in range(dtype.count)]
r[u] = [ssa('acc', dtype=lang.types[dtype.scalar()]) for _ in range(dtype.count)]
for uu in r[u]: kk(f"mov.b{lang.types[dtype.scalar()][1:]} {uu}, {const(args, dtype.scalar())};")
else: kk(f"mov.b{lang.types[dtype][1:]} {ssa(u, 'acc')}, {const(args, dtype)};")
else: kk(f"mov.b{lang.types[dtype][1:]} {ssa('acc', u)}, {const(args, dtype)};")
elif uop is UOps.SPECIAL:
assert args[1][0] != "i", "idx not supported"
kk(f"mov.u32 %{args[1]}, {(lang.gid if args[1][0] == 'g' else lang.lid)[args[0]]};")
@ -171,13 +171,13 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
elif uop is UOps.LOAD:
assert vin[1].dtype is not None
if dtype.count > 1:
r[u] = [ssa(None, 'val', lang.types[dtype.scalar()]) for _ in range(dtype.count)]
r[u] = [ssa('val', dtype=lang.types[dtype.scalar()]) for _ in range(dtype.count)]
if(len(vin)>3):
for v in r[u]: kk(f"mov.{lang.mem_type(dtype.scalar())} {v}, {render_val(0, dtype.scalar())};")
kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
+ f" ld{u.arg}.v{dtype.count}.{lang.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
else:
kk(*lang.render_load(r[vin[0]], ssa(u, 'val'), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
kk(*lang.render_load(r[vin[0]], ssa('val', u), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
alt=r[vin[3]] if len(vin) > 3 else None, ss=u.arg, offset=vin[1].arg))
elif uop is UOps.PHI:
kk(f"mov.b{lang.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
@ -185,26 +185,26 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
elif uop in {UOps.CAST, UOps.BITCAST}:
assert vin[0].dtype is not None
if dtype.count>1: r[u] = [r[x] for x in vin] # type: ignore
else: cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
else: _cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
elif uop is UOps.DEFINE_LOCAL:
# TODO: we should sum these, and fetch 0xC000 from somewhere
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
kk(*lang.render_local(ssa(u, 'local', lang.types[dtypes.ulong]), args[0], args[1], dtype))
kk(*lang.render_local(ssa('local', u, lang.types[dtypes.ulong]), args[0], args[1], dtype))
elif uop is UOps.DEFINE_VAR:
bufs.append((args.expr, dtype))
r[u] = f"%{args.expr}"
if lang.load_global: kk(*lang.render_load(args.expr, ssa(u, 'dat', dtype=lang.types[dtype]), dtype, ss=".param"))
if lang.load_global: kk(*lang.render_load(args.expr, ssa('dat', u, lang.types[dtype]), dtype, ss=".param"))
elif uop is UOps.DEFINE_GLOBAL:
bufs.append((args[1], dtype))
r[u] = f"%{args[1]}"
if lang.load_global:
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
kk(*lang.render_load(args[1], ssa(u, 'dat', dtype=lang.types[dt]), dt, ss=".param"))
kk(*lang.render_load(args[1], ssa('dat', u, lang.types[dt]), dt, ss=".param"))
elif uop is UOps.WMMA:
wmma = []
for vv in vin[:2]:
for i in range(0, len(r[vv]), 2):
wmma.append(ssa(None, "wmma", "b32"))
wmma.append(ssa("wmma", dtype="b32"))
kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
r[u] = r[vin[2]]
kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\

View File

@ -95,7 +95,8 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
c: DefaultDict[str, int] = defaultdict(int)
r: Dict[UOp, str] = {}
def ssa(u, prefix="t"):
def ssa(prefix:str, u:Optional[UOp]=None):
nonlocal c, r
ret = f"{prefix}{c[prefix]}"
if u is not None: r[u] = ret
@ -121,7 +122,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.LOOP:
kk(f"for (int {(expr := ssa(u,'ridx'))} = {r[vin[0]]}; {expr} < {r[vin[1]]}; {expr}++) {{")
kk(f"for (int {(expr := ssa('ridx',u))} = {r[vin[0]]}; {expr} < {r[vin[1]]}; {expr}++) {{")
depth += 1
elif uop is UOps.ALU:
# remove parens if ALU types are the same. TODO: can do more here
@ -131,7 +132,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
assert child_count[u] != 0, f"childless ALU op found {u}"
# TODO: fix index rendering issue. fix clang nested max macro issue
if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
else: kk(f"{lang.render_dtype(dtype)} {ssa(u,'alu')} = {val};")
else: kk(f"{lang.render_dtype(dtype)} {ssa('alu',u)} = {val};")
elif uop is UOps.SPECIAL:
kk(f"int {args[1]} = {lang.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
r[u] = args[1]
@ -139,20 +140,20 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
# NOTE: this relies on the load not happening if it's in the unselected branch
if len(vin) > 3: val = lang.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype)
kk(f"{lang.render_dtype(dtype)} {ssa(u,'val')} = {val};")
kk(f"{lang.render_dtype(dtype)} {ssa('val',u)} = {val};")
elif uop is UOps.PHI:
kk(f"{r[vin[0]]} = {r[vin[1]]};")
r[u] = r[vin[0]]
elif uop in {UOps.CAST, UOps.BITCAST}:
if uop is UOps.BITCAST:
assert len(vin) == 1
precast = ssa(None,'precast')
precast = ssa('precast')
kk(f"{lang.render_dtype(cast(DType, vin[0].dtype))} {precast} = {r[vin[0]]};")
val = lang.render_cast([precast], dtype, bitcast=True)
else:
val = lang.render_cast([r[x] for x in vin], dtype, bitcast=False)
if child_count[u] <= 1: r[u] = val
else: kk(f"{lang.render_dtype(dtype)} {ssa(u,'cast')} = {val};")
else: kk(f"{lang.render_dtype(dtype)} {ssa('cast',u)} = {val};")
elif uop is UOps.DEFINE_LOCAL:
kk(lang.render_local(args[0], dtype, args[1]))
r[u] = args[0]
@ -163,8 +164,8 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
assert len(bufs) == args[0], f"missed a global buffer {len(bufs)} {args}"
bufs.append((args[1], (dtype,args[2])))
r[u] = args[1]
elif uop is UOps.WMMA: kk(f"{lang.render_dtype(dtype)} {ssa(u, 'wmma')} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
elif uop is UOps.DEFINE_ACC: kk(f"{lang.render_dtype(dtype)} {ssa(u,'acc')} = {lang.render_const(args, dtype)};")
elif uop is UOps.WMMA: kk(f"{lang.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
elif uop is UOps.DEFINE_ACC: kk(f"{lang.render_dtype(dtype)} {ssa('acc',u)} = {lang.render_const(args, dtype)};")
elif uop is UOps.CONST: r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"
elif uop is UOps.GEP:
assert vin[0].dtype is not None