Update251203 (#233)

This commit is contained in:
carrot
2025-12-03 10:28:27 +09:00
committed by GitHub
parent d6899edd97
commit c5ebcbcb97
347 changed files with 8678 additions and 13489 deletions

View File

@@ -2,16 +2,16 @@ from typing import Any, cast
import functools, operator, itertools
from collections import defaultdict
from dataclasses import dataclass
from tinygrad.dtype import dtypes, ImageDType, PtrDType, DType, AddrSpace
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element
from tinygrad.uop.symbolic import split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
from tinygrad.helpers import getenv, flatten, AMX, prod, partition
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, sym, symbolic_flat, invalid_gate
from tinygrad.helpers import getenv, flatten, AMX, prod
from tinygrad.renderer import Renderer
# ***** image load valid simplification *****
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.const_like(0)
if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.index(UOp.invalid())
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid)
# wait for it to be image indexed before running simplification
@@ -19,13 +19,13 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
# can drop valid if idx is out of bound when valid is False
drop_stmt = []
for stmt in split_uop(valid, Ops.AND):
for stmt in valid.split_uop(Ops.AND):
try: X, is_upper_bound, c = parse_valid(stmt)
except ValueError: return None
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, Ops.ADD)):
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, Ops.ADD), idx)
if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in X.split_uop(Ops.ADD)):
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), X.split_uop(Ops.ADD), idx)
testidx = testidx.simplify()
if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0:
drop_stmt.append(stmt)
@@ -42,7 +42,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
break
if not drop_stmt and idx is start_idx: return None
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, Ops.AND) if s not in drop_stmt]) else None
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None
return buf.index(idx, new_valid)
def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
@@ -51,12 +51,15 @@ def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp,
return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val, *store.src[2:])
load_store_indexing = PatternMatcher([
# simplify valid
(UPat(Ops.AND, name="valid"), simplify_valid),
# image load valid idx simplification
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
# index True is just Index
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat(Ops.CONST, arg=True))), lambda buf,start_idx: buf.index(start_idx)),
# lower turn the invalid into a gate, must come before index dtype lowering
(UPat(Ops.INDEX, src=(UPat.var("buf"), invalid_gate,),), lambda buf,x,cond,i: buf.index(x, cond)),
# drop true gate
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x)),
# remove hanging cast
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast()),), lambda buf,idx: buf.index(idx)),
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
# delete_redundant_gates (after expand)
(UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
UPat.var("val")), name="store", allow_any_len=True), delete_redundant_gates),
@@ -75,14 +78,12 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
idx: Any = midx.src[i].src[1]
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
else: root_src, arg = idx, 0
if len(midx.src[i].src) == 3: root_src = (midx.src[i].src[2], root_src)
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
# the buf.dtype is always a pointer
ptrdtype = cast(PtrDType, buf.dtype)
# then rewrite everything we can into groups
ret = []
idxs: list[int|None] = [None]*vec.dtype.count
@@ -92,7 +93,7 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
for grp in grouped_offsets:
# get the index offset for this element. using [0] is okay, because they are the same
lidx = midx.src[offsets[grp[0]][0]]
if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace))
if len(grp) > 1: lidx = lidx.cast(buf.ptrdtype.base.vec(len(grp)).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace))
# set the idxs of the output
for i,g in enumerate(grp):
for oo in offsets[g]: idxs[oo] = global_offset+i
@@ -101,7 +102,7 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
global_offset += len(grp)
assert None not in idxs, f"some idxs are missing {idxs}"
# this base thing is for image, we want the CAT to be a normal pointer
post_cat = UOp(Ops.PTRCAT, ptrdtype.base.ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace).vec(vec.dtype.count), tuple(ret))
post_cat = UOp(Ops.PTRCAT, buf.ptrdtype.base.ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace).vec(vec.dtype.count), tuple(ret))
return post_cat.gep(tuple(cast(list[int], idxs)))
def cat_after_store(cat:UOp, data:UOp, sto:UOp):
@@ -154,7 +155,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
must_divide = False
elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
pass
elif cast(PtrDType, buf.dtype).addrspace == AddrSpace.REG:
elif buf.ptrdtype.addrspace == AddrSpace.REG:
pass
elif isinstance(buf.dtype, ImageDType):
lengths = [4]
@@ -169,13 +170,12 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
# split based on the fold lengths
global_offset = 0
ret = []
ptrdtype = cast(PtrDType, buf.dtype)
while global_offset < sz:
# with 1 at the end of the lengths list, this will always hit
for fold_length in lengths:
if global_offset+fold_length > sz: continue
lidx = buf.index(idx.src[1] + global_offset, idx.src[2] if len(idx.src) > 2 else None)
if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace))
if fold_length > 1: lidx = lidx.cast(buf.ptrdtype.base.vec(fold_length).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace))
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))
else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length)))
global_offset += fold_length
@@ -232,17 +232,20 @@ def no_vectorized_alu(alu:UOp):
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
return UOp(Ops.VECTORIZE, alu.dtype, alus)
def no_vectorized_acc(acc:UOp, c:UOp):
if acc.dtype.count == 1: return None
assert c.arg == 0, "this only supports index 0"
new_acc = acc.replace(dtype=acc.dtype.base.scalar().ptr(acc.dtype.count, cast(PtrDType, acc.dtype).addrspace))
return UOp(Ops.PTRCAT, acc.dtype, tuple([new_acc.index(UOp.const(dtypes.int, i)) for i in range(acc.dtype.count)]))
def no_vectorized_buf(buf:UOp):
return buf.replace(dtype=buf.ptrdtype.base.scalar().ptr(buf.ptrdtype.size*buf.ptrdtype.count, buf.ptrdtype.addrspace)).cast(buf.dtype)
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
cnt = cast.dtype.count
assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}"
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.int.vec(cnt), tuple(range(cnt))))
devectorize = PatternMatcher([
# no ALU on vectorized dtypes
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu),
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
(UPat(Ops.DEFINE_REG, name="acc").index(UPat.cvar("c")), no_vectorized_acc),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
])
pm_render = PatternMatcher([
@@ -255,7 +258,8 @@ pm_render = PatternMatcher([
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
# give any loads that are masked an alt value
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) if len(x.src) == 1 or x.src[1].op is Ops.CUSTOM else None),
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:])
if len(x.src) == 1 or x.src[1].op in (Ops.CUSTOM, Ops.STORE, Ops.BARRIER) else None),
# gate any stores that aren't gated with ifs
(UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store", allow_any_len=True),
lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src[:2]+(UOp(Ops.IF, src=(idx.src[2],)),)+store.src[2:]) if \
@@ -293,94 +297,7 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
return acc.load(acc.store(ret, *reduce_range)) if len(reduce_range) != 0 else ret
def no_vectorized_reduce(inp:UOp, red:UOp):
if inp.dtype != red.dtype:
red = red.replace(src=(functools.reduce(lambda x,y: x.alu(red.arg, y), horizontal_reduce(inp, red.dtype)),)+red.src[1:])
if red.dtype.vcount == 1: return red
# no_vectorize_alu ignoring ranges
if red.dtype.vcount == 1: return None
alus = tuple(UOp(red.op, red.dtype.scalar(), (red.src[0].gep(i),)+red.src[1:], red.arg) for i in range(red.dtype.vcount))
return UOp(Ops.VECTORIZE, red.dtype, alus)
def reduce_rangeless(red:UOp):
# TODO: share code with reduce_unparented
if red.arg not in {Ops.ADD, Ops.MAX}: return None
if red.src[0].dtype != red.dtype: return None
if any(x.op in {Ops.RANGE} for x in red.src[0].toposort()): return None
ret = red.src[0]
if red.arg is Ops.ADD:
for r in red.src[1:]:
ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
return ret
def no_range(u:UOp) -> bool: return not any(x.op is Ops.RANGE for x in u.sparents)
pm_reduce_collapse = PatternMatcher([
# lift x+y out of reduce on lt
((UPat.var("x")+UPat.var("y")) < UPat.var("c"), lambda x,y,c: (x < (c-y)) if no_range(y) and no_range(c) else None),
# lift x*y out of reduce
((UPat.var("x")*UPat.var("y")) < UPat.var("c"),
lambda x,y,c: (x < ((c+y-1) // y)) if no_range(y) and no_range(c) and y.vmin > 0 else None),
# lift x+y out of reduce on ne
((UPat.var("x")+UPat.var("y")) != UPat.var("c"), lambda x,y,c: (x != (c-y)) if no_range(y) and no_range(c) else None),
# fold the range
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(0, UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True),
lambda r,cut,val: (r.src[0]-cut).maximum(0).minimum(r.src[0]).cast(val.dtype) * val),
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(UPat.cvar("val"), 0).reduce(arg=Ops.ADD, allow_any_len=True),
lambda r,cut,val: cut.maximum(0).minimum(r.src[0]).cast(val.dtype) * val),
# REDUCE on ADD
((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
lambda x,y,r: x.reduce(*r.src[1:], arg=Ops.ADD) + y.reduce(*r.src[1:],arg=Ops.ADD)),
# MUL casted bool
((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast().or_broadcasted(name="b")),
lambda x,gate,b=None: gate.broadcast(x.dtype.count).where(x, 0) if b is not None else gate.where(x, 0)),
# WHERE on LOAD (works on max too)
(UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD, allow_any_len=True),
lambda buf,idx,gate: buf.index(idx, gate).load()),
(UPat.var("gate").where(0, UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD, allow_any_len=True),
lambda buf,idx,gate: buf.index(idx, gate.logical_not()).load()),
# INDEX on RANGE / gated RANGE
(UPat.var("buf").index(UPat.var("expr"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted())),
lambda buf,r,idx,expr: buf.index(expr.substitute({r:idx.cast(r.dtype)}), (idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0]))),
# AND on WHERE
((UPat.any(UPat(Ops.DEFINE_VAR, name="x"), UPat(Ops.DEFINE_VAR).gep(name="x")) & UPat.var("y")) \
.where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype)),
# remove REDUCEs that no longer have a RANGE in the src
(UPat(Ops.REDUCE, name="red"), reduce_rangeless),
# devectorize REDUCE
(UPat(Ops.VECTORIZE, name="inp").reduce(name="red", allow_any_len=True), no_vectorized_reduce),
# index/load/where. TODO: this is more aggressive than needed
(UPat((Ops.INDEX, Ops.LOAD, Ops.WHERE), name="alu"), no_vectorized_alu),
])+sym
def reduce_collapse(red:UOp):
included, not_included = partition(red.parents, lambda x: any(y in x.sparents for y in red.src[1:]))
if any(x.op in {Ops.STORE, Ops.REDUCE} for x in included): return None
replaces: dict[UOp, UOp] = {}
for u in included:
for s in u.src:
if s in not_included and s not in replaces and s.op not in {Ops.CONST, Ops.VCONST, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}:
replaces[s] = UOp(Ops.DEFINE_VAR, dtype=s.dtype, arg=(f'in{len(replaces)}', s.vmin, s.vmax))
collapse_fxn = red.substitute(replaces)
sink = graph_rewrite(collapse_fxn, pm_reduce_collapse, name="reduce_collapse")
if any(x.op is Ops.RANGE for x in sink.toposort()): return None
return sink.substitute({v:k for k,v in replaces.items()})
def reduce_unparented(red:UOp):
if red.arg not in {Ops.ADD, Ops.MAX}: return None
reduce_parented, reduce_unparented = partition(red.src[1:], lambda x: x in red.src[0].sparents)
if len(reduce_unparented) == 0: return None
ret = red.replace(src=(red.src[0],)+tuple(reduce_parented)) if len(reduce_parented) or red.dtype != red.src[0].dtype else red.src[0]
if red.arg is Ops.ADD:
for r in reduce_unparented: ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
return ret
pm_reduce = PatternMatcher([
# remove any ranges from a REDUCE that aren't referenced in the reduce source
(UPat(Ops.REDUCE, name="red"), reduce_unparented),
# remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range
(UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse),
# REDUCE -> DEFINE_ACC+ASSIGN
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
# tensor core built in accumulate