From df44a4e861ff25d9552f82d1164aaa36c3d4abd4 Mon Sep 17 00:00:00 2001 From: gswangg <152219575+greg-niemeyer@users.noreply.github.com> Date: Thu, 8 Aug 2024 10:59:05 -0700 Subject: [PATCH] Make vectorization of CONST explicit (#5322) * remove test_const_vectorize_fold * remove const folding UPat for VECTORIZE * refactor cstyle render_const * remove calls to dtype.scalar() in render_const * add assert * add vectorized const to UOp.const * add UPat GEP-VECTORIZE-CONST -> CONST * render_vectorize for DEFINE_ACC in cstyle * add back missing render_cast in render_const * generate vectorized consts as UOps for DEFINE_ACC * update asserts for DEFINE_ACC with VECTORIZE src * add UPats for PHI with VECTORIZE src * use prev rendered vectorize in DEFINE_ACC render * update DEFINE_ACC in python runtime * update vectorized DEFINE_ACC in PTXRenderer * rebase DEFINE_ACC changes on lowerer * verbose rewrite of bad UPats * simplify UOps.CONST implementation in ops_python * update sum_collapse UPats for DEFINE_ACC-VECTORIZE * revert linearizer to TOT * fix DEFINE_ACC implementation in ops_python * simplify DEFINE_ACC in cstyle * Fix linter error * support VECTORIZE in fold gated load/store UPat * support VECTORIZE in other fold gated load UPats * rewrite VECTORIZE in UPat for no input DEFINE_ACC * simplify DEFINE_ACC render in cstyle * make VECTORIZE rules more concise * add more vectorize fold tests * inline VECTORIZE-CONSTs in cstyle render * revert VECTORIZE/GEP rule refactor * revert cstyle render_const refactor * inline VECTORIZE-CONSTs in cstyle render * implicitly vectorized const rendering -> explicit * WMMA VECTORIZE CONST process replay hacks * VECTORIZE CONST NAN process_replay hacks * more VECTORIZE CONST NAN hacks * cleanup process_replay hacks * isnan() -> not isfinite() cstyle VECTORIZE CONST * tweak isnan and isfinite checks VECTORIZE CONST * tweak for positive vs negative infinity VECTORIZE CONST * add assert to PTX CONST render * process_replay VECTORIZE CONST render parity for PTX STORE * vmin/vmax for VECTORIZE'd CONST * update WMMA folding rules * add tests for WMMA VECTORIZE fold * hack for cstyle half4 CONST zero process_replay parity * revert PTX backend changes * add back minimal DEFINE_ACC PTX change * remove cstyle process_replay hacks * remove dead code in PTX CONST render * cleanup vmin/vmax logic for VECTORIZE'd CONSTs * update vectorize fold tests to use DEFINE_VAR * fix long line formatting in test * remove unwanted merge artifact * more vmin/vmax cleanup * remove unnecessary asserts * yet more vmin/vmax cleanup * get rid of explicit VECTORIZE CONST logic in _min_max * reuse CONST instead of creating a new one * remove unneeded cast * handle DType correctly in sconst * improve readability of tests * save a line * save another line * tuplize pats in src * remove GEP-VECTORIZE pats * add vec +0 fold * HACK: fold only vec8 +0 * remove vectorized ALU fold hack --------- Co-authored-by: qazal Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com> --- test/test_uop_graph.py | 76 ++++++++++++++++++++++++++++++---- tinygrad/codegen/uopgraph.py | 18 ++++---- tinygrad/codegen/uops.py | 34 ++++++++------- tinygrad/renderer/assembly.py | 4 +- tinygrad/renderer/cstyle.py | 2 +- tinygrad/runtime/ops_python.py | 5 +-- 6 files changed, 100 insertions(+), 39 deletions(-) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 75438223..bbd3553f 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -133,15 +133,6 @@ class TestUOpGraph(TestUOps): self.assertEqual(out.op, UOps.CONST) self.assertEqual(out.arg, 0) - def test_const_vectorize_fold(self): - c0 = UOp(UOps.CONST, dtypes.half, arg=0.0) - out = UOp(UOps.VECTORIZE, dtypes.half.vec(2), (c0, c0)) - g = UOpGraph([out]) - self.assertEqual(len(g.uops), 1) - out = g.uops[-1] - self.assertEqual(out.op, UOps.CONST) - self.assertEqual(out.arg, 0.0) - def test_noop_vectorize_fold(self): d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0) idx = UOp.const(dtypes.int, 0) @@ -192,6 +183,73 @@ class TestUOpGraph(TestUOps): xy2 = tuple(UOp(UOps.GEP, dtypes.float, (val2, ), i) for i in range(2)) self.assertIs(_test_vec(xy1+xy2).op, UOps.VECTORIZE) + def test_gep_vec_const_fold(self): + for vec_size in [2, 4, 8]: + consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)] + vec = UOp(UOps.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts)) + geps = [UOp(UOps.GEP, dtypes.float, (vec,), i) for i in range(vec_size)] + g = UOpGraph(geps) + for uop, const in zip(g.uops, consts): + self.assert_equiv_uops(uop, const) + + def test_wmma_vectorize_fold(self): + for i in [2, 4, 8]: + vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i))) + var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i)) + acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) + wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) + g = UOpGraph([wmma]) + self.assert_equiv_uops(g.uops[0], acc) + self.assertEqual(len(g.uops), 1) + + for i in [2, 4, 8]: + var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i)) + vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i))) + acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) + wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) + g = UOpGraph([wmma]) + self.assert_equiv_uops(g.uops[0], acc) + self.assertEqual(len(g.uops), 1) + + def test_wmma_vectorize_no_fold(self): + for i in [4, 8]: + vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), + tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) + + tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2))) + var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0)) + acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) + wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) + g = UOpGraph([wmma]) + self.assert_equiv_uops(g.uops[-1], wmma) + + for i in [4, 8]: + var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0)) + vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), + tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) + + tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2))) + acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) + wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) + g = UOpGraph([wmma]) + self.assert_equiv_uops(g.uops[-1], wmma) + + for i in [2, 4, 8]: + vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), + tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i))) + var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0)) + acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) + wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) + g = UOpGraph([wmma]) + self.assert_equiv_uops(g.uops[-1], wmma) + + for i in [2, 4, 8]: + var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0)) + vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), + tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i))) + acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0)) + wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) + g = UOpGraph([wmma]) + self.assert_equiv_uops(g.uops[-1], wmma) + def test_cast_alu_fold(self): d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=0) d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 5d42d7d9..ebe957c2 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -191,8 +191,10 @@ constant_folder = PatternMatcher([ *[(NOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(NOp(UOps.GEP, dtypes.half, src=(NOp.var('x', dtype=dtypes.half.vec(i)),), arg=j) for j in range(i))), lambda x: x) for i in [2, 4, 8, 16]], # tensor core with a 0 input is acc - (NOp(UOps.WMMA, src=(NOp.const(None, 0.0), NOp.var(), NOp.var('acc'))), lambda acc: acc), - (NOp(UOps.WMMA, src=(NOp.var(), NOp.const(None, 0.0), NOp.var('acc'))), lambda acc: acc), + *[(NOp(UOps.WMMA, src=(NOp(UOps.VECTORIZE, src=tuple(NOp.const(None, 0.0) for _ in range(i))), NOp.var(), NOp.var('acc'))), + lambda acc: acc) for i in [2, 4, 8]], + *[(NOp(UOps.WMMA, src=(NOp.var(), NOp(UOps.VECTORIZE, src=tuple(NOp.const(None, 0.0) for _ in range(i))), NOp.var('acc'))), + lambda acc: acc) for i in [2, 4, 8]], # tensor core cleanups *[(NOp(UOps.REDUCE, src=(NOp(UOps.EXPAND, src=tuple(NOp(UOps.GEP, dtypes.float, src=(NOp.var('x'),), arg=i) for i in range(j)), name="expand"),) ,name="reduce", allow_any_len=True), reduce_before_expand) for j in [2,4,8]], @@ -332,11 +334,11 @@ constant_folder = PatternMatcher([ (NOp(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None), (NOp(UOps.VECTORIZE, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None), # fold gated LOAD/STORE - (NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.cvar("var"), NOp.const(dtypes.bool, True)), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)), - (NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.cvar("var"), NOp.const(dtypes.bool, True), NOp.var("barrier")), + (NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.var("var"), NOp.const(dtypes.bool, True)), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)), + (NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.var("var"), NOp.const(dtypes.bool, True), NOp.var("barrier")), lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)), - (NOp.load(NOp.var(), NOp.var(), NOp.cvar("var"), NOp.const(dtypes.bool, False)), lambda var: var), - (NOp.load(NOp.var(), NOp.var(), NOp.cvar("var"), NOp.const(dtypes.bool, False), NOp.var()), lambda var: var), + (NOp.load(NOp.var(), NOp.var(), NOp.var("var"), NOp.const(dtypes.bool, False)), lambda var: var), + (NOp.load(NOp.var(), NOp.var(), NOp.var("var"), NOp.const(dtypes.bool, False), NOp.var()), lambda var: var), (NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.var("val"), NOp.const(dtypes.bool, True)), UOp.store), (NOp.store(NOp.var(), NOp.var(), NOp.var(), NOp.const(dtypes.bool, False)), lambda: UOp(UOps.NOOP)), # remove NOOPs from SINK @@ -401,7 +403,7 @@ def do_reduce(root): reduce_parented, reduce_unparented = partition(root.src[1:], lambda x: x in root.src[0].parents) ret = root.src[0] if len(reduce_parented): - const = UOp.const(root.dtype.scalar(), 0 if root.arg is BinaryOps.ADD else dtypes.min(root.dtype)) + const = UOp.const(root.dtype, 0 if root.arg is BinaryOps.ADD else dtypes.min(root.dtype.scalar())) acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(reduce_parented), (acc_number,)) acc_number += 1 ret = UOp(UOps.PHI, root.dtype, (acc, acc.alu(root.arg, ret))) @@ -468,8 +470,6 @@ reducer = PatternMatcher([ (NOp(UOps.REDUCE, name="root"), do_reduce), # no ALU on vectorized dtypes (UPat({UOps.ALU, UOps.CAST, UOps.BITCAST}, name="alu"), no_vectorized_alu), - # VECTORIZE a CONST is a CONST (eventually remove this rule) - (UPat(UOps.VECTORIZE, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: root.const(c.arg)), # delete_redundant_gates (after expand, is this still needed?) (NOp(UOps.STORE, name="root"), delete_redundant_gates), ]) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 23ecb1d4..9a7b9142 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -70,11 +70,15 @@ class UOp: def where(self, x, y): return self.alu(TernaryOps.WHERE, x, y) def recip(self): return self.alu(UnaryOps.RECIP) def const(self:Union[UOp, DType, None], b:ConstType|Variable): return UOp._const(self.dtype if isinstance(self, UOp) else self, b) + def sconst(self:Union[UOp, DType, None], b:ConstType|Variable): + return UOp._const(cast(DType, self.dtype if isinstance(self, UOp) else self).scalar() if self is not None else self, b) @staticmethod @functools.lru_cache(maxsize=None) def _const(dtype:Optional[DType], b:ConstType|Variable): # TODO: fix dtype of b.max after Variable is just an UOp if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (UOp.const(dtypes.int, b.min), UOp.const(dtypes.int, cast(int,b.max))), b) + if dtype is not None and dtype != (sdtype := dtype.scalar()): + return UOp(UOps.VECTORIZE, dtype, src=tuple(UOp(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count))) return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) def alu(self, arg, *src:UOp): return type(self)(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else (self, *src)[-1].dtype, (self,)+src, arg) @@ -104,9 +108,9 @@ class UOp: if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1 return None # generic None if we aren't sure @functools.cached_property - def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.const(dtypes.min(cast(DType, self.dtype))) + def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.min(cast(DType, self.dtype))) @functools.cached_property - def vmax(self) -> UOp: return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.const(dtypes.max(cast(DType, self.dtype))) + def vmax(self) -> UOp: return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.max(cast(DType, self.dtype))) @functools.cached_property def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]: # NOTE: returned UOp is assumed to be CONST @@ -118,21 +122,21 @@ class UOp: if self.op is UOps.ALU and cast(DType, self.dtype).count == 1: s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)] if self.arg is UnaryOps.NEG and self.dtype != dtypes.bool and not dtypes.is_unsigned(cast(DType, self.dtype)): - return self.const(-s0.vmax.arg), self.const(-s0.vmin.arg) - if self.arg is BinaryOps.ADD: return self.const(s0.vmin.arg+s1.vmin.arg), self.const(s0.vmax.arg+s1.vmax.arg) + return self.sconst(-s0.vmax.arg), self.sconst(-s0.vmin.arg) + if self.arg is BinaryOps.ADD: return self.sconst(s0.vmin.arg+s1.vmin.arg), self.sconst(s0.vmax.arg+s1.vmax.arg) if self.arg is BinaryOps.MUL and (s0.vmin.arg >= 0 or s1.vmin.arg >= 0): # handle at lease one is non-negative Lmin, Lmax = (s0.vmin.arg, s0.vmax.arg) if s1.vmin.arg >= 0 else (s0.vmax.arg, s0.vmin.arg) Rmin, Rmax = (s1.vmin.arg, s1.vmax.arg) if s0.vmin.arg >= 0 else (s1.vmax.arg, s1.vmin.arg) assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}" - return self.const(Lmin*Rmin), self.const(Lmax*Rmax) - if self.arg is BinaryOps.MOD and s1.op is UOps.CONST and s1.arg > 0: return self.const(0), self.const(s1.arg-1) + return self.sconst(Lmin*Rmin), self.sconst(Lmax*Rmax) + if self.arg is BinaryOps.MOD and s1.op is UOps.CONST and s1.arg > 0: return self.sconst(0), self.sconst(s1.arg-1) if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST: - if s1.arg > 0: return self.const(s0.vmin.arg//s1.arg), self.const(s0.vmax.arg//s1.arg) - if s1.arg < 0: return self.const(-(s0.vmax.arg//-s1.arg)), self.const(-(s0.vmin.arg//-s1.arg)) - if self.arg is BinaryOps.MAX: return self.const(max(s0.vmin.arg, s1.vmin.arg)), self.const(max(s0.vmax.arg, s1.vmax.arg)) - if self.arg is BinaryOps.CMPLT: return (UOp.const(dtypes.bool, True), UOp.const(dtypes.bool, True)) if s0.vmax.arg < s1.vmin.arg else \ - (UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, False)) if s0.vmin.arg >= s1.vmax.arg else (None, None) + if s1.arg > 0: return self.sconst(s0.vmin.arg//s1.arg), self.sconst(s0.vmax.arg//s1.arg) + if s1.arg < 0: return self.sconst(-(s0.vmax.arg//-s1.arg)), self.sconst(-(s0.vmin.arg//-s1.arg)) + if self.arg is BinaryOps.MAX: return self.sconst(max(s0.vmin.arg, s1.vmin.arg)), self.sconst(max(s0.vmax.arg, s1.vmax.arg)) + if self.arg is BinaryOps.CMPLT: return (UOp.sconst(dtypes.bool, True), UOp.sconst(dtypes.bool, True)) if s0.vmax.arg < s1.vmin.arg else \ + (UOp.sconst(dtypes.bool, False), UOp.sconst(dtypes.bool, False)) if s0.vmin.arg >= s1.vmax.arg else (None, None) return None, None @dataclass(frozen=True, repr=False) # reuse repr from UOp @@ -209,10 +213,10 @@ def type_verify(uops): for u in uops: uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype if uop in {UOps.CONST, UOps.DEFINE_ACC}: - if uop is UOps.DEFINE_ACC: - assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}" - arg = src[0].arg - assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}" + if uop is UOps.CONST: + assert dtype is not None and dtype == dtype.scalar(), f"consts should be scalar, got {dtype}" + assert type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}" + if uop is UOps.DEFINE_ACC: assert dtype is not None and src[0].dtype == dtype, f"dtype mismatch {src[0].dtype=} != {dtype=}" if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg if uop is UOps.CAST: assert dtype.count == 1 and len(src) == 1 if uop is UOps.VECTORIZE: diff --git a/tinygrad/renderer/assembly.py b/tinygrad/renderer/assembly.py index ac35dedc..a9c5f161 100644 --- a/tinygrad/renderer/assembly.py +++ b/tinygrad/renderer/assembly.py @@ -206,7 +206,7 @@ class PTXRenderer(Renderer): elif uop is UOps.DEFINE_ACC: if dtype.count > 1: r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)] - for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].arg, dtype.scalar())};") + for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].src[0].arg, dtype.scalar())};") else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {ssa('acc', u)}, {const(src[0].arg, dtype)};") elif uop is UOps.SPECIAL: assert args[0][0] != "i", "idx not supported" @@ -217,7 +217,7 @@ class PTXRenderer(Renderer): bufs.append((args.expr, dtype)) r[u] = f"%{args.expr}" kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param")) - elif uop is UOps.CONST: r[u] = ([const(args, dtype.scalar(), mov=True)] * dtype.count) if dtype.count > 1 else const(args, dtype, mov=True) + elif uop is UOps.CONST: r[u] = const(args, dtype, mov=True) elif uop is UOps.GEP: r[u] = r[src[0]][u.arg] elif uop is UOps.LOAD: assert src[0].dtype == dtypes.int64, "load isn't int64" diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index cf940ad9..738e7329 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -174,7 +174,7 @@ class CStyleLanguage(Renderer): bufs[u] = (nm:=f"data{args}", (dtype, False)) r[u] = nm elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});") - elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(src[0].arg, dtype)};") + elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {r[src[0]]};") elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})" elif uop is UOps.GEP: assert src[0].dtype is not None diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index ce1d0783..46244a34 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -85,10 +85,9 @@ class PythonProgram: elif uop is UOps.SPECIAL: if arg[0][0] == 'g': ul[i] = [idxs[2-int(arg[0][-1])]] * warp_size elif arg[0][0] == 'l': ul[i] = [x[2-int(arg[0][-1])] for x in warp] - elif uop is UOps.CONST: - ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size + elif uop is UOps.CONST: ul[i] = [arg] * warp_size elif uop is UOps.DEFINE_ACC: - ul[i] = [[inp[0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size + ul[i] = [[inp[0][0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size elif uop is UOps.RANGE: if i not in ul: ul[i] = [inp[0][0]] * warp_size else: