From dfa562dbc17e087cbec5ebda00359ec92f3fe43a Mon Sep 17 00:00:00 2001 From: Jhenner Tigreros <32320832+JhennerTigreros@users.noreply.github.com> Date: Mon, 24 Jun 2024 11:25:33 -0500 Subject: [PATCH] DEFINE_ACC takes UOps.CONST in vin instead of arg (#4975) * Change DEFINE_ACC to receive UOps.CONST in vin * Use localtype instead of acc dtype * Fix idp * Fix copy list * Fix warp * Fix error * Fix merge * Fix testing * Fix merge * Use deepcopy * Change to copy of inp * Fix lint * Move const to first place * Fix issue upat * Fix upat patterns * Change to list, to test permutations * Add condition * Change pm * Revert change pm * Remove unused rule * Fix * Change of float4 DEFINE_ACC values * Cast on PM to correct dtype * Improve assert message * Move IFs --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> --- tinygrad/codegen/linearizer.py | 2 +- tinygrad/codegen/uops.py | 20 +++++++++++--------- tinygrad/renderer/assembly.py | 4 ++-- tinygrad/renderer/cstyle.py | 2 +- tinygrad/renderer/llvmir.py | 2 +- tinygrad/runtime/ops_python.py | 4 ++-- 6 files changed, 18 insertions(+), 16 deletions(-) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 5e4c8bed..f3fadf7e 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -145,7 +145,7 @@ class Linearizer(Kernel): key = f"{'' if acc is None else self.reduceops.index(acc)}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501 if key not in self.load_cache: if acc is not None: - self.load_cache[key] = UOp(UOps.DEFINE_ACC, localtype, loop_ctx, (self.get_reduce_acc(acc), i, acc_count)) + self.load_cache[key] = UOp(UOps.DEFINE_ACC, localtype, (UOp.const(localtype.scalar(), self.get_reduce_acc(acc)), *loop_ctx), (i, acc_count)) acc_count += 1 elif this_const is not None: self.load_cache[key] = UOp.const(localtype, this_const) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 4b0e5a57..6217f1a4 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -171,9 +171,9 @@ constant_folder = PatternMatcher([ src=[UPat(UOps.CONST, name="mval"), UPat(UOps.RANGE, src=(UPat(name="loop_start"), UPat(name="loop_end")))])]), UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))), loop_collapse), # sum collapse to mul (with possible GEP) - (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=(UPat(UOps.RANGE, name="loop"),)), + (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]), UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse), - (UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input", src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.RANGE, name="loop"),)),)), + (UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input", src=(UPat(UOps.DEFINE_ACC, src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),)), UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse), # deal with UNMUL (UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"), UPat(UOps.UNMUL, src=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]), @@ -186,11 +186,11 @@ constant_folder = PatternMatcher([ (UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)), (UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)), # a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed - (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.const(acc.dtype, acc.arg[0])), - (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, src=tuple()), UPat(name="x"))), lambda x: x), + (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)), + (UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST),)), UPat(name="x"))), lambda x: x), (UPat(UOps.PHI, src=(UPat(UOps.CONST), UPat(name="x"))), lambda x: x), # a DEFINE_ACC without inputs is a const + GEP on a const is the const - (UPat(UOps.DEFINE_ACC, name="root", src=tuple()), lambda root: UOp.const(root.dtype, root.arg[0])), + (UPat(UOps.DEFINE_ACC, name="root", src=(UPat(UOps.CONST),)), lambda root: UOp.cast(root.src[0], root.dtype)), (UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="x"),)), lambda root,x: UOp.const(root.dtype, x.arg)), # max -2147483648 (UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x), @@ -404,8 +404,8 @@ class UOpGraph: while queue: p,x = heapq.heappop(queue) if DEBUG >= 7: print(p,x) - if x.op is UOps.DEFINE_ACC and len(x.src): - idx = min([self._uops.index(l) for l in x.src]) + if x.op is UOps.DEFINE_ACC and len(x.src) > 1: + idx = min([self._uops.index(l) for l in x.src if l.op is UOps.RANGE]) self._uops.insert(idx, x) else: self._uops.append(x) @@ -460,8 +460,10 @@ class UOpGraph: def type_verify(self): for u in self.uops: uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype - if uop in {UOps.CONST, UOps.DEFINE_ACC}: - if uop is UOps.DEFINE_ACC: arg = arg[0] + if uop in (UOps.CONST, UOps.DEFINE_ACC): + if uop is UOps.DEFINE_ACC: + assert dtype is not None and src[0].dtype is dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}" + arg = src[0].arg assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}" if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg if uop is UOps.LOAD and len(src) > 2 and src[2].op not in {UOps.IF, UOps.BARRIER}: assert src[2].dtype == dtypes.bool diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index e3ef6669..60bd2a7c 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -164,8 +164,8 @@ class PTXRenderer(Renderer): elif uop is UOps.DEFINE_ACC: if dtype.count > 1: r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)] - for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(args[0], dtype.scalar())};") - else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(args[0], dtype)};") + for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].arg, dtype.scalar())};") + else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(src[0].arg, dtype)};") elif uop is UOps.SPECIAL: assert args[1][0] != "i", "idx not supported" kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};") diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 9b7cd460..db19f911 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -164,7 +164,7 @@ class CStyleLanguage(Renderer): bufs.append((nm:=f"data{args[0]}", (dtype,args[1]))) r[u] = nm elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});") - elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(args[0], dtype)};") + elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(src[0].arg, dtype)};") elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})" elif uop is UOps.GEP: assert src[0].dtype is not None diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index d08e9032..1d5af348 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -132,7 +132,7 @@ class LLVMRenderer(Renderer): lvars[u].add_incoming(lvars[src[0]], bb[-2].block) loop_blocks.append((bb[-1].block, phis)) elif uop is UOps.DEFINE_ACC: - lvars[u] = const(args[0], dtype) + lvars[u] = const(src[0].arg, dtype) reduce_phis.append(u) elif uop is UOps.LOAD: if len(src) > 2: diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 5615bb99..4c032c04 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -41,7 +41,7 @@ class PythonProgram: while i < len(self.uops): uop, dtype, idp, arg = self.uops[i] void_ops = {UOps.STORE, UOps.ENDRANGE, UOps.BARRIER, UOps.IF, UOps.ENDIF} - if uop is UOps.DEFINE_ACC: idp.clear() + if uop is UOps.DEFINE_ACC: idp = [idp[0]] inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops] dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops] if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp) @@ -90,7 +90,7 @@ class PythonProgram: elif uop is UOps.CONST: ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size elif uop is UOps.DEFINE_ACC: - ul[i] = [[arg[0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg[0]] * warp_size + ul[i] = [[inp[0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size elif uop is UOps.RANGE: if i not in ul: ul[i] = [inp[0][0]] * warp_size else: