mirror of https://github.com/commaai/tinygrad.git
real minimum cstyle change (#6709)
* real minimum cstyle change * make it match * bring back DEFINE_GLOBAL store marking writable * bump line count to 9800 * closer * precompute don't render * cast/bitcast too * smem_align * vectorize * more pr match * remove that test * less PR diff
This commit is contained in:
parent
e6a1b5aa8f
commit
dd575da7ee
|
@ -534,9 +534,6 @@ jobs:
|
|||
run: TRANSCENDENTAL=2 python -m pytest -n=auto test/test_ops.py::TestOps::test_sin test/test_ops.py::TestOps::test_cos test/test_ops.py::TestOps::test_tan test/test_ops.py::TestOps::test_exp test/test_ops.py::TestOps::test_log --durations=20
|
||||
- name: Run process replay tests
|
||||
run: |
|
||||
if [ "${{ matrix.backend }}" == "amd" ] && [ "${GITHUB_REF_NAME}" != "master" ]; then
|
||||
MAX_DIFF_PCT=1 RUN_PROCESS_REPLAY=0 test/external/process_replay/test_process_replay.sh
|
||||
fi
|
||||
export PR_TITLE=$(jq -r .pull_request.title "$GITHUB_EVENT_PATH")
|
||||
export COMMIT_MESSAGE=$(git show -s --format=%B ${{ github.event.pull_request.head.sha }})
|
||||
cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py
|
||||
|
|
|
@ -1,11 +1,57 @@
|
|||
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable
|
||||
from __future__ import annotations
|
||||
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast
|
||||
import os, math
|
||||
from collections import defaultdict, Counter
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOps, UOp
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOps, UOp, PatternMatcher, UPat
|
||||
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
|
||||
from tinygrad.renderer import Renderer, TensorCore
|
||||
|
||||
def render_load(r:CStyleLanguage, x:UOp):
|
||||
val = r.render_load(x.dtype, r[x.src[0]], x.src[0].dtype, strip_parens(r[x.src[1]]))
|
||||
# NOTE: this relies on the load not happening if it's in the unselected branch
|
||||
if len(x.src) > 3 and x.src[3].op is UOps.ALU: val = r.code_for_op[TernaryOps.WHERE](r[x.src[3]], val, r[x.src[2]], x.dtype)
|
||||
return val
|
||||
|
||||
def render_store(r:CStyleLanguage, x:UOp):
|
||||
assert isinstance(x.src[0].dtype, (ImageDType, PtrDType))
|
||||
rendered_store = r.render_store(r[x.src[0]], x.src[0].dtype, r[x.src[2]], x.src[2].dtype, strip_parens(r[x.src[1]]))
|
||||
return f"if ({r[x.src[3]]}) {{ {rendered_store} }}" if len(x.src) > 3 and x.src[3].op is not UOps.IF else rendered_store
|
||||
|
||||
def render_alu(r:CStyleLanguage, x:UOp):
|
||||
if x.arg in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == x.arg else r[v] for v in x.src]
|
||||
elif x.arg is BinaryOps.MAX: operands = [r.render_cast(r[v], v.dtype) if v.op is UOps.CONST else r[v] for v in x.src]
|
||||
else: operands = [r[v] for v in x.src]
|
||||
return r.code_for_op[x.arg](*operands, x.dtype)
|
||||
|
||||
def render_gep(r:CStyleLanguage, x:UOp):
|
||||
from_ssa = x.src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
|
||||
return (r[x.src[0]] if from_ssa else f"{(r[x.src[0]])}") + \
|
||||
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if r.device in {"CUDA", "NV"} else 4) \
|
||||
or r.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")
|
||||
|
||||
base_pm = PatternMatcher([
|
||||
(UPat(UOps.DEFINE_ACC, name="x"), lambda r,x: r[x.src[0]]),
|
||||
(UPat(UOps.ASSIGN, name="x"), lambda r,x: f"{r[x.src[0]]} = {r[x.src[1]]};"),
|
||||
(UPat(UOps.IF, name="x"), lambda r,x: f"if ({r[x.src[0]]}) {{"),
|
||||
(UPat((UOps.ENDIF, UOps.ENDRANGE)), lambda r: "}"),
|
||||
(UPat(UOps.WMMA, name="x"), lambda r,x: f"__{x.arg[0]}({r[x.src[0]]}, {r[x.src[1]]}, {r[x.src[2]]})"),
|
||||
# r method accesses
|
||||
(UPat(UOps.CONST, name="x"), lambda r,x: r.render_const(x.arg, x.dtype) if x.arg >= 0 else f"({r.render_const(x.arg, x.dtype)})"),
|
||||
(UPat(UOps.RANGE, name="x"), lambda r,x: f"for ({r.render_dtype(x.dtype)} {r[x]} = {r[x.src[0]]}; {r[x]} < {r[x.src[1]]}; {r[x]}++) {{"),
|
||||
(UPat(UOps.VECTORIZE, name="x"), lambda r,x: r.render_vectorize([r[y] for y in x.src], x.dtype)),
|
||||
(UPat(UOps.CAST, name="x"), lambda r,x: r.render_cast(r[x.src[0]], x.dtype, False)),
|
||||
(UPat(UOps.BITCAST, name="x"), lambda r,x: r.render_cast(r[x.src[0]], x.dtype, True)),
|
||||
(UPat(UOps.DEFINE_LOCAL, name="x"), lambda r,x: f"{r.smem_align}{r.smem_prefix}{r.render_dtype(x.dtype.base)} {r[x]}[{x.arg[1]}];"),
|
||||
(UPat(UOps.BARRIER), lambda r: r.barrier),
|
||||
(UPat(UOps.SPECIAL, name="x"), lambda r,x: f"{r.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"),
|
||||
# function calls
|
||||
(UPat(UOps.LOAD, name="x"), render_load),
|
||||
(UPat(UOps.STORE, name="x"), render_store),
|
||||
(UPat(UOps.ALU, name="x"), render_alu),
|
||||
(UPat(UOps.GEP, name="x"), render_gep),
|
||||
])
|
||||
|
||||
class CStyleLanguage(Renderer):
|
||||
kernel_prefix: str = ""
|
||||
buffer_prefix: str = ""
|
||||
|
@ -89,101 +135,66 @@ class CStyleLanguage(Renderer):
|
|||
return f"*(({prefix}{self.render_dtype(var_dtype)}*)({buf_name}+{idx})) = {var_name};"
|
||||
return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
|
||||
|
||||
def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{self.render_dtype(dtype)} {name}[{size}];"
|
||||
def render_dtype(self, var_dtype:DType) -> str:
|
||||
return self.type_map.get(scalar:=var_dtype.scalar(), scalar.name) + (str(var_dtype.count) if (var_dtype.count) > 1 else "")
|
||||
|
||||
def __getitem__(self, key): return self.r[key] # hacky helper
|
||||
def render(self, name:str, uops:List[UOp]) -> str:
|
||||
kernel = []
|
||||
bufs: Dict[UOp, Tuple[str, Tuple[DType, bool]]] = {}
|
||||
depth = 1
|
||||
def kk(s): kernel.append(" "*depth+s)
|
||||
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
r: Dict[UOp, str] = {}
|
||||
self.r = r
|
||||
|
||||
def ssa(prefix:str, u:Optional[UOp]=None):
|
||||
nonlocal c, r
|
||||
ret = f"{prefix}{c[prefix]}"
|
||||
if u is not None: r[u] = ret
|
||||
c[prefix] += 1
|
||||
return ret
|
||||
|
||||
# get should render
|
||||
child_count = Counter(v for ru in uops for v in ru.src)
|
||||
|
||||
seen_vars = set()
|
||||
dont_render: Dict[UOp, bool] = {}
|
||||
for u in uops:
|
||||
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
|
||||
# these four uops don't have output dtypes
|
||||
if uop is UOps.IF:
|
||||
kk(f"if ({r[src[0]]}) {{")
|
||||
depth += 1
|
||||
elif uop is UOps.BARRIER: kk(self.barrier)
|
||||
elif uop in {UOps.ENDRANGE, UOps.ENDIF}:
|
||||
depth -= 1
|
||||
kk("}")
|
||||
elif uop is UOps.STORE:
|
||||
# mark DEFINE_GLOBAL buf as writable
|
||||
assert isinstance(src[0].dtype, (ImageDType, PtrDType))
|
||||
if src[0].op is UOps.DEFINE_GLOBAL: bufs[src[0]] = (bufs[src[0]][0], (bufs[src[0]][1][0], True))
|
||||
rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]))
|
||||
kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 and src[3].op is not UOps.IF else rendered_store)
|
||||
# bitcast src must be rendered (always earlier, so this is safe)
|
||||
if u.op is UOps.BITCAST: dont_render[u.src[0]] = False
|
||||
dont_render[u] = u.op in {UOps.CONST, UOps.GEP} or \
|
||||
(u.op in {UOps.VECTORIZE, UOps.ALU, UOps.CAST, UOps.BITCAST} and child_count[u] == 1 \
|
||||
and u.arg is not BinaryOps.MAX and not getenv("EXPAND_SSA"))
|
||||
|
||||
bufs: Dict[UOp, Tuple[str, Tuple[DType, bool]]] = {}
|
||||
kernel = []
|
||||
depth = 1
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
c['temp'] += 1 # hack for process replay
|
||||
for u in uops:
|
||||
if u.op is UOps.DEFINE_GLOBAL:
|
||||
r[u] = f"data{u.arg}"
|
||||
bufs[u] = (r[u], (u.dtype, False))
|
||||
continue
|
||||
if u.op is UOps.DEFINE_VAR:
|
||||
r[u] = u.arg[0]
|
||||
bufs[u] = (r[u], (u.dtype, False))
|
||||
continue
|
||||
|
||||
# mark buffers that we store to writable
|
||||
if u.op is UOps.STORE and u.src[0].op is UOps.DEFINE_GLOBAL: bufs[u.src[0]] = (bufs[u.src[0]][0], (bufs[u.src[0]][1][0], True))
|
||||
|
||||
# naming
|
||||
prefix = None
|
||||
if u.op is UOps.SPECIAL:
|
||||
r[u] = u.arg[0]
|
||||
else:
|
||||
if uop is UOps.RANGE:
|
||||
kk(f"for (int {(expr := ssa('ridx',u))} = {r[src[0]]}; {expr} < {r[src[1]]}; {expr}++) {{")
|
||||
depth += 1
|
||||
elif uop is UOps.ALU:
|
||||
# remove parens if ALU types are the same. TODO: can do more here
|
||||
if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in src]
|
||||
elif args is BinaryOps.MAX: operands = [self.render_cast(r[v], v.dtype) if v.op is UOps.CONST else r[v] for v in src]
|
||||
else: operands = [r[v] for v in src]
|
||||
val = self.code_for_op[args](*operands, dtype)
|
||||
assert child_count[u] != 0, f"childless ALU op found {u}"
|
||||
# TODO: fix index rendering issue. fix clang nested max macro issue
|
||||
if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
|
||||
else: kk(f"{self.render_dtype(dtype)} {ssa('alu',u)} = {val};")
|
||||
elif uop is UOps.SPECIAL:
|
||||
kk(f"int {args[0]} = {self.code_for_workitem[args[0][0]](args[0][-1])}; /* {args[1]} */")
|
||||
r[u] = args[0]
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
assert args[0] not in seen_vars, f"duplicate variable {args[0]}"
|
||||
seen_vars.add(args[0])
|
||||
bufs[u] = (args[0], (dtype,False))
|
||||
r[u] = args[0]
|
||||
elif uop is UOps.LOAD:
|
||||
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]))
|
||||
# NOTE: this relies on the load not happening if it's in the unselected branch
|
||||
if len(src) > 3 and src[3].op is UOps.ALU: val = self.code_for_op[TernaryOps.WHERE](r[src[3]], val, r[src[2]], dtype)
|
||||
kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
|
||||
elif uop is UOps.ASSIGN:
|
||||
kk(f"{r[src[0]]} = {r[src[1]]};")
|
||||
r[u] = r[src[0]]
|
||||
elif uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}:
|
||||
assert len(src) == 1 or (uop is UOps.VECTORIZE and len(src) > 1), "Invalid source length for operation"
|
||||
if uop is UOps.BITCAST:
|
||||
precast = ssa('precast')
|
||||
kk(f"{self.render_dtype(src[0].dtype)} {precast} = {r[src[0]]};")
|
||||
val = self.render_cast(precast, dtype, bitcast=True)
|
||||
elif uop is UOps.CAST: val = self.render_cast(r[src[0]], dtype, bitcast=False)
|
||||
else: val = self.render_vectorize([r[x] for x in src], dtype)
|
||||
if child_count[u] <= 1: r[u] = val
|
||||
else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};")
|
||||
elif uop is UOps.DEFINE_LOCAL:
|
||||
kk(self.render_local(args[0], dtype, args[1]))
|
||||
r[u] = args[0]
|
||||
elif uop is UOps.DEFINE_GLOBAL:
|
||||
bufs[u] = (nm:=f"data{args}", (dtype, False))
|
||||
r[u] = nm
|
||||
elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});")
|
||||
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {r[src[0]]};")
|
||||
elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
|
||||
elif uop is UOps.GEP:
|
||||
assert len(args) == 1
|
||||
from_ssa = src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
|
||||
r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + \
|
||||
(f"[{args[0]}]" if src[0].dtype.count > (8 if self.device in {"CUDA", "NV"} else 4) \
|
||||
or self.device == 'CLANG' else f".{'xyzwabcd'[args[0]]}")
|
||||
else: raise RuntimeError(f"failed to render {u}")
|
||||
prefix = {UOps.RANGE: "ridx", UOps.ALU: "alu", UOps.WMMA: "wmma", UOps.DEFINE_LOCAL: "temp", UOps.CONST: "const",
|
||||
UOps.CAST: "cast", UOps.BITCAST: "cast", UOps.GEP: "gep", UOps.VECTORIZE: "cast",
|
||||
UOps.DEFINE_ACC: "acc", UOps.LOAD: "val"}.get(u.op, "unk")
|
||||
r[u] = f"{prefix}{c[prefix]}"
|
||||
|
||||
l = cast(str, base_pm.rewrite(u, ctx=self))
|
||||
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
|
||||
|
||||
if u.op in {UOps.ENDIF, UOps.ENDRANGE}: depth -= 1
|
||||
if dont_render[u]: r[u] = l
|
||||
else:
|
||||
if u.op in {UOps.RANGE, UOps.ASSIGN, UOps.DEFINE_LOCAL} or u.dtype == dtypes.void:
|
||||
if u.op is UOps.ASSIGN: r[u] = r[u.src[0]]
|
||||
else:
|
||||
l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not UOps.SPECIAL else "")
|
||||
kernel.append(" "*depth + l)
|
||||
if prefix: c[prefix] += 1 # if it was used, increment
|
||||
if u.op in {UOps.IF, UOps.RANGE}: depth += 1
|
||||
del self.r
|
||||
|
||||
# NOTE: this relies on bufs dict preserving order
|
||||
return self.render_kernel(name, kernel, list(bufs.values()), uops)
|
||||
|
|
Loading…
Reference in New Issue