real minimum cstyle change (#6709)

* real minimum cstyle change

* make it match

* bring back DEFINE_GLOBAL store marking writable

* bump line count to 9800

* closer

* precompute don't render

* cast/bitcast too

* smem_align

* vectorize

* more pr match

* remove that test

* less PR diff
This commit is contained in:
George Hotz 2024-09-25 12:40:46 +08:00 committed by GitHub
parent e6a1b5aa8f
commit dd575da7ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 99 additions and 91 deletions

View File

@ -534,9 +534,6 @@ jobs:
run: TRANSCENDENTAL=2 python -m pytest -n=auto test/test_ops.py::TestOps::test_sin test/test_ops.py::TestOps::test_cos test/test_ops.py::TestOps::test_tan test/test_ops.py::TestOps::test_exp test/test_ops.py::TestOps::test_log --durations=20
- name: Run process replay tests
run: |
if [ "${{ matrix.backend }}" == "amd" ] && [ "${GITHUB_REF_NAME}" != "master" ]; then
MAX_DIFF_PCT=1 RUN_PROCESS_REPLAY=0 test/external/process_replay/test_process_replay.sh
fi
export PR_TITLE=$(jq -r .pull_request.title "$GITHUB_EVENT_PATH")
export COMMIT_MESSAGE=$(git show -s --format=%B ${{ github.event.pull_request.head.sha }})
cp test/external/process_replay/process_replay.py ./process_replay.py && git fetch origin master && git -c advice.detachedHead=false checkout origin/master && PYTHONPATH=. python3 process_replay.py

View File

@ -1,11 +1,57 @@
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable
from __future__ import annotations
from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast
import os, math
from collections import defaultdict, Counter
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOps, UOp
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOps, UOp, PatternMatcher, UPat
from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
from tinygrad.renderer import Renderer, TensorCore
def render_load(r:CStyleLanguage, x:UOp):
val = r.render_load(x.dtype, r[x.src[0]], x.src[0].dtype, strip_parens(r[x.src[1]]))
# NOTE: this relies on the load not happening if it's in the unselected branch
if len(x.src) > 3 and x.src[3].op is UOps.ALU: val = r.code_for_op[TernaryOps.WHERE](r[x.src[3]], val, r[x.src[2]], x.dtype)
return val
def render_store(r:CStyleLanguage, x:UOp):
assert isinstance(x.src[0].dtype, (ImageDType, PtrDType))
rendered_store = r.render_store(r[x.src[0]], x.src[0].dtype, r[x.src[2]], x.src[2].dtype, strip_parens(r[x.src[1]]))
return f"if ({r[x.src[3]]}) {{ {rendered_store} }}" if len(x.src) > 3 and x.src[3].op is not UOps.IF else rendered_store
def render_alu(r:CStyleLanguage, x:UOp):
if x.arg in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == x.arg else r[v] for v in x.src]
elif x.arg is BinaryOps.MAX: operands = [r.render_cast(r[v], v.dtype) if v.op is UOps.CONST else r[v] for v in x.src]
else: operands = [r[v] for v in x.src]
return r.code_for_op[x.arg](*operands, x.dtype)
def render_gep(r:CStyleLanguage, x:UOp):
from_ssa = x.src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
return (r[x.src[0]] if from_ssa else f"{(r[x.src[0]])}") + \
(f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if r.device in {"CUDA", "NV"} else 4) \
or r.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")
base_pm = PatternMatcher([
(UPat(UOps.DEFINE_ACC, name="x"), lambda r,x: r[x.src[0]]),
(UPat(UOps.ASSIGN, name="x"), lambda r,x: f"{r[x.src[0]]} = {r[x.src[1]]};"),
(UPat(UOps.IF, name="x"), lambda r,x: f"if ({r[x.src[0]]}) {{"),
(UPat((UOps.ENDIF, UOps.ENDRANGE)), lambda r: "}"),
(UPat(UOps.WMMA, name="x"), lambda r,x: f"__{x.arg[0]}({r[x.src[0]]}, {r[x.src[1]]}, {r[x.src[2]]})"),
# r method accesses
(UPat(UOps.CONST, name="x"), lambda r,x: r.render_const(x.arg, x.dtype) if x.arg >= 0 else f"({r.render_const(x.arg, x.dtype)})"),
(UPat(UOps.RANGE, name="x"), lambda r,x: f"for ({r.render_dtype(x.dtype)} {r[x]} = {r[x.src[0]]}; {r[x]} < {r[x.src[1]]}; {r[x]}++) {{"),
(UPat(UOps.VECTORIZE, name="x"), lambda r,x: r.render_vectorize([r[y] for y in x.src], x.dtype)),
(UPat(UOps.CAST, name="x"), lambda r,x: r.render_cast(r[x.src[0]], x.dtype, False)),
(UPat(UOps.BITCAST, name="x"), lambda r,x: r.render_cast(r[x.src[0]], x.dtype, True)),
(UPat(UOps.DEFINE_LOCAL, name="x"), lambda r,x: f"{r.smem_align}{r.smem_prefix}{r.render_dtype(x.dtype.base)} {r[x]}[{x.arg[1]}];"),
(UPat(UOps.BARRIER), lambda r: r.barrier),
(UPat(UOps.SPECIAL, name="x"), lambda r,x: f"{r.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"),
# function calls
(UPat(UOps.LOAD, name="x"), render_load),
(UPat(UOps.STORE, name="x"), render_store),
(UPat(UOps.ALU, name="x"), render_alu),
(UPat(UOps.GEP, name="x"), render_gep),
])
class CStyleLanguage(Renderer):
kernel_prefix: str = ""
buffer_prefix: str = ""
@ -89,101 +135,66 @@ class CStyleLanguage(Renderer):
return f"*(({prefix}{self.render_dtype(var_dtype)}*)({buf_name}+{idx})) = {var_name};"
return f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};"
def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{self.render_dtype(dtype)} {name}[{size}];"
def render_dtype(self, var_dtype:DType) -> str:
return self.type_map.get(scalar:=var_dtype.scalar(), scalar.name) + (str(var_dtype.count) if (var_dtype.count) > 1 else "")
def __getitem__(self, key): return self.r[key] # hacky helper
def render(self, name:str, uops:List[UOp]) -> str:
kernel = []
bufs: Dict[UOp, Tuple[str, Tuple[DType, bool]]] = {}
depth = 1
def kk(s): kernel.append(" "*depth+s)
c: DefaultDict[str, int] = defaultdict(int)
r: Dict[UOp, str] = {}
self.r = r
def ssa(prefix:str, u:Optional[UOp]=None):
nonlocal c, r
ret = f"{prefix}{c[prefix]}"
if u is not None: r[u] = ret
c[prefix] += 1
return ret
# get should render
child_count = Counter(v for ru in uops for v in ru.src)
seen_vars = set()
dont_render: Dict[UOp, bool] = {}
for u in uops:
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
# these four uops don't have output dtypes
if uop is UOps.IF:
kk(f"if ({r[src[0]]}) {{")
depth += 1
elif uop is UOps.BARRIER: kk(self.barrier)
elif uop in {UOps.ENDRANGE, UOps.ENDIF}:
depth -= 1
kk("}")
elif uop is UOps.STORE:
# mark DEFINE_GLOBAL buf as writable
assert isinstance(src[0].dtype, (ImageDType, PtrDType))
if src[0].op is UOps.DEFINE_GLOBAL: bufs[src[0]] = (bufs[src[0]][0], (bufs[src[0]][1][0], True))
rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]))
kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 and src[3].op is not UOps.IF else rendered_store)
# bitcast src must be rendered (always earlier, so this is safe)
if u.op is UOps.BITCAST: dont_render[u.src[0]] = False
dont_render[u] = u.op in {UOps.CONST, UOps.GEP} or \
(u.op in {UOps.VECTORIZE, UOps.ALU, UOps.CAST, UOps.BITCAST} and child_count[u] == 1 \
and u.arg is not BinaryOps.MAX and not getenv("EXPAND_SSA"))
bufs: Dict[UOp, Tuple[str, Tuple[DType, bool]]] = {}
kernel = []
depth = 1
c: DefaultDict[str, int] = defaultdict(int)
c['temp'] += 1 # hack for process replay
for u in uops:
if u.op is UOps.DEFINE_GLOBAL:
r[u] = f"data{u.arg}"
bufs[u] = (r[u], (u.dtype, False))
continue
if u.op is UOps.DEFINE_VAR:
r[u] = u.arg[0]
bufs[u] = (r[u], (u.dtype, False))
continue
# mark buffers that we store to writable
if u.op is UOps.STORE and u.src[0].op is UOps.DEFINE_GLOBAL: bufs[u.src[0]] = (bufs[u.src[0]][0], (bufs[u.src[0]][1][0], True))
# naming
prefix = None
if u.op is UOps.SPECIAL:
r[u] = u.arg[0]
else:
if uop is UOps.RANGE:
kk(f"for (int {(expr := ssa('ridx',u))} = {r[src[0]]}; {expr} < {r[src[1]]}; {expr}++) {{")
depth += 1
elif uop is UOps.ALU:
# remove parens if ALU types are the same. TODO: can do more here
if args in {BinaryOps.ADD,BinaryOps.MUL,BinaryOps.XOR}: operands = [strip_parens(r[v]) if v.arg == args else r[v]for v in src]
elif args is BinaryOps.MAX: operands = [self.render_cast(r[v], v.dtype) if v.op is UOps.CONST else r[v] for v in src]
else: operands = [r[v] for v in src]
val = self.code_for_op[args](*operands, dtype)
assert child_count[u] != 0, f"childless ALU op found {u}"
# TODO: fix index rendering issue. fix clang nested max macro issue
if child_count[u] <= 1 and args is not BinaryOps.MAX and not getenv("EXPAND_SSA"): r[u] = val
else: kk(f"{self.render_dtype(dtype)} {ssa('alu',u)} = {val};")
elif uop is UOps.SPECIAL:
kk(f"int {args[0]} = {self.code_for_workitem[args[0][0]](args[0][-1])}; /* {args[1]} */")
r[u] = args[0]
elif uop is UOps.DEFINE_VAR:
assert args[0] not in seen_vars, f"duplicate variable {args[0]}"
seen_vars.add(args[0])
bufs[u] = (args[0], (dtype,False))
r[u] = args[0]
elif uop is UOps.LOAD:
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]))
# NOTE: this relies on the load not happening if it's in the unselected branch
if len(src) > 3 and src[3].op is UOps.ALU: val = self.code_for_op[TernaryOps.WHERE](r[src[3]], val, r[src[2]], dtype)
kk(f"{self.render_dtype(dtype)} {ssa('val',u)} = {val};")
elif uop is UOps.ASSIGN:
kk(f"{r[src[0]]} = {r[src[1]]};")
r[u] = r[src[0]]
elif uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}:
assert len(src) == 1 or (uop is UOps.VECTORIZE and len(src) > 1), "Invalid source length for operation"
if uop is UOps.BITCAST:
precast = ssa('precast')
kk(f"{self.render_dtype(src[0].dtype)} {precast} = {r[src[0]]};")
val = self.render_cast(precast, dtype, bitcast=True)
elif uop is UOps.CAST: val = self.render_cast(r[src[0]], dtype, bitcast=False)
else: val = self.render_vectorize([r[x] for x in src], dtype)
if child_count[u] <= 1: r[u] = val
else: kk(f"{self.render_dtype(dtype)} {ssa('cast',u)} = {val};")
elif uop is UOps.DEFINE_LOCAL:
kk(self.render_local(args[0], dtype, args[1]))
r[u] = args[0]
elif uop is UOps.DEFINE_GLOBAL:
bufs[u] = (nm:=f"data{args}", (dtype, False))
r[u] = nm
elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});")
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {r[src[0]]};")
elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
elif uop is UOps.GEP:
assert len(args) == 1
from_ssa = src[0].op in {UOps.LOAD, UOps.WMMA, UOps.DEFINE_ACC}
r[u] = (r[src[0]] if from_ssa else f"{(r[src[0]])}") + \
(f"[{args[0]}]" if src[0].dtype.count > (8 if self.device in {"CUDA", "NV"} else 4) \
or self.device == 'CLANG' else f".{'xyzwabcd'[args[0]]}")
else: raise RuntimeError(f"failed to render {u}")
prefix = {UOps.RANGE: "ridx", UOps.ALU: "alu", UOps.WMMA: "wmma", UOps.DEFINE_LOCAL: "temp", UOps.CONST: "const",
UOps.CAST: "cast", UOps.BITCAST: "cast", UOps.GEP: "gep", UOps.VECTORIZE: "cast",
UOps.DEFINE_ACC: "acc", UOps.LOAD: "val"}.get(u.op, "unk")
r[u] = f"{prefix}{c[prefix]}"
l = cast(str, base_pm.rewrite(u, ctx=self))
assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
if u.op in {UOps.ENDIF, UOps.ENDRANGE}: depth -= 1
if dont_render[u]: r[u] = l
else:
if u.op in {UOps.RANGE, UOps.ASSIGN, UOps.DEFINE_LOCAL} or u.dtype == dtypes.void:
if u.op is UOps.ASSIGN: r[u] = r[u.src[0]]
else:
l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not UOps.SPECIAL else "")
kernel.append(" "*depth + l)
if prefix: c[prefix] += 1 # if it was used, increment
if u.op in {UOps.IF, UOps.RANGE}: depth += 1
del self.r
# NOTE: this relies on bufs dict preserving order
return self.render_kernel(name, kernel, list(bufs.values()), uops)