mirror of https://github.com/commaai/tinygrad.git
Update ssa input order and annotate types in cstyle and assembly (#4117)
variable prefix is never optional (removed the default "t") and UOp can be optional (added the default None).
This commit is contained in:
parent
15f2f39658
commit
1ef9c50fd7
|
@ -1,4 +1,4 @@
|
|||
from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple
|
||||
from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple, Optional, cast
|
||||
import functools, struct
|
||||
from collections import defaultdict
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
|
@ -97,16 +97,16 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
|||
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
r: Dict[UOp, Union[List[str], str]] = {}
|
||||
def ssa(u, prefix="t", dtype=None) -> str:
|
||||
def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str:
|
||||
nonlocal c, r
|
||||
prefix += f"_{dtype if dtype else lang.types[u.dtype]}_"
|
||||
prefix += f"_{dtype if dtype is not None else lang.types[cast(DType, cast(UOp, u).dtype)]}_"
|
||||
c[prefix] += 1
|
||||
if u: r[u] = f"%{prefix}{c[prefix]-1}"
|
||||
if u is not None: r[u] = f"%{prefix}{c[prefix]-1}"
|
||||
return f"%{prefix}{c[prefix]-1}"
|
||||
|
||||
c_label: DefaultDict[str, int] = defaultdict(int)
|
||||
r_label: Dict[UOp, str] = {}
|
||||
def ssa_label(u, prefix):
|
||||
def ssa_label(prefix:str, u:UOp):
|
||||
nonlocal c_label, r_label
|
||||
c_label[prefix] += 1
|
||||
r_label[u] = f"{lang.label_prefix}{prefix}_{c_label[prefix]-1}"
|
||||
|
@ -114,26 +114,26 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
|||
|
||||
def const(x:ConstType, dtype:DType, mov=False):
|
||||
if mov or dtype in lang.const_requires_mov:
|
||||
kk(*lang.render_const(x, dtype, mov=(out:=ssa(None, 'const', lang.types[dtype]))))
|
||||
kk(*lang.render_const(x, dtype, mov=(out:=ssa('const', dtype=lang.types[dtype]))))
|
||||
return out
|
||||
return lang.render_const(x, dtype)
|
||||
|
||||
def cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
|
||||
def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
|
||||
if atype == dtype:
|
||||
if u: r[u] = a
|
||||
return a
|
||||
kk(*lang.render_cast((ret:=ssa(u, 'cast', lang.types[dtype])), a, dtype, atype, bitcast))
|
||||
kk(*lang.render_cast((ret:=ssa('cast', u, lang.types[dtype])), a, dtype, atype, bitcast))
|
||||
return ret
|
||||
|
||||
for u in uops:
|
||||
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||
if uop is UOps.IF:
|
||||
assert vin[0].dtype is not None
|
||||
kk(*lang.render_bra(lb:=ssa_label(u, 'if'), cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
|
||||
kk(*lang.render_bra(lb:=ssa_label('if', u), _cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
|
||||
elif uop is UOps.BARRIER and lang.barrier: kk(lang.barrier)
|
||||
elif uop is UOps.ENDLOOP:
|
||||
kk(lang.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, lang.types[dtypes.int]),
|
||||
lang.asm_for_op[BinaryOps.CMPLT](pred:=ssa(None, "pred", "pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, lang.types[dtypes.int]))
|
||||
lang.asm_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, lang.types[dtypes.int]))
|
||||
kk(*lang.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
|
||||
elif uop is UOps.ENDIF:
|
||||
kk(f"{r_label[vin[0]]}:")
|
||||
|
@ -146,19 +146,19 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
|||
kk(*lang.render_store(r[vin[0]], r[vin[2]], vin[2].dtype, gate=r[vin[3]] if len(vin)>3 else None, ss=u.arg, offset=vin[1].arg))
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.LOOP: kk(*lang.render_loop(ssa(u, 'ridx'), r[vin[0]], ssa_label(u, 'loop')))
|
||||
if uop is UOps.LOOP: kk(*lang.render_loop(ssa('ridx', u), r[vin[0]], ssa_label('loop', u)))
|
||||
elif uop is UOps.ALU:
|
||||
assert vin[0].dtype is not None
|
||||
if args is BinaryOps.CMPLT or args is BinaryOps.CMPEQ:
|
||||
# pass in the other dtype here
|
||||
kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], vin[0].dtype, lang.types[vin[0].dtype]))
|
||||
kk(lang.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], vin[0].dtype, lang.types[vin[0].dtype]))
|
||||
else:
|
||||
kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], dtype, lang.types[dtype]))
|
||||
kk(lang.asm_for_op[args](ssa("alu", u), *[r[x] for x in vin], dtype, lang.types[dtype]))
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa(None, 'acc', lang.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
r[u] = [ssa('acc', dtype=lang.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
for uu in r[u]: kk(f"mov.b{lang.types[dtype.scalar()][1:]} {uu}, {const(args, dtype.scalar())};")
|
||||
else: kk(f"mov.b{lang.types[dtype][1:]} {ssa(u, 'acc')}, {const(args, dtype)};")
|
||||
else: kk(f"mov.b{lang.types[dtype][1:]} {ssa('acc', u)}, {const(args, dtype)};")
|
||||
elif uop is UOps.SPECIAL:
|
||||
assert args[1][0] != "i", "idx not supported"
|
||||
kk(f"mov.u32 %{args[1]}, {(lang.gid if args[1][0] == 'g' else lang.lid)[args[0]]};")
|
||||
|
@ -171,13 +171,13 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
|||
elif uop is UOps.LOAD:
|
||||
assert vin[1].dtype is not None
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa(None, 'val', lang.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
r[u] = [ssa('val', dtype=lang.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
if(len(vin)>3):
|
||||
for v in r[u]: kk(f"mov.{lang.mem_type(dtype.scalar())} {v}, {render_val(0, dtype.scalar())};")
|
||||
kk((f"@{r[vin[2]]}"if len(vin) > 3 else "")
|
||||
+ f" ld{u.arg}.v{dtype.count}.{lang.mem_type(dtype.scalar())} {{{', '.join(r[u])}}}, [{r[vin[0]]}+{vin[1].arg}];")
|
||||
else:
|
||||
kk(*lang.render_load(r[vin[0]], ssa(u, 'val'), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
|
||||
kk(*lang.render_load(r[vin[0]], ssa('val', u), dtype, gate=r[vin[2]] if len(vin) > 3 else None,
|
||||
alt=r[vin[3]] if len(vin) > 3 else None, ss=u.arg, offset=vin[1].arg))
|
||||
elif uop is UOps.PHI:
|
||||
kk(f"mov.b{lang.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
|
||||
|
@ -185,26 +185,26 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
|||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
assert vin[0].dtype is not None
|
||||
if dtype.count>1: r[u] = [r[x] for x in vin] # type: ignore
|
||||
else: cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
|
||||
else: _cast(r[vin[0]], dtype, vin[0].dtype, bitcast=uop is UOps.BITCAST, u=u)
|
||||
elif uop is UOps.DEFINE_LOCAL:
|
||||
# TODO: we should sum these, and fetch 0xC000 from somewhere
|
||||
assert args[1]*dtype.itemsize <= 0xC000, "too large local"
|
||||
kk(*lang.render_local(ssa(u, 'local', lang.types[dtypes.ulong]), args[0], args[1], dtype))
|
||||
kk(*lang.render_local(ssa('local', u, lang.types[dtypes.ulong]), args[0], args[1], dtype))
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
bufs.append((args.expr, dtype))
|
||||
r[u] = f"%{args.expr}"
|
||||
if lang.load_global: kk(*lang.render_load(args.expr, ssa(u, 'dat', dtype=lang.types[dtype]), dtype, ss=".param"))
|
||||
if lang.load_global: kk(*lang.render_load(args.expr, ssa('dat', u, lang.types[dtype]), dtype, ss=".param"))
|
||||
elif uop is UOps.DEFINE_GLOBAL:
|
||||
bufs.append((args[1], dtype))
|
||||
r[u] = f"%{args[1]}"
|
||||
if lang.load_global:
|
||||
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
|
||||
kk(*lang.render_load(args[1], ssa(u, 'dat', dtype=lang.types[dt]), dt, ss=".param"))
|
||||
kk(*lang.render_load(args[1], ssa('dat', u, lang.types[dt]), dt, ss=".param"))
|
||||
elif uop is UOps.WMMA:
|
||||
wmma = []
|
||||
for vv in vin[:2]:
|
||||
for i in range(0, len(r[vv]), 2):
|
||||
wmma.append(ssa(None, "wmma", "b32"))
|
||||
wmma.append(ssa("wmma", dtype="b32"))
|
||||
kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};')
|
||||
r[u] = r[vin[2]]
|
||||
kk(f'mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32\
|
||||
|
|
|
@ -95,7 +95,8 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
|
|||
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
r: Dict[UOp, str] = {}
|
||||
def ssa(u, prefix="t"):
|
||||
|
||||
def ssa(prefix:str, u:Optional[UOp]=None):
|
||||
nonlocal c, r
|
||||
ret = f"{prefix}{c[prefix]}"
|
||||
if u is not None: r[u] = ret
|
||||
|
@ -121,7 +122,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
|
|||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.LOOP:
|
||||
kk(f"for (int {(expr := ssa(u,'ridx'))} = {r[vin[0]]}; {expr} < {r[vin[1]]}; {expr}++) {{")
|
||||
kk(f"for (int {(expr := ssa('ridx',u))} = {r[vin[0]]}; {expr} < {r[vin[1]]}; {expr}++) {{")
|
||||
depth += 1
|
||||
elif uop is UOps.ALU:
|
||||
# remove parens if ALU types are the same. TODO: can do more here
|
||||
|
@ -131,7 +132,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
|
|||
assert child_count[u] != 0, f"childless ALU op found {u}"
|
||||
# TODO: fix index rendering issue. fix clang nested max macro issue
|
||||
if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
|
||||
else: kk(f"{lang.render_dtype(dtype)} {ssa(u,'alu')} = {val};")
|
||||
else: kk(f"{lang.render_dtype(dtype)} {ssa('alu',u)} = {val};")
|
||||
elif uop is UOps.SPECIAL:
|
||||
kk(f"int {args[1]} = {lang.code_for_workitem[args[1][0]](args[0])}; /* {args[2]} */")
|
||||
r[u] = args[1]
|
||||
|
@ -139,20 +140,20 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
|
|||
val = lang.render_load(dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop is UOps.DEFINE_LOCAL)
|
||||
# NOTE: this relies on the load not happening if it's in the unselected branch
|
||||
if len(vin) > 3: val = lang.code_for_op[TernaryOps.WHERE](r[vin[2]], val, r[vin[3]], dtype)
|
||||
kk(f"{lang.render_dtype(dtype)} {ssa(u,'val')} = {val};")
|
||||
kk(f"{lang.render_dtype(dtype)} {ssa('val',u)} = {val};")
|
||||
elif uop is UOps.PHI:
|
||||
kk(f"{r[vin[0]]} = {r[vin[1]]};")
|
||||
r[u] = r[vin[0]]
|
||||
elif uop in {UOps.CAST, UOps.BITCAST}:
|
||||
if uop is UOps.BITCAST:
|
||||
assert len(vin) == 1
|
||||
precast = ssa(None,'precast')
|
||||
precast = ssa('precast')
|
||||
kk(f"{lang.render_dtype(cast(DType, vin[0].dtype))} {precast} = {r[vin[0]]};")
|
||||
val = lang.render_cast([precast], dtype, bitcast=True)
|
||||
else:
|
||||
val = lang.render_cast([r[x] for x in vin], dtype, bitcast=False)
|
||||
if child_count[u] <= 1: r[u] = val
|
||||
else: kk(f"{lang.render_dtype(dtype)} {ssa(u,'cast')} = {val};")
|
||||
else: kk(f"{lang.render_dtype(dtype)} {ssa('cast',u)} = {val};")
|
||||
elif uop is UOps.DEFINE_LOCAL:
|
||||
kk(lang.render_local(args[0], dtype, args[1]))
|
||||
r[u] = args[0]
|
||||
|
@ -163,8 +164,8 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str
|
|||
assert len(bufs) == args[0], f"missed a global buffer {len(bufs)} {args}"
|
||||
bufs.append((args[1], (dtype,args[2])))
|
||||
r[u] = args[1]
|
||||
elif uop is UOps.WMMA: kk(f"{lang.render_dtype(dtype)} {ssa(u, 'wmma')} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
|
||||
elif uop is UOps.DEFINE_ACC: kk(f"{lang.render_dtype(dtype)} {ssa(u,'acc')} = {lang.render_const(args, dtype)};")
|
||||
elif uop is UOps.WMMA: kk(f"{lang.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});")
|
||||
elif uop is UOps.DEFINE_ACC: kk(f"{lang.render_dtype(dtype)} {ssa('acc',u)} = {lang.render_const(args, dtype)};")
|
||||
elif uop is UOps.CONST: r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})"
|
||||
elif uop is UOps.GEP:
|
||||
assert vin[0].dtype is not None
|
||||
|
|
Loading…
Reference in New Issue