Make vectorization of CONST explicit (#5322)

* remove test_const_vectorize_fold

* remove const folding UPat for VECTORIZE

* refactor cstyle render_const

* remove calls to dtype.scalar() in render_const

* add assert

* add vectorized const to UOp.const

* add UPat GEP-VECTORIZE-CONST -> CONST

* render_vectorize for DEFINE_ACC in cstyle

* add back missing render_cast in render_const

* generate vectorized consts as UOps for DEFINE_ACC

* update asserts for DEFINE_ACC with VECTORIZE src

* add UPats for PHI with VECTORIZE src

* use prev rendered vectorize in DEFINE_ACC render

* update DEFINE_ACC in python runtime

* update vectorized DEFINE_ACC in PTXRenderer

* rebase DEFINE_ACC changes on lowerer

* verbose rewrite of bad UPats

* simplify UOps.CONST implementation in ops_python

* update sum_collapse UPats for DEFINE_ACC-VECTORIZE

* revert linearizer to TOT

* fix DEFINE_ACC implementation in ops_python

* simplify DEFINE_ACC in cstyle

* Fix linter error

* support VECTORIZE in fold gated load/store UPat

* support VECTORIZE in other fold gated load UPats

* rewrite VECTORIZE in UPat for no input DEFINE_ACC

* simplify DEFINE_ACC render in cstyle

* make VECTORIZE rules more concise

* add more vectorize fold tests

* inline VECTORIZE-CONSTs in cstyle render

* revert VECTORIZE/GEP rule refactor

* revert cstyle render_const refactor

* inline VECTORIZE-CONSTs in cstyle render

* implicitly vectorized const rendering -> explicit

* WMMA VECTORIZE CONST process replay hacks

* VECTORIZE CONST NAN process_replay hacks

* more VECTORIZE CONST NAN hacks

* cleanup process_replay hacks

* isnan() -> not isfinite() cstyle VECTORIZE CONST

* tweak isnan and isfinite checks VECTORIZE CONST

* tweak for positive vs negative infinity VECTORIZE CONST

* add assert to PTX CONST render

* process_replay VECTORIZE CONST render parity for PTX STORE

* vmin/vmax for VECTORIZE'd CONST

* update WMMA folding rules

* add tests for WMMA VECTORIZE fold

* hack for cstyle half4 CONST zero process_replay parity

* revert PTX backend changes

* add back minimal DEFINE_ACC PTX change

* remove cstyle process_replay hacks

* remove dead code in PTX CONST render

* cleanup vmin/vmax logic for VECTORIZE'd CONSTs

* update vectorize fold tests to use DEFINE_VAR

* fix long line formatting in test

* remove unwanted merge artifact

* more vmin/vmax cleanup

* remove unnecessary asserts

* yet more vmin/vmax cleanup

* get rid of explicit VECTORIZE CONST logic in _min_max

* reuse CONST instead of creating a new one

* remove unneeded cast

* handle DType correctly in sconst

* improve readability of tests

* save a line

* save another line

* tuplize pats in src

* remove GEP-VECTORIZE pats

* add vec +0 fold

* HACK: fold only vec8 +0

* remove vectorized ALU fold hack

---------

Co-authored-by: qazal <qazal.software@gmail.com>
Co-authored-by: qazal <77887910+Qazalin@users.noreply.github.com>
This commit is contained in:
gswangg 2024-08-08 10:59:05 -07:00 committed by GitHub
parent 62c77a2831
commit df44a4e861
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 100 additions and 39 deletions

View File

@ -133,15 +133,6 @@ class TestUOpGraph(TestUOps):
self.assertEqual(out.op, UOps.CONST)
self.assertEqual(out.arg, 0)
def test_const_vectorize_fold(self):
c0 = UOp(UOps.CONST, dtypes.half, arg=0.0)
out = UOp(UOps.VECTORIZE, dtypes.half.vec(2), (c0, c0))
g = UOpGraph([out])
self.assertEqual(len(g.uops), 1)
out = g.uops[-1]
self.assertEqual(out.op, UOps.CONST)
self.assertEqual(out.arg, 0.0)
def test_noop_vectorize_fold(self):
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0)
idx = UOp.const(dtypes.int, 0)
@ -192,6 +183,73 @@ class TestUOpGraph(TestUOps):
xy2 = tuple(UOp(UOps.GEP, dtypes.float, (val2, ), i) for i in range(2))
self.assertIs(_test_vec(xy1+xy2).op, UOps.VECTORIZE)
def test_gep_vec_const_fold(self):
for vec_size in [2, 4, 8]:
consts = [UOp.const(dtypes.float, float(i)) for i in range(vec_size)]
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(vec_size), tuple(consts))
geps = [UOp(UOps.GEP, dtypes.float, (vec,), i) for i in range(vec_size)]
g = UOpGraph(geps)
for uop, const in zip(g.uops, consts):
self.assert_equiv_uops(uop, const)
def test_wmma_vectorize_fold(self):
for i in [2, 4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
g = UOpGraph([wmma])
self.assert_equiv_uops(g.uops[0], acc)
self.assertEqual(len(g.uops), 1)
for i in [2, 4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
g = UOpGraph([wmma])
self.assert_equiv_uops(g.uops[0], acc)
self.assertEqual(len(g.uops), 1)
def test_wmma_vectorize_no_fold(self):
for i in [4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
g = UOpGraph([wmma])
self.assert_equiv_uops(g.uops[-1], wmma)
for i in [4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
g = UOpGraph([wmma])
self.assert_equiv_uops(g.uops[-1], wmma)
for i in [2, 4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
g = UOpGraph([wmma])
self.assert_equiv_uops(g.uops[-1], wmma)
for i in [2, 4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
g = UOpGraph([wmma])
self.assert_equiv_uops(g.uops[-1], wmma)
def test_cast_alu_fold(self):
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=0)
d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1)

View File

@ -191,8 +191,10 @@ constant_folder = PatternMatcher([
*[(NOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(NOp(UOps.GEP, dtypes.half,
src=(NOp.var('x', dtype=dtypes.half.vec(i)),), arg=j) for j in range(i))), lambda x: x) for i in [2, 4, 8, 16]],
# tensor core with a 0 input is acc
(NOp(UOps.WMMA, src=(NOp.const(None, 0.0), NOp.var(), NOp.var('acc'))), lambda acc: acc),
(NOp(UOps.WMMA, src=(NOp.var(), NOp.const(None, 0.0), NOp.var('acc'))), lambda acc: acc),
*[(NOp(UOps.WMMA, src=(NOp(UOps.VECTORIZE, src=tuple(NOp.const(None, 0.0) for _ in range(i))), NOp.var(), NOp.var('acc'))),
lambda acc: acc) for i in [2, 4, 8]],
*[(NOp(UOps.WMMA, src=(NOp.var(), NOp(UOps.VECTORIZE, src=tuple(NOp.const(None, 0.0) for _ in range(i))), NOp.var('acc'))),
lambda acc: acc) for i in [2, 4, 8]],
# tensor core cleanups
*[(NOp(UOps.REDUCE, src=(NOp(UOps.EXPAND, src=tuple(NOp(UOps.GEP, dtypes.float, src=(NOp.var('x'),), arg=i) for i in range(j)), name="expand"),)
,name="reduce", allow_any_len=True), reduce_before_expand) for j in [2,4,8]],
@ -332,11 +334,11 @@ constant_folder = PatternMatcher([
(NOp(UOps.CAST, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
(NOp(UOps.VECTORIZE, name="root"), lambda root: root.src[0] if str(root.dtype) == str(root.src[0].dtype) else None),
# fold gated LOAD/STORE
(NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.cvar("var"), NOp.const(dtypes.bool, True)), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
(NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.cvar("var"), NOp.const(dtypes.bool, True), NOp.var("barrier")),
(NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.var("var"), NOp.const(dtypes.bool, True)), lambda buf,idx,var: UOp.load(buf, idx, dtype=var.dtype)),
(NOp.load(NOp.var("buf"), NOp.var("idx"), NOp.var("var"), NOp.const(dtypes.bool, True), NOp.var("barrier")),
lambda buf,idx,var,barrier: UOp.load(buf, idx, barrier, dtype=var.dtype)),
(NOp.load(NOp.var(), NOp.var(), NOp.cvar("var"), NOp.const(dtypes.bool, False)), lambda var: var),
(NOp.load(NOp.var(), NOp.var(), NOp.cvar("var"), NOp.const(dtypes.bool, False), NOp.var()), lambda var: var),
(NOp.load(NOp.var(), NOp.var(), NOp.var("var"), NOp.const(dtypes.bool, False)), lambda var: var),
(NOp.load(NOp.var(), NOp.var(), NOp.var("var"), NOp.const(dtypes.bool, False), NOp.var()), lambda var: var),
(NOp.store(NOp.var("buf"), NOp.var("idx"), NOp.var("val"), NOp.const(dtypes.bool, True)), UOp.store),
(NOp.store(NOp.var(), NOp.var(), NOp.var(), NOp.const(dtypes.bool, False)), lambda: UOp(UOps.NOOP)),
# remove NOOPs from SINK
@ -401,7 +403,7 @@ def do_reduce(root):
reduce_parented, reduce_unparented = partition(root.src[1:], lambda x: x in root.src[0].parents)
ret = root.src[0]
if len(reduce_parented):
const = UOp.const(root.dtype.scalar(), 0 if root.arg is BinaryOps.ADD else dtypes.min(root.dtype))
const = UOp.const(root.dtype, 0 if root.arg is BinaryOps.ADD else dtypes.min(root.dtype.scalar()))
acc = UOp(UOps.DEFINE_ACC, root.dtype, (const,) + tuple(reduce_parented), (acc_number,))
acc_number += 1
ret = UOp(UOps.PHI, root.dtype, (acc, acc.alu(root.arg, ret)))
@ -468,8 +470,6 @@ reducer = PatternMatcher([
(NOp(UOps.REDUCE, name="root"), do_reduce),
# no ALU on vectorized dtypes
(UPat({UOps.ALU, UOps.CAST, UOps.BITCAST}, name="alu"), no_vectorized_alu),
# VECTORIZE a CONST is a CONST (eventually remove this rule)
(UPat(UOps.VECTORIZE, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: root.const(c.arg)),
# delete_redundant_gates (after expand, is this still needed?)
(NOp(UOps.STORE, name="root"), delete_redundant_gates),
])

View File

@ -70,11 +70,15 @@ class UOp:
def where(self, x, y): return self.alu(TernaryOps.WHERE, x, y)
def recip(self): return self.alu(UnaryOps.RECIP)
def const(self:Union[UOp, DType, None], b:ConstType|Variable): return UOp._const(self.dtype if isinstance(self, UOp) else self, b)
def sconst(self:Union[UOp, DType, None], b:ConstType|Variable):
return UOp._const(cast(DType, self.dtype if isinstance(self, UOp) else self).scalar() if self is not None else self, b)
@staticmethod
@functools.lru_cache(maxsize=None)
def _const(dtype:Optional[DType], b:ConstType|Variable):
# TODO: fix dtype of b.max after Variable is just an UOp
if isinstance(b, Variable): return UOp(UOps.DEFINE_VAR, dtype, (UOp.const(dtypes.int, b.min), UOp.const(dtypes.int, cast(int,b.max))), b)
if dtype is not None and dtype != (sdtype := dtype.scalar()):
return UOp(UOps.VECTORIZE, dtype, src=tuple(UOp(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count)))
return UOp(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
def alu(self, arg, *src:UOp):
return type(self)(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else (self, *src)[-1].dtype, (self,)+src, arg)
@ -104,9 +108,9 @@ class UOp:
if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1
return None # generic None if we aren't sure
@functools.cached_property
def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.const(dtypes.min(cast(DType, self.dtype)))
def vmin(self) -> UOp: return x if (x:=self._min_max[0]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.min(cast(DType, self.dtype)))
@functools.cached_property
def vmax(self) -> UOp: return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.const(dtypes.max(cast(DType, self.dtype)))
def vmax(self) -> UOp: return x if (x:=self._min_max[1]) is not None and not math.isnan(x.arg) else self.sconst(dtypes.max(cast(DType, self.dtype)))
@functools.cached_property
def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]:
# NOTE: returned UOp is assumed to be CONST
@ -118,21 +122,21 @@ class UOp:
if self.op is UOps.ALU and cast(DType, self.dtype).count == 1:
s0,s1 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(2)]
if self.arg is UnaryOps.NEG and self.dtype != dtypes.bool and not dtypes.is_unsigned(cast(DType, self.dtype)):
return self.const(-s0.vmax.arg), self.const(-s0.vmin.arg)
if self.arg is BinaryOps.ADD: return self.const(s0.vmin.arg+s1.vmin.arg), self.const(s0.vmax.arg+s1.vmax.arg)
return self.sconst(-s0.vmax.arg), self.sconst(-s0.vmin.arg)
if self.arg is BinaryOps.ADD: return self.sconst(s0.vmin.arg+s1.vmin.arg), self.sconst(s0.vmax.arg+s1.vmax.arg)
if self.arg is BinaryOps.MUL and (s0.vmin.arg >= 0 or s1.vmin.arg >= 0):
# handle at lease one is non-negative
Lmin, Lmax = (s0.vmin.arg, s0.vmax.arg) if s1.vmin.arg >= 0 else (s0.vmax.arg, s0.vmin.arg)
Rmin, Rmax = (s1.vmin.arg, s1.vmax.arg) if s0.vmin.arg >= 0 else (s1.vmax.arg, s1.vmin.arg)
assert math.isnan(Lmax*Rmax) or math.isnan(Lmin*Rmin) or Lmax*Rmax >= Lmin*Rmin, f"{Lmax=}, {Lmin=}, {Rmax=}, {Rmin=}"
return self.const(Lmin*Rmin), self.const(Lmax*Rmax)
if self.arg is BinaryOps.MOD and s1.op is UOps.CONST and s1.arg > 0: return self.const(0), self.const(s1.arg-1)
return self.sconst(Lmin*Rmin), self.sconst(Lmax*Rmax)
if self.arg is BinaryOps.MOD and s1.op is UOps.CONST and s1.arg > 0: return self.sconst(0), self.sconst(s1.arg-1)
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
if s1.arg > 0: return self.const(s0.vmin.arg//s1.arg), self.const(s0.vmax.arg//s1.arg)
if s1.arg < 0: return self.const(-(s0.vmax.arg//-s1.arg)), self.const(-(s0.vmin.arg//-s1.arg))
if self.arg is BinaryOps.MAX: return self.const(max(s0.vmin.arg, s1.vmin.arg)), self.const(max(s0.vmax.arg, s1.vmax.arg))
if self.arg is BinaryOps.CMPLT: return (UOp.const(dtypes.bool, True), UOp.const(dtypes.bool, True)) if s0.vmax.arg < s1.vmin.arg else \
(UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, False)) if s0.vmin.arg >= s1.vmax.arg else (None, None)
if s1.arg > 0: return self.sconst(s0.vmin.arg//s1.arg), self.sconst(s0.vmax.arg//s1.arg)
if s1.arg < 0: return self.sconst(-(s0.vmax.arg//-s1.arg)), self.sconst(-(s0.vmin.arg//-s1.arg))
if self.arg is BinaryOps.MAX: return self.sconst(max(s0.vmin.arg, s1.vmin.arg)), self.sconst(max(s0.vmax.arg, s1.vmax.arg))
if self.arg is BinaryOps.CMPLT: return (UOp.sconst(dtypes.bool, True), UOp.sconst(dtypes.bool, True)) if s0.vmax.arg < s1.vmin.arg else \
(UOp.sconst(dtypes.bool, False), UOp.sconst(dtypes.bool, False)) if s0.vmin.arg >= s1.vmax.arg else (None, None)
return None, None
@dataclass(frozen=True, repr=False) # reuse repr from UOp
@ -209,10 +213,10 @@ def type_verify(uops):
for u in uops:
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
if uop is UOps.DEFINE_ACC:
assert dtype is not None and src[0].dtype == dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
arg = src[0].arg
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
if uop is UOps.CONST:
assert dtype is not None and dtype == dtype.scalar(), f"consts should be scalar, got {dtype}"
assert type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
if uop is UOps.DEFINE_ACC: assert dtype is not None and src[0].dtype == dtype, f"dtype mismatch {src[0].dtype=} != {dtype=}"
if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg
if uop is UOps.CAST: assert dtype.count == 1 and len(src) == 1
if uop is UOps.VECTORIZE:

View File

@ -206,7 +206,7 @@ class PTXRenderer(Renderer):
elif uop is UOps.DEFINE_ACC:
if dtype.count > 1:
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].arg, dtype.scalar())};")
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].src[0].arg, dtype.scalar())};")
else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
elif uop is UOps.SPECIAL:
assert args[0][0] != "i", "idx not supported"
@ -217,7 +217,7 @@ class PTXRenderer(Renderer):
bufs.append((args.expr, dtype))
r[u] = f"%{args.expr}"
kk(*self.render_load(args.expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
elif uop is UOps.CONST: r[u] = ([const(args, dtype.scalar(), mov=True)] * dtype.count) if dtype.count > 1 else const(args, dtype, mov=True)
elif uop is UOps.CONST: r[u] = const(args, dtype, mov=True)
elif uop is UOps.GEP: r[u] = r[src[0]][u.arg]
elif uop is UOps.LOAD:
assert src[0].dtype == dtypes.int64, "load isn't int64"

View File

@ -174,7 +174,7 @@ class CStyleLanguage(Renderer):
bufs[u] = (nm:=f"data{args}", (dtype, False))
r[u] = nm
elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});")
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(src[0].arg, dtype)};")
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {r[src[0]]};")
elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
elif uop is UOps.GEP:
assert src[0].dtype is not None

View File

@ -85,10 +85,9 @@ class PythonProgram:
elif uop is UOps.SPECIAL:
if arg[0][0] == 'g': ul[i] = [idxs[2-int(arg[0][-1])]] * warp_size
elif arg[0][0] == 'l': ul[i] = [x[2-int(arg[0][-1])] for x in warp]
elif uop is UOps.CONST:
ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
elif uop is UOps.CONST: ul[i] = [arg] * warp_size
elif uop is UOps.DEFINE_ACC:
ul[i] = [[inp[0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
ul[i] = [[inp[0][0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
elif uop is UOps.RANGE:
if i not in ul: ul[i] = [inp[0][0]] * warp_size
else: