mirror of https://github.com/commaai/tinygrad.git
replace the lowerer with a contextual PatternMatcher [run_process_replay] (#6646)
* replace the lowerer with a contextual PatternMatcher [run_process_replay] * todo * it's REDUCE by the time it's in lowerer
This commit is contained in:
parent
4751159139
commit
84703d5b77
|
@ -1,10 +1,11 @@
|
|||
# the job of the lowerer is to do indexing
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
from typing import List, Tuple, cast, Optional, Dict
|
||||
from typing import List, Tuple, cast, Optional
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, variable_to_uop
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import KernelInfo, BinaryOps, BUFFER_UOPS, UOp, UOps
|
||||
from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import all_int, get_contraction, prod, partition, flatten
|
||||
|
||||
|
@ -34,6 +35,53 @@ def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int
|
|||
idx //= dims[c]
|
||||
return ret[::-1] if reverse else ret
|
||||
|
||||
# TODO: move this to kernel.py, it doesn't depend on axes
|
||||
def lower_wmma(ctx: IndependentLowerer, x: UOp):
|
||||
upcast_axes = x.arg[-2]
|
||||
wmma_sz = [prod(x[1] for x in l) for l in upcast_axes]
|
||||
ret = UOp(UOps.WMMA, dtype=x.dtype.vec(wmma_sz[2]), src=(
|
||||
UOp(UOps.CONTRACT, dtype=x.src[0].dtype.vec(wmma_sz[0]), src=(x.src[0],), arg=upcast_axes[0]),
|
||||
UOp(UOps.CONTRACT, dtype=x.src[1].dtype.vec(wmma_sz[1]), src=(x.src[1],), arg=upcast_axes[1]),
|
||||
UOp.const(x.dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
|
||||
return UOp(UOps.EXPAND, x.dtype, (ret,), arg=upcast_axes[2])
|
||||
|
||||
def lower_reduce_axis(ctx: IndependentLowerer, x: UOp):
|
||||
# NOTE: always using ridxs is fine here
|
||||
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.arg[1]], lambda y: y.op is UOps.RANGE)
|
||||
alu_op: BinaryOps = x.arg[0]
|
||||
ret = x.src[0]
|
||||
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
||||
ret = UOp(UOps.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
|
||||
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(ret.dtype.count)])
|
||||
return UOp(UOps.REDUCE, x.dtype, (ret,) + tuple(reduce_range), alu_op) if len(reduce_range) else ret
|
||||
|
||||
def lower_load_store(ctx: IndependentLowerer, x: UOp):
|
||||
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else ctx.idxs)
|
||||
# TODO: check has_valid in UPat, not here
|
||||
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
||||
buf = x.src[0]
|
||||
if x.op is UOps.LOAD:
|
||||
barrier = (UOp(UOps.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is UOps.DEFINE_LOCAL else ()
|
||||
return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((x.const_like(0), valid) if has_valid else ()) + barrier)
|
||||
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
||||
store_back = x.src[0].op is UOps.DEFINE_LOCAL and x.src[2].op is UOps.REDUCE and \
|
||||
x.src[2].src[0].op is UOps.LOAD and x.src[2].src[0].src[0].op is UOps.DEFINE_LOCAL
|
||||
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
|
||||
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[2].src else u for u in ctx.idxs])
|
||||
if x.src[0].op is UOps.DEFINE_GLOBAL or store_back:
|
||||
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
|
||||
if oidx != ridx: valid = valid * oidx.eq(0)
|
||||
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
||||
return UOp(UOps.STORE, dtypes.void, (buf, idx, x.src[2]) + ((valid,) if has_valid else ()))
|
||||
|
||||
pm_lowerer = PatternMatcher([
|
||||
(UPat(UOps.WMMA, src=(UPat(), UPat()), name="x"), lower_wmma), # 2 param -> 3 param WMMA
|
||||
(UPat(UOps.REDUCE_AXIS, name="x"), lower_reduce_axis),
|
||||
(UPat(UOps.VALID, src=(UPat(UOps.SHAPETRACKER),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
|
||||
# rewrite LOAD/STORE SHAPETRACKER to LOAD/STORE with indexed
|
||||
(UPat((UOps.LOAD, UOps.STORE), src=(UPat(), UPat(UOps.SHAPETRACKER)), allow_any_len=True, name="x"), lower_load_store),
|
||||
])
|
||||
|
||||
class IndependentLowerer:
|
||||
def lower(self, ast:UOp, opts:Renderer) -> UOp:
|
||||
self.output_count = len(ast.src)
|
||||
|
@ -79,54 +127,7 @@ class IndependentLowerer:
|
|||
for a in range(first_reduce, first_reduce+group_for_reduces):
|
||||
self.ridxs[a] = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(full_shape[a])), (1000+a, True))
|
||||
|
||||
self.uop_cache: Dict[UOp, UOp] = {}
|
||||
return self.to_uop(ast)
|
||||
|
||||
def to_uop(self, x:UOp) -> UOp:
|
||||
if uop:=self.uop_cache.get(x, None): return uop
|
||||
ret = self._to_uop(x)
|
||||
self.uop_cache[x] = ret
|
||||
return ret
|
||||
|
||||
def _to_uop(self, x:UOp) -> UOp:
|
||||
if x.op in BUFFER_UOPS:
|
||||
idx, valid = x.st_arg.to_indexed_uops(self.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else self.idxs)
|
||||
if x.op is UOps.VALID: return valid
|
||||
# TODO: check has_valid in UPat, not here
|
||||
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
||||
buf = x.src[0]
|
||||
if x.op is UOps.LOAD:
|
||||
barrier = (UOp(UOps.BARRIER, dtypes.void, (self.to_uop(x.src[2]),)),) if x.src[0].op is UOps.DEFINE_LOCAL else ()
|
||||
return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((x.const_like(0), valid) if has_valid else ()) + barrier)
|
||||
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
|
||||
store_back = x.src[0].op is UOps.DEFINE_LOCAL and x.src[2].op is UOps.REDUCE_AXIS and \
|
||||
x.src[2].src[0].op is UOps.LOAD and x.src[2].src[0].src[0].op is UOps.DEFINE_LOCAL
|
||||
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
|
||||
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if i in x.src[2].arg[1] else u for i,u in enumerate(self.idxs)])
|
||||
if x.src[0].op is UOps.DEFINE_GLOBAL or store_back:
|
||||
for oidx, ridx in zip(self.idxs, self.ridxs):
|
||||
if oidx != ridx: valid = valid * oidx.eq(0)
|
||||
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
||||
return UOp(UOps.STORE, dtypes.void, (buf, idx, self.to_uop(x.src[2])) + ((valid,) if has_valid else ()))
|
||||
|
||||
in_uops = tuple(self.to_uop(y) for y in x.src)
|
||||
if x.op is UOps.WMMA:
|
||||
upcast_axes = x.arg[-2]
|
||||
wmma_sz = [prod(x[1] for x in l) for l in upcast_axes]
|
||||
ret = UOp(UOps.WMMA, dtype=x.dtype.vec(wmma_sz[2]), src=(
|
||||
UOp(UOps.CONTRACT, dtype=in_uops[0].dtype.vec(wmma_sz[0]), src=(in_uops[0],), arg=upcast_axes[0]),
|
||||
UOp(UOps.CONTRACT, dtype=in_uops[1].dtype.vec(wmma_sz[1]), src=(in_uops[1],), arg=upcast_axes[1]),
|
||||
UOp.const(x.dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
|
||||
return UOp(UOps.EXPAND, x.dtype, (ret,), arg=upcast_axes[2])
|
||||
if x.op is UOps.REDUCE_AXIS:
|
||||
# NOTE: always using ridxs is fine here
|
||||
reduce_range, reduce_expand = partition([self.ridxs[i] for i in x.arg[1]], lambda y: y.op is UOps.RANGE)
|
||||
alu_op: BinaryOps = x.arg[0]
|
||||
ret = in_uops[0]
|
||||
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
|
||||
ret = UOp(UOps.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
|
||||
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(ret.dtype.count)])
|
||||
return UOp(UOps.REDUCE, x.dtype, (ret,) + tuple(reduce_range), alu_op) if len(reduce_range) else ret
|
||||
return x if x.src == in_uops else UOp(x.op, x.dtype, in_uops, x.arg)
|
||||
# rewrite to add the index
|
||||
return graph_rewrite(ast, pm_lowerer, ctx=self)
|
||||
|
||||
def ast_to_uop(ast:UOp, opts:Renderer) -> UOp: return IndependentLowerer().lower(ast, opts)
|
||||
|
|
|
@ -486,11 +486,11 @@ class PatternMatcher:
|
|||
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
|
||||
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
|
||||
|
||||
def rewrite(self, uop:UOp) -> Optional[UOp]:
|
||||
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
|
||||
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
|
||||
for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + ([] if uop.arg is None else self.pdict[(uop.op, None)]):
|
||||
if not early_reject.issubset(ler): continue
|
||||
if (matches := p.match(uop, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match
|
||||
if (matches := p.match(uop, {})) and (ret:=(fxn(ctx, **matches[0]) if ctx is not None else fxn(**matches[0]))) is not None: return ret
|
||||
return None
|
||||
|
||||
# *** tracking pattern matcher ***
|
||||
|
@ -510,7 +510,7 @@ class TrackedPatternMatcher(PatternMatcher):
|
|||
for p,_ in self.patterns:
|
||||
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
|
||||
|
||||
def rewrite(self, uop:UOp) -> Optional[UOp]:
|
||||
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
|
||||
ret = None
|
||||
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
|
||||
for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + ([] if uop.arg is None else self.pdict[(uop.op, None)]):
|
||||
|
@ -519,7 +519,7 @@ class TrackedPatternMatcher(PatternMatcher):
|
|||
match_stats[p][2] += time.perf_counter()-st
|
||||
continue
|
||||
match_stats[p][1] += 1
|
||||
if (matches := p.match(uop, {})) and (ret:=fxn(**matches[0])) is not None:
|
||||
if (matches := p.match(uop, {})) and (ret:=(fxn(ctx, **matches[0]) if ctx is not None else fxn(**matches[0]))) is not None:
|
||||
match_stats[p][0] += 1
|
||||
match_stats[p][2] += (et:=time.perf_counter()-st)
|
||||
match_stats[p][3] += et
|
||||
|
@ -551,8 +551,9 @@ if TRACK_MATCH_STATS:
|
|||
# *** simple graph rewrite engine ***
|
||||
|
||||
class RewriteContext:
|
||||
def __init__(self, pm):
|
||||
def __init__(self, pm, ctx):
|
||||
self.pm: PatternMatcher = pm
|
||||
self.ctx = ctx
|
||||
self.nodes: Dict[Tuple, UOp] = {}
|
||||
self.replace: Dict[UOp, UOp] = {}
|
||||
def rewrite(self, n:UOp) -> UOp:
|
||||
|
@ -561,12 +562,12 @@ class RewriteContext:
|
|||
if found := self.nodes.get(replace_source): self.replace[n] = found
|
||||
else:
|
||||
x = UOp(*replace_source) if new_src != n.src else n
|
||||
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x)) else x
|
||||
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x, self.ctx)) else x
|
||||
return found
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
|
||||
if TRACK_MATCH_STATS >= 2:
|
||||
contexts.append(TrackedRewriteContext(f"{(f:=sys._getframe(1)).f_code.co_filename.split('/')[-1]}:{f.f_lineno}", sink, _CURRENT_KERNEL.get()))
|
||||
return RewriteContext(pm).rewrite(sink)
|
||||
return RewriteContext(pm, ctx).rewrite(sink)
|
||||
|
||||
# ***** uop type spec *****
|
||||
|
||||
|
|
Loading…
Reference in New Issue