mirror of https://github.com/commaai/tinygrad.git
Make vectorization of CONST explicit (#5322)
* remove test_const_vectorize_fold * remove const folding UPat for VECTORIZE * refactor cstyle render_const * remove calls to dtype.scalar() in render_const * add assert * add vectorized const to UOp.const * add UPat GEP-VECTORIZE-CONST -> CONST * render_vectorize for DEFINE_ACC in cstyle * add back missing render_cast in render_const * generate vectorized consts as UOps for DEFINE_ACC * update asserts for DEFINE_ACC with VECTORIZE src * add UPats for PHI with VECTORIZE src * use prev rendered vectorize in DEFINE_ACC render * update DEFINE_ACC in python runtime * update vectorized DEFINE_ACC in PTXRenderer * rebase DEFINE_ACC changes on lowerer * verbose rewrite of bad UPats * simplify UOps.CONST implementation in ops_python * update sum_collapse UPats for DEFINE_ACC-VECTORIZE * revert linearizer to TOT * fix DEFINE_ACC implementation in ops_python * simplify DEFINE_ACC in cstyle * Fix linter error * support VECTORIZE in fold gated load/store UPat * support VECTORIZE in other fold gated load UPats * rewrite VECTORIZE in UPat for no input DEFINE_ACC * simplify DEFINE_ACC render in cstyle * make VECTORIZE rules more concise * add more vectorize fold tests * inline VECTORIZE-CONSTs in cstyle render * revert VECTORIZE/GEP rule refactor * revert cstyle render_const refactor * inline VECTORIZE-CONSTs in cstyle render * implicitly vectorized const rendering -> explicit * WMMA VECTORIZE CONST process replay hacks * VECTORIZE CONST NAN process_replay hacks * more VECTORIZE CONST NAN hacks * cleanup process_replay hacks * isnan() -> not isfinite() cstyle VECTORIZE CONST * tweak isnan and isfinite checks VECTORIZE CONST * tweak for positive vs negative infinity VECTORIZE CONST * add assert to PTX CONST render * process_replay VECTORIZE CONST render parity for PTX STORE * vmin/vmax for VECTORIZE'd CONST * update WMMA folding rules * add tests for WMMA VECTORIZE fold * hack for cstyle half4 CONST zero process_replay parity * revert PTX backend changes * add back minimal DEFINE_ACC PTX change * remove cstyle process_replay hacks * remove dead code in PTX CONST render * cleanup vmin/vmax logic for VECTORIZE'd CONSTs * update vectorize fold tests to use DEFINE_VAR * fix long line formatting in test * remove unwanted merge artifact * more vmin/vmax cleanup * remove unnecessary asserts * yet more vmin/vmax cleanup * get rid of explicit VECTORIZE CONST logic in _min_max * reuse CONST instead of creating a new one * remove unneeded cast * handle DType correctly in sconst * improve readability of tests * save a line * save another line * tuplize pats in src * remove GEP-VECTORIZE pats * add vec +0 fold * HACK: fold only vec8 +0 * remove vectorized ALU fold hack --------- Co-authored-by: qazal <qazal.software@gmail.com> Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
This commit is contained in:
parent
62c77a2831
commit
df44a4e861
|
@ -133,15 +133,6 @@ class TestUOpGraph(TestUOps):
|
|||
self.assertEqual(out.op, UOps.CONST)
|
||||
self.assertEqual(out.arg, 0)
|
||||
|
||||
def test_const_vectorize_fold(self):
|
||||
c0 = UOp(UOps.CONST, dtypes.half, arg=0.0)
|
||||
out = UOp(UOps.VECTORIZE, dtypes.half.vec(2), (c0, c0))
|
||||
g = UOpGraph([out])
|
||||
self.assertEqual(len(g.uops), 1)
|
||||
out = g.uops[-1]
|
||||
self.assertEqual(out.op, UOps.CONST)
|
||||
self.assertEqual(out.arg, 0.0)
|
||||
|
||||
def test_noop_vectorize_fold(self):
|
||||
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0)
|
||||
idx = UOp.const(dtypes.int, 0)
|
||||
|
@ -192,6 +183,73 @@ class TestUOpGraph(TestUOps):
|
|||
xy2 = tuple(UOp(UOps.GEP, dtypes.float, (val2, ), i) for i in range(2))
|
||||
self.assertIs(_test_vec(xy1+xy2).op, UOps.VECTORIZE)
|
||||
|
||||
def test_gep_vec_const_fold(self):
|
||||
for vec_size in [2, 4, 8]:
|
||||
consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)]
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
|
||||
geps = [UOp(UOps.GEP, dtypes.float, (vec,), i) for i in range(vec_size)]
|
||||
g = UOpGraph(geps)
|
||||
for uop, const in zip(g.uops, consts):
|
||||
self.assert_equiv_uops(uop, const)
|
||||
|
||||
def test_wmma_vectorize_fold(self):
|
||||
for i in [2, 4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
g = UOpGraph([wmma])
|
||||
self.assert_equiv_uops(g.uops[0], acc)
|
||||
self.assertEqual(len(g.uops), 1)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
g = UOpGraph([wmma])
|
||||
self.assert_equiv_uops(g.uops[0], acc)
|
||||
self.assertEqual(len(g.uops), 1)
|
||||
|
||||
def test_wmma_vectorize_no_fold(self):
|
||||
for i in [4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
|
||||
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
g = UOpGraph([wmma])
|
||||
self.assert_equiv_uops(g.uops[-1], wmma)
|
||||
|
||||
for i in [4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
|
||||
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
g = UOpGraph([wmma])
|
||||
self.assert_equiv_uops(g.uops[-1], wmma)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
g = UOpGraph([wmma])
|
||||
self.assert_equiv_uops(g.uops[-1], wmma)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
g = UOpGraph([wmma])
|
||||
self.assert_equiv_uops(g.uops[-1], wmma)
|
||||
|
||||
def test_cast_alu_fold(self):
|
||||
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=0)
|
||||
d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1)
|
||||
|
|
|
@ -191,8 +191,10 @@ constant_folder = PatternMatcher([
|
|||
*[(NOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(NOp(UOps.GEP, dtypes.half,
|
||||
src=(NOp.var('x', dtype=dtypes.half.vec(i)),), arg=j) for j in range(i))), lambda x: x) for i in [2, 4, 8, 16]],
|
||||
# tensor core with a 0 input is acc
|
||||
(NOp(UOps.WMMA, src=(NOp.const(None, 0.0), NOp.var(), NOp.var('acc'))), lambda acc: acc),
|
||||
(NOp(UOps.WMMA, src=(NOp.var(), NOp.const(None, 0.0), NOp.var('acc'))), lambda acc: acc),
|
||||
*[(NOp(UOps.WMMA, src=(NOp(UOps.VECTORIZE, src=tuple(NOp.const(None, 0.0) for _ in range(i))), NOp.var(), NOp.var('acc'))),
|
||||
lambda acc: acc) for i in [2, 4, 8]],
|
||||
*[(NOp(UOps.WMMA, src=(NOp.var(), NOp(UOps.VECTORIZE, src=tuple(NOp.const(None, 0.0) for _ in range(i))), NOp.var('acc'))),
|
||||
lambda acc: acc) for i in [2, 4, 8]],
|
||||
# tensor core cleanups
|
||||
*[(NOp(UOps.REDUCE, src=(NOp(UOps.EXPAND, src=tuple(NOp(UOps.GEP, dtypes.float, src=(NOp.var('x'),), arg=i) for i in range(j)), name="expand"),)
|
||||
,name="reduce", allow_any_len=True), reduce_before_expand) for j in [2,4,8]],
|
||||
|
@ -332,11 +334,11 @@ constant_folder = PatternMatcher([
|
|||
(NOp(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
||||
(NOp(UOps.VECTORIZE, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
|
||||
# fold gated LOAD/STORE
|
||||
(NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.cvar("var"), NOp.const(dtypes.bool, True)), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
|
||||
(NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.cvar("var"), NOp.const(dtypes.bool, True), NOp.var("barrier")),
|
||||
(NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.var("var"), NOp.const(dtypes.bool, True)), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
|
||||
(NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.var("var"), NOp.const(dtypes.bool, True), NOp.var("barrier")),
|
||||
lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)),
|
||||
(NOp.load(NOp.var(), NOp.var(), NOp.cvar("var"), NOp.const(dtypes.bool, False)), lambda var: var),
|
||||
(NOp.load(NOp.var(), NOp.var(), NOp.cvar("var"), NOp.const(dtypes.bool, False), NOp.var()), lambda var: var),
|
||||
(NOp.load(NOp.var(), NOp.var(), NOp.var("var"), NOp.const(dtypes.bool, False)), lambda var: var),
|
||||
(NOp.load(NOp.var(), NOp.var(), NOp.var("var"), NOp.const(dtypes.bool, False), NOp.var()), lambda var: var),
|
||||
(NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.var("val"), NOp.const(dtypes.bool, True)), UOp.store),
|
||||
(NOp.store(NOp.var(), NOp.var(), NOp.var(), NOp.const(dtypes.bool, False)), lambda: UOp(UOps.NOOP)),
|
||||
# remove NOOPs from SINK
|
||||
|
@ -401,7 +403,7 @@ def do_reduce(root):
|
|||
reduce_parented, reduce_unparented = partition(root.src[1:], lambda x: x in root.src[0].parents)
|
||||
ret = root.src[0]
|
||||
if len(reduce_parented):
|
||||
const = UOp.const(root.dtype.scalar(), 0 if root.arg is BinaryOps.ADD else dtypes.min(root.dtype))
|
||||
const = UOp.const(root.dtype, 0 if root.arg is BinaryOps.ADD else dtypes.min(root.dtype.scalar()))
|
||||
acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(reduce_parented), (acc_number,))
|
||||
acc_number += 1
|
||||
ret = UOp(UOps.PHI, root.dtype, (acc, acc.alu(root.arg, ret)))
|
||||
|
@ -468,8 +470,6 @@ reducer = PatternMatcher([
|
|||
(NOp(UOps.REDUCE, name="root"), do_reduce),
|
||||
# no ALU on vectorized dtypes
|
||||
(UPat({UOps.ALU, UOps.CAST, UOps.BITCAST}, name="alu"), no_vectorized_alu),
|
||||
# VECTORIZE a CONST is a CONST (eventually remove this rule)
|
||||
(UPat(UOps.VECTORIZE, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: root.const(c.arg)),
|
||||
# delete_redundant_gates (after expand, is this still needed?)
|
||||
(NOp(UOps.STORE, name="root"), delete_redundant_gates),
|
||||
])
|
||||
|
|
|
@ -70,11 +70,15 @@ class UOp:
|
|||
def where(self, x, y): return self.alu(TernaryOps.WHERE, x, y)
|
||||
def recip(self): return self.alu(UnaryOps.RECIP)
|
||||
def const(self:Union[UOp, DType, None], b:ConstType|Variable): return UOp._const(self.dtype if isinstance(self, UOp) else self, b)
|
||||
def sconst(self:Union[UOp, DType, None], b:ConstType|Variable):
|
||||
return UOp._const(cast(DType, self.dtype if isinstance(self, UOp) else self).scalar() if self is not None else self, b)
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _const(dtype:Optional[DType], b:ConstType|Variable):
|
||||
# TODO: fix dtype of b.max after Variable is just an UOp
|
||||
if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (UOp.const(dtypes.int, b.min), UOp.const(dtypes.int, cast(int,b.max))), b)
|
||||
if dtype is not None and dtype != (sdtype := dtype.scalar()):
|
||||
return UOp(UOps.VECTORIZE, dtype, src=tuple(UOp(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count)))
|
||||
return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
|
||||
def alu(self, arg, *src:UOp):
|
||||
return type(self)(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else (self, *src)[-1].dtype, (self,)+src, arg)
|
||||
|
@ -104,9 +108,9 @@ class UOp:
|
|||
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
|
||||
return None # generic None if we aren't sure
|
||||
@functools.cached_property
|
||||
def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.const(dtypes.min(cast(DType, self.dtype)))
|
||||
def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.min(cast(DType, self.dtype)))
|
||||
@functools.cached_property
|
||||
def vmax(self) -> UOp: return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.const(dtypes.max(cast(DType, self.dtype)))
|
||||
def vmax(self) -> UOp: return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.max(cast(DType, self.dtype)))
|
||||
@functools.cached_property
|
||||
def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]:
|
||||
# NOTE: returned UOp is assumed to be CONST
|
||||
|
@ -118,21 +122,21 @@ class UOp:
|
|||
if self.op is UOps.ALU and cast(DType, self.dtype).count == 1:
|
||||
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
|
||||
if self.arg is UnaryOps.NEG and self.dtype != dtypes.bool and not dtypes.is_unsigned(cast(DType, self.dtype)):
|
||||
return self.const(-s0.vmax.arg), self.const(-s0.vmin.arg)
|
||||
if self.arg is BinaryOps.ADD: return self.const(s0.vmin.arg+s1.vmin.arg), self.const(s0.vmax.arg+s1.vmax.arg)
|
||||
return self.sconst(-s0.vmax.arg), self.sconst(-s0.vmin.arg)
|
||||
if self.arg is BinaryOps.ADD: return self.sconst(s0.vmin.arg+s1.vmin.arg), self.sconst(s0.vmax.arg+s1.vmax.arg)
|
||||
if self.arg is BinaryOps.MUL and (s0.vmin.arg >= 0 or s1.vmin.arg >= 0):
|
||||
# handle at lease one is non-negative
|
||||
Lmin, Lmax = (s0.vmin.arg, s0.vmax.arg) if s1.vmin.arg >= 0 else (s0.vmax.arg, s0.vmin.arg)
|
||||
Rmin, Rmax = (s1.vmin.arg, s1.vmax.arg) if s0.vmin.arg >= 0 else (s1.vmax.arg, s1.vmin.arg)
|
||||
assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}"
|
||||
return self.const(Lmin*Rmin), self.const(Lmax*Rmax)
|
||||
if self.arg is BinaryOps.MOD and s1.op is UOps.CONST and s1.arg > 0: return self.const(0), self.const(s1.arg-1)
|
||||
return self.sconst(Lmin*Rmin), self.sconst(Lmax*Rmax)
|
||||
if self.arg is BinaryOps.MOD and s1.op is UOps.CONST and s1.arg > 0: return self.sconst(0), self.sconst(s1.arg-1)
|
||||
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
|
||||
if s1.arg > 0: return self.const(s0.vmin.arg//s1.arg), self.const(s0.vmax.arg//s1.arg)
|
||||
if s1.arg < 0: return self.const(-(s0.vmax.arg//-s1.arg)), self.const(-(s0.vmin.arg//-s1.arg))
|
||||
if self.arg is BinaryOps.MAX: return self.const(max(s0.vmin.arg, s1.vmin.arg)), self.const(max(s0.vmax.arg, s1.vmax.arg))
|
||||
if self.arg is BinaryOps.CMPLT: return (UOp.const(dtypes.bool, True), UOp.const(dtypes.bool, True)) if s0.vmax.arg < s1.vmin.arg else \
|
||||
(UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, False)) if s0.vmin.arg >= s1.vmax.arg else (None, None)
|
||||
if s1.arg > 0: return self.sconst(s0.vmin.arg//s1.arg), self.sconst(s0.vmax.arg//s1.arg)
|
||||
if s1.arg < 0: return self.sconst(-(s0.vmax.arg//-s1.arg)), self.sconst(-(s0.vmin.arg//-s1.arg))
|
||||
if self.arg is BinaryOps.MAX: return self.sconst(max(s0.vmin.arg, s1.vmin.arg)), self.sconst(max(s0.vmax.arg, s1.vmax.arg))
|
||||
if self.arg is BinaryOps.CMPLT: return (UOp.sconst(dtypes.bool, True), UOp.sconst(dtypes.bool, True)) if s0.vmax.arg < s1.vmin.arg else \
|
||||
(UOp.sconst(dtypes.bool, False), UOp.sconst(dtypes.bool, False)) if s0.vmin.arg >= s1.vmax.arg else (None, None)
|
||||
return None, None
|
||||
|
||||
@dataclass(frozen=True, repr=False) # reuse repr from UOp
|
||||
|
@ -209,10 +213,10 @@ def type_verify(uops):
|
|||
for u in uops:
|
||||
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
|
||||
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
|
||||
if uop is UOps.DEFINE_ACC:
|
||||
assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
|
||||
arg = src[0].arg
|
||||
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
||||
if uop is UOps.CONST:
|
||||
assert dtype is not None and dtype == dtype.scalar(), f"consts should be scalar, got {dtype}"
|
||||
assert type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
||||
if uop is UOps.DEFINE_ACC: assert dtype is not None and src[0].dtype == dtype, f"dtype mismatch {src[0].dtype=} != {dtype=}"
|
||||
if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg
|
||||
if uop is UOps.CAST: assert dtype.count == 1 and len(src) == 1
|
||||
if uop is UOps.VECTORIZE:
|
||||
|
|
|
@ -206,7 +206,7 @@ class PTXRenderer(Renderer):
|
|||
elif uop is UOps.DEFINE_ACC:
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].arg, dtype.scalar())};")
|
||||
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].src[0].arg, dtype.scalar())};")
|
||||
else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
|
||||
elif uop is UOps.SPECIAL:
|
||||
assert args[0][0] != "i", "idx not supported"
|
||||
|
@ -217,7 +217,7 @@ class PTXRenderer(Renderer):
|
|||
bufs.append((args.expr, dtype))
|
||||
r[u] = f"%{args.expr}"
|
||||
kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
|
||||
elif uop is UOps.CONST: r[u] = ([const(args, dtype.scalar(), mov=True)] * dtype.count) if dtype.count > 1 else const(args, dtype, mov=True)
|
||||
elif uop is UOps.CONST: r[u] = const(args, dtype, mov=True)
|
||||
elif uop is UOps.GEP: r[u] = r[src[0]][u.arg]
|
||||
elif uop is UOps.LOAD:
|
||||
assert src[0].dtype == dtypes.int64, "load isn't int64"
|
||||
|
|
|
@ -174,7 +174,7 @@ class CStyleLanguage(Renderer):
|
|||
bufs[u] = (nm:=f"data{args}", (dtype, False))
|
||||
r[u] = nm
|
||||
elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});")
|
||||
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(src[0].arg, dtype)};")
|
||||
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {r[src[0]]};")
|
||||
elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
|
||||
elif uop is UOps.GEP:
|
||||
assert src[0].dtype is not None
|
||||
|
|
|
@ -85,10 +85,9 @@ class PythonProgram:
|
|||
elif uop is UOps.SPECIAL:
|
||||
if arg[0][0] == 'g': ul[i] = [idxs[2-int(arg[0][-1])]] * warp_size
|
||||
elif arg[0][0] == 'l': ul[i] = [x[2-int(arg[0][-1])] for x in warp]
|
||||
elif uop is UOps.CONST:
|
||||
ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
|
||||
elif uop is UOps.CONST: ul[i] = [arg] * warp_size
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
ul[i] = [[inp[0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
|
||||
ul[i] = [[inp[0][0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
|
||||
elif uop is UOps.RANGE:
|
||||
if i not in ul: ul[i] = [inp[0][0]] * warp_size
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue