DEFINE_ACC takes UOps.CONST in vin instead of arg (#4975)

* Change DEFINE_ACC to receive UOps.CONST in vin

* Use localtype instead of acc dtype

* Fix idp

* Fix copy list

* Fix warp

* Fix error

* Fix merge

* Fix testing

* Fix merge

* Use deepcopy

* Change to copy of inp

* Fix lint

* Move const to first place

* Fix issue upat

* Fix upat patterns

* Change to list, to test permutations

* Add condition

* Change pm

* Revert change pm

* Remove unused rule

* Fix

* Change of float4 DEFINE_ACC values

* Cast on PM to correct dtype

* Improve assert message

* Move IFs

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Jhenner Tigreros 2024-06-24 11:25:33 -05:00 committed by GitHub
parent d84beaa6dd
commit dfa562dbc1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 18 additions and 16 deletions

View File

@ -145,7 +145,7 @@ class Linearizer(Kernel):
key = f"{'' if acc is None else self.reduceops.index(acc)}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
if key not in self.load_cache:
if acc is not None:
self.load_cache[key] = UOp(UOps.DEFINE_ACC, localtype, loop_ctx, (self.get_reduce_acc(acc), i, acc_count))
self.load_cache[key] = UOp(UOps.DEFINE_ACC, localtype, (UOp.const(localtype.scalar(), self.get_reduce_acc(acc)), *loop_ctx), (i, acc_count))
acc_count += 1
elif this_const is not None:
self.load_cache[key] = UOp.const(localtype, this_const)

View File

@ -171,9 +171,9 @@ constant_folder = PatternMatcher([
src=[UPat(UOps.CONST, name="mval"), UPat(UOps.RANGE, src=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))), loop_collapse),
# sum collapse to mul (with possible GEP)
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=(UPat(UOps.RANGE, name="loop"),)),
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
(UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input", src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.RANGE, name="loop"),)),)),
(UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input", src=(UPat(UOps.DEFINE_ACC, src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),)),
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
# deal with UNMUL
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"), UPat(UOps.UNMUL, src=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]),
@ -186,11 +186,11 @@ constant_folder = PatternMatcher([
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)),
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
# a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.const(acc.dtype, acc.arg[0])),
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, src=tuple()), UPat(name="x"))), lambda x: x),
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)),
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST),)), UPat(name="x"))), lambda x: x),
(UPat(UOps.PHI, src=(UPat(UOps.CONST), UPat(name="x"))), lambda x: x),
# a DEFINE_ACC without inputs is a const + GEP on a const is the const
(UPat(UOps.DEFINE_ACC, name="root", src=tuple()), lambda root: UOp.const(root.dtype, root.arg[0])),
(UPat(UOps.DEFINE_ACC, name="root", src=(UPat(UOps.CONST),)), lambda root: UOp.cast(root.src[0], root.dtype)),
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="x"),)), lambda root,x: UOp.const(root.dtype, x.arg)),
# max -2147483648
(UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x),
@ -404,8 +404,8 @@ class UOpGraph:
while queue:
p,x = heapq.heappop(queue)
if DEBUG >= 7: print(p,x)
if x.op is UOps.DEFINE_ACC and len(x.src):
idx = min([self._uops.index(l) for l in x.src])
if x.op is UOps.DEFINE_ACC and len(x.src) > 1:
idx = min([self._uops.index(l) for l in x.src if l.op is UOps.RANGE])
self._uops.insert(idx, x)
else:
self._uops.append(x)
@ -460,8 +460,10 @@ class UOpGraph:
def type_verify(self):
for u in self.uops:
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
if uop is UOps.DEFINE_ACC: arg = arg[0]
if uop in (UOps.CONST, UOps.DEFINE_ACC):
if uop is UOps.DEFINE_ACC:
assert dtype is not None and src[0].dtype is dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
arg = src[0].arg
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
if uop is UOps.LOAD and len(src) > 2 and src[2].op not in {UOps.IF, UOps.BARRIER}: assert src[2].dtype == dtypes.bool

View File

@ -164,8 +164,8 @@ class PTXRenderer(Renderer):
elif uop is UOps.DEFINE_ACC:
if dtype.count > 1:
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(args[0], dtype.scalar())};")
else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(args[0], dtype)};")
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].arg, dtype.scalar())};")
else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
elif uop is UOps.SPECIAL:
assert args[1][0] != "i", "idx not supported"
kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};")

View File

@ -164,7 +164,7 @@ class CStyleLanguage(Renderer):
bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
r[u] = nm
elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});")
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(args[0], dtype)};")
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(src[0].arg, dtype)};")
elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
elif uop is UOps.GEP:
assert src[0].dtype is not None

View File

@ -132,7 +132,7 @@ class LLVMRenderer(Renderer):
lvars[u].add_incoming(lvars[src[0]], bb[-2].block)
loop_blocks.append((bb[-1].block, phis))
elif uop is UOps.DEFINE_ACC:
lvars[u] = const(args[0], dtype)
lvars[u] = const(src[0].arg, dtype)
reduce_phis.append(u)
elif uop is UOps.LOAD:
if len(src) > 2:

View File

@ -41,7 +41,7 @@ class PythonProgram:
while i < len(self.uops):
uop, dtype, idp, arg = self.uops[i]
void_ops = {UOps.STORE, UOps.ENDRANGE, UOps.BARRIER, UOps.IF, UOps.ENDIF}
if uop is UOps.DEFINE_ACC: idp.clear()
if uop is UOps.DEFINE_ACC: idp = [idp[0]]
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
@ -90,7 +90,7 @@ class PythonProgram:
elif uop is UOps.CONST:
ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
elif uop is UOps.DEFINE_ACC:
ul[i] = [[arg[0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg[0]] * warp_size
ul[i] = [[inp[0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
elif uop is UOps.RANGE:
if i not in ul: ul[i] = [inp[0][0]] * warp_size
else: