mirror of https://github.com/commaai/tinygrad.git
DEFINE_ACC takes UOps.CONST in vin instead of arg (#4975)
* Change DEFINE_ACC to receive UOps.CONST in vin * Use localtype instead of acc dtype * Fix idp * Fix copy list * Fix warp * Fix error * Fix merge * Fix testing * Fix merge * Use deepcopy * Change to copy of inp * Fix lint * Move const to first place * Fix issue upat * Fix upat patterns * Change to list, to test permutations * Add condition * Change pm * Revert change pm * Remove unused rule * Fix * Change of float4 DEFINE_ACC values * Cast on PM to correct dtype * Improve assert message * Move IFs --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
parent
d84beaa6dd
commit
dfa562dbc1
|
@ -145,7 +145,7 @@ class Linearizer(Kernel):
|
|||
key = f"{'' if acc is None else self.reduceops.index(acc)}{localtype}{'CONST'+str(this_const) if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}" # noqa: E501
|
||||
if key not in self.load_cache:
|
||||
if acc is not None:
|
||||
self.load_cache[key] = UOp(UOps.DEFINE_ACC, localtype, loop_ctx, (self.get_reduce_acc(acc), i, acc_count))
|
||||
self.load_cache[key] = UOp(UOps.DEFINE_ACC, localtype, (UOp.const(localtype.scalar(), self.get_reduce_acc(acc)), *loop_ctx), (i, acc_count))
|
||||
acc_count += 1
|
||||
elif this_const is not None:
|
||||
self.load_cache[key] = UOp.const(localtype, this_const)
|
||||
|
|
|
@ -171,9 +171,9 @@ constant_folder = PatternMatcher([
|
|||
src=[UPat(UOps.CONST, name="mval"), UPat(UOps.RANGE, src=(UPat(name="loop_start"), UPat(name="loop_end")))])]),
|
||||
UPat(UOps.CONST, name="compval"))), UPat(UOps.CONST, name="multconst"), UPat(UOps.CONST, 0))), loop_collapse),
|
||||
# sum collapse to mul (with possible GEP)
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=(UPat(UOps.RANGE, name="loop"),)),
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="phi_input", src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),
|
||||
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input", src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.RANGE, name="loop"),)),)),
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.GEP, name="phi_input", src=(UPat(UOps.DEFINE_ACC, src=[UPat(UOps.CONST), UPat(UOps.RANGE, name="loop")]),)),
|
||||
UPat(UOps.ALU, BinaryOps.ADD, src=(UPat(name="val1"), UPat(name="val2"))))), sum_collapse),
|
||||
# deal with UNMUL
|
||||
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"), UPat(UOps.UNMUL, src=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]),
|
||||
|
@ -186,11 +186,11 @@ constant_folder = PatternMatcher([
|
|||
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)),
|
||||
(UPat(UOps.CAST, name="root", src=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
|
||||
# a phi on a DEFINE_ACC without loops or a CONST is a noop. this is for correctness, not just speed
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.const(acc.dtype, acc.arg[0])),
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, src=tuple()), UPat(name="x"))), lambda x: x),
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, name="acc"), UPat(name="acc"))), lambda acc: UOp.cast(acc.src[0], acc.dtype)),
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST),)), UPat(name="x"))), lambda x: x),
|
||||
(UPat(UOps.PHI, src=(UPat(UOps.CONST), UPat(name="x"))), lambda x: x),
|
||||
# a DEFINE_ACC without inputs is a const + GEP on a const is the const
|
||||
(UPat(UOps.DEFINE_ACC, name="root", src=tuple()), lambda root: UOp.const(root.dtype, root.arg[0])),
|
||||
(UPat(UOps.DEFINE_ACC, name="root", src=(UPat(UOps.CONST),)), lambda root: UOp.cast(root.src[0], root.dtype)),
|
||||
(UPat(UOps.GEP, name="root", src=(UPat(UOps.CONST, name="x"),)), lambda root,x: UOp.const(root.dtype, x.arg)),
|
||||
# max -2147483648
|
||||
(UOp.max(UOp.var('x'), UOp.const(dtypes.int, -2147483648)), lambda x: x),
|
||||
|
@ -404,8 +404,8 @@ class UOpGraph:
|
|||
while queue:
|
||||
p,x = heapq.heappop(queue)
|
||||
if DEBUG >= 7: print(p,x)
|
||||
if x.op is UOps.DEFINE_ACC and len(x.src):
|
||||
idx = min([self._uops.index(l) for l in x.src])
|
||||
if x.op is UOps.DEFINE_ACC and len(x.src) > 1:
|
||||
idx = min([self._uops.index(l) for l in x.src if l.op is UOps.RANGE])
|
||||
self._uops.insert(idx, x)
|
||||
else:
|
||||
self._uops.append(x)
|
||||
|
@ -460,8 +460,10 @@ class UOpGraph:
|
|||
def type_verify(self):
|
||||
for u in self.uops:
|
||||
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
|
||||
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
|
||||
if uop is UOps.DEFINE_ACC: arg = arg[0]
|
||||
if uop in (UOps.CONST, UOps.DEFINE_ACC):
|
||||
if uop is UOps.DEFINE_ACC:
|
||||
assert dtype is not None and src[0].dtype is dtype.scalar(), f"type of {src[0].dtype=} must be a scalar {dtype.scalar()}"
|
||||
arg = src[0].arg
|
||||
assert dtype is not None and type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
||||
if uop in {UOps.CAST, UOps.BITCAST}: assert arg is None # type is the output type, not an arg
|
||||
if uop is UOps.LOAD and len(src) > 2 and src[2].op not in {UOps.IF, UOps.BARRIER}: assert src[2].dtype == dtypes.bool
|
||||
|
|
|
@ -164,8 +164,8 @@ class PTXRenderer(Renderer):
|
|||
elif uop is UOps.DEFINE_ACC:
|
||||
if dtype.count > 1:
|
||||
r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)]
|
||||
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(args[0], dtype.scalar())};")
|
||||
else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(args[0], dtype)};")
|
||||
for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].arg, dtype.scalar())};")
|
||||
else: kk(f"mov.b{self.types[dtype][1:]} {ssa('acc', u)}, {const(src[0].arg, dtype)};")
|
||||
elif uop is UOps.SPECIAL:
|
||||
assert args[1][0] != "i", "idx not supported"
|
||||
kk(f"mov.u32 %{args[1]}, {(self.gid if args[1][0] == 'g' else self.lid)[args[0]]};")
|
||||
|
|
|
@ -164,7 +164,7 @@ class CStyleLanguage(Renderer):
|
|||
bufs.append((nm:=f"data{args[0]}", (dtype,args[1])))
|
||||
r[u] = nm
|
||||
elif uop is UOps.WMMA: kk(f"{self.render_dtype(dtype)} {ssa('wmma',u)} = __{args[0]}({r[src[0]]}, {r[src[1]]}, {r[src[2]]});")
|
||||
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(args[0], dtype)};")
|
||||
elif uop is UOps.DEFINE_ACC: kk(f"{self.render_dtype(dtype)} {ssa('acc',u)} = {self.render_const(src[0].arg, dtype)};")
|
||||
elif uop is UOps.CONST: r[u] = self.render_const(args, dtype) if args >= 0 else f"({self.render_const(args, dtype)})"
|
||||
elif uop is UOps.GEP:
|
||||
assert src[0].dtype is not None
|
||||
|
|
|
@ -132,7 +132,7 @@ class LLVMRenderer(Renderer):
|
|||
lvars[u].add_incoming(lvars[src[0]], bb[-2].block)
|
||||
loop_blocks.append((bb[-1].block, phis))
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
lvars[u] = const(args[0], dtype)
|
||||
lvars[u] = const(src[0].arg, dtype)
|
||||
reduce_phis.append(u)
|
||||
elif uop is UOps.LOAD:
|
||||
if len(src) > 2:
|
||||
|
|
|
@ -41,7 +41,7 @@ class PythonProgram:
|
|||
while i < len(self.uops):
|
||||
uop, dtype, idp, arg = self.uops[i]
|
||||
void_ops = {UOps.STORE, UOps.ENDRANGE, UOps.BARRIER, UOps.IF, UOps.ENDIF}
|
||||
if uop is UOps.DEFINE_ACC: idp.clear()
|
||||
if uop is UOps.DEFINE_ACC: idp = [idp[0]]
|
||||
inp = [ul[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
dtp = [dl[v] for v in idp if self.uops[v][0] not in void_ops]
|
||||
if getenv("TRACE"): print(i, uop, dtype, arg, inp, dtp)
|
||||
|
@ -90,7 +90,7 @@ class PythonProgram:
|
|||
elif uop is UOps.CONST:
|
||||
ul[i] = [[arg] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg] * warp_size
|
||||
elif uop is UOps.DEFINE_ACC:
|
||||
ul[i] = [[arg[0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [arg[0]] * warp_size
|
||||
ul[i] = [[inp[0][0]] * warp_size for _ in range(dtype.count)] if dtype.count > 1 else [inp[0][0]] * warp_size
|
||||
elif uop is UOps.RANGE:
|
||||
if i not in ul: ul[i] = [inp[0][0]] * warp_size
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue