mirror of
https://github.com/ajouatom/openpilot.git
synced 2026-03-02 02:53:55 +08:00
Update251203 (#233)
This commit is contained in:
@@ -2,16 +2,16 @@ from typing import Any, cast
|
||||
import functools, operator, itertools
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.dtype import dtypes, ImageDType, PtrDType, DType, AddrSpace
|
||||
from tinygrad.dtype import dtypes, ImageDType, DType, AddrSpace, Invalid
|
||||
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element
|
||||
from tinygrad.uop.symbolic import split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
|
||||
from tinygrad.helpers import getenv, flatten, AMX, prod, partition
|
||||
from tinygrad.uop.symbolic import uop_given_valid, parse_valid, sym, symbolic_flat, invalid_gate
|
||||
from tinygrad.helpers import getenv, flatten, AMX, prod
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# ***** image load valid simplification *****
|
||||
|
||||
def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
||||
if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.const_like(0)
|
||||
if (idx:=uop_given_valid(valid, start_idx)) is None: return buf.index(UOp.invalid())
|
||||
if not isinstance(buf.dtype, ImageDType): return None if idx is start_idx else buf.index(idx, valid)
|
||||
|
||||
# wait for it to be image indexed before running simplification
|
||||
@@ -19,13 +19,13 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
||||
|
||||
# can drop valid if idx is out of bound when valid is False
|
||||
drop_stmt = []
|
||||
for stmt in split_uop(valid, Ops.AND):
|
||||
for stmt in valid.split_uop(Ops.AND):
|
||||
try: X, is_upper_bound, c = parse_valid(stmt)
|
||||
except ValueError: return None
|
||||
|
||||
# for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i
|
||||
if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, Ops.ADD)):
|
||||
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, Ops.ADD), idx)
|
||||
if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in X.split_uop(Ops.ADD)):
|
||||
testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), X.split_uop(Ops.ADD), idx)
|
||||
testidx = testidx.simplify()
|
||||
if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0:
|
||||
drop_stmt.append(stmt)
|
||||
@@ -42,7 +42,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> UOp|None:
|
||||
break
|
||||
|
||||
if not drop_stmt and idx is start_idx: return None
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, Ops.AND) if s not in drop_stmt]) else None
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in valid.split_uop(Ops.AND) if s not in drop_stmt]) else None
|
||||
return buf.index(idx, new_valid)
|
||||
|
||||
def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp, cast:UOp|None=None) -> UOp|None:
|
||||
@@ -51,12 +51,15 @@ def delete_redundant_gates(store:UOp, buf:UOp, idx:UOp, val:UOp, store_gate:UOp,
|
||||
return UOp.store(buf.index(idx).cast(cast.dtype) if cast is not None else buf.index(idx), val, *store.src[2:])
|
||||
|
||||
load_store_indexing = PatternMatcher([
|
||||
# simplify valid
|
||||
(UPat(Ops.AND, name="valid"), simplify_valid),
|
||||
# image load valid idx simplification
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat.var("valid"))), simplify_valid_load),
|
||||
# index True is just Index
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("start_idx"), UPat(Ops.CONST, arg=True))), lambda buf,start_idx: buf.index(start_idx)),
|
||||
# lower turn the invalid into a gate, must come before index dtype lowering
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), invalid_gate,),), lambda buf,x,cond,i: buf.index(x, cond)),
|
||||
# drop true gate
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("x"), UPat.const(dtypes.bool, True)),), lambda buf,x: buf.index(x)),
|
||||
# remove hanging cast
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast()),), lambda buf,idx: buf.index(idx)),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx", dtypes.int).cast(), UPat.var("valid"))), lambda buf,idx,valid: buf.index(idx, valid)),
|
||||
# delete_redundant_gates (after expand)
|
||||
(UPat(Ops.STORE, src=(UPat.any(stidx:=UPat.var("buf").index(UPat.var("idx"), UPat.var("store_gate")), stidx.cast().named("cast")),
|
||||
UPat.var("val")), name="store", allow_any_len=True), delete_redundant_gates),
|
||||
@@ -75,14 +78,12 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
||||
idx: Any = midx.src[i].src[1]
|
||||
if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg
|
||||
elif idx.op is Ops.ADD and idx.src[0].op is Ops.CONST: root_src, arg = idx.src[1], idx.src[0].arg
|
||||
elif idx.op is Ops.CONST and idx.arg is Invalid: root_src, arg = "INVALID", 0
|
||||
elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg
|
||||
else: root_src, arg = idx, 0
|
||||
if len(midx.src[i].src) == 3: root_src = (midx.src[i].src[2], root_src)
|
||||
offsets_rootsrc[root_src].setdefault(arg, []).append(i)
|
||||
|
||||
# the buf.dtype is always a pointer
|
||||
ptrdtype = cast(PtrDType, buf.dtype)
|
||||
|
||||
# then rewrite everything we can into groups
|
||||
ret = []
|
||||
idxs: list[int|None] = [None]*vec.dtype.count
|
||||
@@ -92,7 +93,7 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
||||
for grp in grouped_offsets:
|
||||
# get the index offset for this element. using [0] is okay, because they are the same
|
||||
lidx = midx.src[offsets[grp[0]][0]]
|
||||
if len(grp) > 1: lidx = lidx.cast(ptrdtype.base.vec(len(grp)).ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace))
|
||||
if len(grp) > 1: lidx = lidx.cast(buf.ptrdtype.base.vec(len(grp)).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace))
|
||||
# set the idxs of the output
|
||||
for i,g in enumerate(grp):
|
||||
for oo in offsets[g]: idxs[oo] = global_offset+i
|
||||
@@ -101,7 +102,7 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
||||
global_offset += len(grp)
|
||||
assert None not in idxs, f"some idxs are missing {idxs}"
|
||||
# this base thing is for image, we want the CAT to be a normal pointer
|
||||
post_cat = UOp(Ops.PTRCAT, ptrdtype.base.ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace).vec(vec.dtype.count), tuple(ret))
|
||||
post_cat = UOp(Ops.PTRCAT, buf.ptrdtype.base.ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace).vec(vec.dtype.count), tuple(ret))
|
||||
return post_cat.gep(tuple(cast(list[int], idxs)))
|
||||
|
||||
def cat_after_store(cat:UOp, data:UOp, sto:UOp):
|
||||
@@ -154,7 +155,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||
must_divide = False
|
||||
elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
|
||||
pass
|
||||
elif cast(PtrDType, buf.dtype).addrspace == AddrSpace.REG:
|
||||
elif buf.ptrdtype.addrspace == AddrSpace.REG:
|
||||
pass
|
||||
elif isinstance(buf.dtype, ImageDType):
|
||||
lengths = [4]
|
||||
@@ -169,13 +170,12 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||
# split based on the fold lengths
|
||||
global_offset = 0
|
||||
ret = []
|
||||
ptrdtype = cast(PtrDType, buf.dtype)
|
||||
while global_offset < sz:
|
||||
# with 1 at the end of the lengths list, this will always hit
|
||||
for fold_length in lengths:
|
||||
if global_offset+fold_length > sz: continue
|
||||
lidx = buf.index(idx.src[1] + global_offset, idx.src[2] if len(idx.src) > 2 else None)
|
||||
if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, addrspace=ptrdtype.addrspace))
|
||||
if fold_length > 1: lidx = lidx.cast(buf.ptrdtype.base.vec(fold_length).ptr(size=buf.ptrdtype.size, addrspace=buf.ptrdtype.addrspace))
|
||||
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))
|
||||
else: ret.append(ls.replace(src=(lidx,)+ls.src[1:], dtype=ls.dtype.scalar().vec(fold_length)))
|
||||
global_offset += fold_length
|
||||
@@ -232,17 +232,20 @@ def no_vectorized_alu(alu:UOp):
|
||||
alus = tuple(UOp(alu.op, alu.dtype.scalar(), tuple(s.gep(i) for s in alu.src), alu.arg) for i in range(alu.dtype.vcount))
|
||||
return UOp(Ops.VECTORIZE, alu.dtype, alus)
|
||||
|
||||
def no_vectorized_acc(acc:UOp, c:UOp):
|
||||
if acc.dtype.count == 1: return None
|
||||
assert c.arg == 0, "this only supports index 0"
|
||||
new_acc = acc.replace(dtype=acc.dtype.base.scalar().ptr(acc.dtype.count, cast(PtrDType, acc.dtype).addrspace))
|
||||
return UOp(Ops.PTRCAT, acc.dtype, tuple([new_acc.index(UOp.const(dtypes.int, i)) for i in range(acc.dtype.count)]))
|
||||
def no_vectorized_buf(buf:UOp):
|
||||
return buf.replace(dtype=buf.ptrdtype.base.scalar().ptr(buf.ptrdtype.size*buf.ptrdtype.count, buf.ptrdtype.addrspace)).cast(buf.dtype)
|
||||
|
||||
def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
|
||||
cnt = cast.dtype.count
|
||||
assert idx.dtype.count == 1, f"idx dtype must be 1 {idx.dtype}"
|
||||
return buf.broadcast(cnt).index(idx.broadcast(cnt)*cnt+UOp.const(dtypes.int.vec(cnt), tuple(range(cnt))))
|
||||
|
||||
devectorize = PatternMatcher([
|
||||
# no ALU on vectorized dtypes
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name="alu"), no_vectorized_alu),
|
||||
(UPat(Ops.WMMA, name="wmma"), no_vectorized_wmma),
|
||||
(UPat(Ops.DEFINE_REG, name="acc").index(UPat.cvar("c")), no_vectorized_acc),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf").cast(name="cast").index(UPat.var("idx")), no_vectorized_index),
|
||||
])
|
||||
|
||||
pm_render = PatternMatcher([
|
||||
@@ -255,7 +258,8 @@ pm_render = PatternMatcher([
|
||||
(UPat(Ops.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
||||
# give any loads that are masked an alt value
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.INDEX, src=(UPat(), UPat(), UPat())).or_casted(),), allow_any_len=True, name="x"),
|
||||
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:]) if len(x.src) == 1 or x.src[1].op is Ops.CUSTOM else None),
|
||||
lambda x: x.replace(src=(x.src[0], x.const_like(0))+x.src[1:])
|
||||
if len(x.src) == 1 or x.src[1].op in (Ops.CUSTOM, Ops.STORE, Ops.BARRIER) else None),
|
||||
# gate any stores that aren't gated with ifs
|
||||
(UPat(Ops.STORE, src=(UPat(src=(UPat(), UPat(), UPat(dtype=dtypes.bool)), name="idx").or_casted(), UPat()), name="store", allow_any_len=True),
|
||||
lambda store,idx: UOp(Ops.STORE, dtype=store.dtype, src=store.src[:2]+(UOp(Ops.IF, src=(idx.src[2],)),)+store.src[2:]) if \
|
||||
@@ -293,94 +297,7 @@ def reduce_to_acc(ctx:ReduceContext, red:UOp):
|
||||
ret = functools.reduce(lambda x,y: x.alu(red.arg, y), lst)
|
||||
return acc.load(acc.store(ret, *reduce_range)) if len(reduce_range) != 0 else ret
|
||||
|
||||
def no_vectorized_reduce(inp:UOp, red:UOp):
|
||||
if inp.dtype != red.dtype:
|
||||
red = red.replace(src=(functools.reduce(lambda x,y: x.alu(red.arg, y), horizontal_reduce(inp, red.dtype)),)+red.src[1:])
|
||||
if red.dtype.vcount == 1: return red
|
||||
# no_vectorize_alu ignoring ranges
|
||||
if red.dtype.vcount == 1: return None
|
||||
alus = tuple(UOp(red.op, red.dtype.scalar(), (red.src[0].gep(i),)+red.src[1:], red.arg) for i in range(red.dtype.vcount))
|
||||
return UOp(Ops.VECTORIZE, red.dtype, alus)
|
||||
|
||||
def reduce_rangeless(red:UOp):
|
||||
# TODO: share code with reduce_unparented
|
||||
if red.arg not in {Ops.ADD, Ops.MAX}: return None
|
||||
if red.src[0].dtype != red.dtype: return None
|
||||
if any(x.op in {Ops.RANGE} for x in red.src[0].toposort()): return None
|
||||
ret = red.src[0]
|
||||
if red.arg is Ops.ADD:
|
||||
for r in red.src[1:]:
|
||||
ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
||||
return ret
|
||||
|
||||
def no_range(u:UOp) -> bool: return not any(x.op is Ops.RANGE for x in u.sparents)
|
||||
|
||||
pm_reduce_collapse = PatternMatcher([
|
||||
# lift x+y out of reduce on lt
|
||||
((UPat.var("x")+UPat.var("y")) < UPat.var("c"), lambda x,y,c: (x < (c-y)) if no_range(y) and no_range(c) else None),
|
||||
# lift x*y out of reduce
|
||||
((UPat.var("x")*UPat.var("y")) < UPat.var("c"),
|
||||
lambda x,y,c: (x < ((c+y-1) // y)) if no_range(y) and no_range(c) and y.vmin > 0 else None),
|
||||
# lift x+y out of reduce on ne
|
||||
((UPat.var("x")+UPat.var("y")) != UPat.var("c"), lambda x,y,c: (x != (c-y)) if no_range(y) and no_range(c) else None),
|
||||
# fold the range
|
||||
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(0, UPat.cvar("val")).reduce(arg=Ops.ADD, allow_any_len=True),
|
||||
lambda r,cut,val: (r.src[0]-cut).maximum(0).minimum(r.src[0]).cast(val.dtype) * val),
|
||||
((UPat(Ops.RANGE, name="r") < UPat.var("cut")).where(UPat.cvar("val"), 0).reduce(arg=Ops.ADD, allow_any_len=True),
|
||||
lambda r,cut,val: cut.maximum(0).minimum(r.src[0]).cast(val.dtype) * val),
|
||||
# REDUCE on ADD
|
||||
((UPat.var("x")+UPat.var("y")).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
|
||||
lambda x,y,r: x.reduce(*r.src[1:], arg=Ops.ADD) + y.reduce(*r.src[1:],arg=Ops.ADD)),
|
||||
# MUL casted bool
|
||||
((UPat.var("x") * UPat.var("gate", dtype=dtypes.bool).cast().or_broadcasted(name="b")),
|
||||
lambda x,gate,b=None: gate.broadcast(x.dtype.count).where(x, 0) if b is not None else gate.where(x, 0)),
|
||||
# WHERE on LOAD (works on max too)
|
||||
(UPat.var("gate").where(UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load(), 0).reduce(arg=Ops.ADD, allow_any_len=True),
|
||||
lambda buf,idx,gate: buf.index(idx, gate).load()),
|
||||
(UPat.var("gate").where(0, UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var("idx"))).load()).reduce(arg=Ops.ADD, allow_any_len=True),
|
||||
lambda buf,idx,gate: buf.index(idx, gate.logical_not()).load()),
|
||||
# INDEX on RANGE / gated RANGE
|
||||
(UPat.var("buf").index(UPat.var("expr"), UPat.var("idx").eq(UPat(Ops.RANGE, name="r").or_casted())),
|
||||
lambda buf,r,idx,expr: buf.index(expr.substitute({r:idx.cast(r.dtype)}), (idx.cast(r.dtype) >= 0) & (idx.cast(r.dtype) < r.src[0]))),
|
||||
# AND on WHERE
|
||||
((UPat.any(UPat(Ops.DEFINE_VAR, name="x"), UPat(Ops.DEFINE_VAR).gep(name="x")) & UPat.var("y")) \
|
||||
.where(UPat.cvar("c"), 0).reduce(arg=Ops.ADD, allow_any_len=True, name="r"),
|
||||
lambda x,y,c,r: y.where(c, 0).reduce(*r.src[1:], arg=Ops.ADD)*x.cast(c.dtype)),
|
||||
# remove REDUCEs that no longer have a RANGE in the src
|
||||
(UPat(Ops.REDUCE, name="red"), reduce_rangeless),
|
||||
# devectorize REDUCE
|
||||
(UPat(Ops.VECTORIZE, name="inp").reduce(name="red", allow_any_len=True), no_vectorized_reduce),
|
||||
# index/load/where. TODO: this is more aggressive than needed
|
||||
(UPat((Ops.INDEX, Ops.LOAD, Ops.WHERE), name="alu"), no_vectorized_alu),
|
||||
])+sym
|
||||
|
||||
def reduce_collapse(red:UOp):
|
||||
included, not_included = partition(red.parents, lambda x: any(y in x.sparents for y in red.src[1:]))
|
||||
if any(x.op in {Ops.STORE, Ops.REDUCE} for x in included): return None
|
||||
replaces: dict[UOp, UOp] = {}
|
||||
for u in included:
|
||||
for s in u.src:
|
||||
if s in not_included and s not in replaces and s.op not in {Ops.CONST, Ops.VCONST, Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_VAR}:
|
||||
replaces[s] = UOp(Ops.DEFINE_VAR, dtype=s.dtype, arg=(f'in{len(replaces)}', s.vmin, s.vmax))
|
||||
collapse_fxn = red.substitute(replaces)
|
||||
sink = graph_rewrite(collapse_fxn, pm_reduce_collapse, name="reduce_collapse")
|
||||
if any(x.op is Ops.RANGE for x in sink.toposort()): return None
|
||||
return sink.substitute({v:k for k,v in replaces.items()})
|
||||
|
||||
def reduce_unparented(red:UOp):
|
||||
if red.arg not in {Ops.ADD, Ops.MAX}: return None
|
||||
reduce_parented, reduce_unparented = partition(red.src[1:], lambda x: x in red.src[0].sparents)
|
||||
if len(reduce_unparented) == 0: return None
|
||||
ret = red.replace(src=(red.src[0],)+tuple(reduce_parented)) if len(reduce_parented) or red.dtype != red.src[0].dtype else red.src[0]
|
||||
if red.arg is Ops.ADD:
|
||||
for r in reduce_unparented: ret = ret * r.src[0].cast(ret.dtype.scalar()).broadcast(ret.dtype.count)
|
||||
return ret
|
||||
|
||||
pm_reduce = PatternMatcher([
|
||||
# remove any ranges from a REDUCE that aren't referenced in the reduce source
|
||||
(UPat(Ops.REDUCE, name="red"), reduce_unparented),
|
||||
# remove REDUCE without loads (generic arange opt / indexing). TODO: support multi range
|
||||
(UPat(Ops.REDUCE, src=(UPat(), UPat()), name="red"), reduce_collapse),
|
||||
# REDUCE -> DEFINE_ACC+ASSIGN
|
||||
(UPat(Ops.REDUCE, name="red"), reduce_to_acc),
|
||||
# tensor core built in accumulate
|
||||
|
||||
Reference in New Issue
Block a user