replace the lowerer with a contextual PatternMatcher [run_process_replay] (#6646)

* replace the lowerer with a contextual PatternMatcher [run_process_replay]

* todo

* it's REDUCE by the time it's in lowerer
This commit is contained in:
George Hotz 2024-09-22 13:22:26 +08:00 committed by GitHub
parent 4751159139
commit 84703d5b77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 61 additions and 59 deletions

View File

@ -1,10 +1,11 @@
# the job of the lowerer is to do indexing
from __future__ import annotations
import functools
from typing import List, Tuple, cast, Optional, Dict
from typing import List, Tuple, cast, Optional
from tinygrad.shape.shapetracker import ShapeTracker, variable_to_uop
from tinygrad.shape.symbolic import sint
from tinygrad.dtype import dtypes
from tinygrad.ops import KernelInfo, BinaryOps, BUFFER_UOPS, UOp, UOps
from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat
from tinygrad.renderer import Renderer
from tinygrad.helpers import all_int, get_contraction, prod, partition, flatten
@ -34,6 +35,53 @@ def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int
idx //= dims[c]
return ret[::-1] if reverse else ret
# TODO: move this to kernel.py, it doesn't depend on axes
def lower_wmma(ctx: IndependentLowerer, x: UOp):
upcast_axes = x.arg[-2]
wmma_sz = [prod(x[1] for x in l) for l in upcast_axes]
ret = UOp(UOps.WMMA, dtype=x.dtype.vec(wmma_sz[2]), src=(
UOp(UOps.CONTRACT, dtype=x.src[0].dtype.vec(wmma_sz[0]), src=(x.src[0],), arg=upcast_axes[0]),
UOp(UOps.CONTRACT, dtype=x.src[1].dtype.vec(wmma_sz[1]), src=(x.src[1],), arg=upcast_axes[1]),
UOp.const(x.dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
return UOp(UOps.EXPAND, x.dtype, (ret,), arg=upcast_axes[2])
def lower_reduce_axis(ctx: IndependentLowerer, x: UOp):
# NOTE: always using ridxs is fine here
reduce_range, reduce_expand = partition([ctx.ridxs[i] for i in x.arg[1]], lambda y: y.op is UOps.RANGE)
alu_op: BinaryOps = x.arg[0]
ret = x.src[0]
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
ret = UOp(UOps.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(ret.dtype.count)])
return UOp(UOps.REDUCE, x.dtype, (ret,) + tuple(reduce_range), alu_op) if len(reduce_range) else ret
def lower_load_store(ctx: IndependentLowerer, x: UOp):
idx, valid = x.st_arg.to_indexed_uops(ctx.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else ctx.idxs)
# TODO: check has_valid in UPat, not here
has_valid = valid.op is not UOps.CONST or valid.arg is not True
buf = x.src[0]
if x.op is UOps.LOAD:
barrier = (UOp(UOps.BARRIER, dtypes.void, (x.src[2],)),) if x.src[0].op is UOps.DEFINE_LOCAL else ()
return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((x.const_like(0), valid) if has_valid else ()) + barrier)
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
store_back = x.src[0].op is UOps.DEFINE_LOCAL and x.src[2].op is UOps.REDUCE and \
x.src[2].src[0].op is UOps.LOAD and x.src[2].src[0].src[0].op is UOps.DEFINE_LOCAL
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if u in x.src[2].src else u for u in ctx.idxs])
if x.src[0].op is UOps.DEFINE_GLOBAL or store_back:
for oidx, ridx in zip(ctx.idxs, ctx.ridxs):
if oidx != ridx: valid = valid * oidx.eq(0)
has_valid = valid.op is not UOps.CONST or valid.arg is not True
return UOp(UOps.STORE, dtypes.void, (buf, idx, x.src[2]) + ((valid,) if has_valid else ()))
pm_lowerer = PatternMatcher([
(UPat(UOps.WMMA, src=(UPat(), UPat()), name="x"), lower_wmma), # 2 param -> 3 param WMMA
(UPat(UOps.REDUCE_AXIS, name="x"), lower_reduce_axis),
(UPat(UOps.VALID, src=(UPat(UOps.SHAPETRACKER),), name="x"), lambda ctx,x: x.st_arg.to_indexed_uops(ctx.idxs)[1]),
# rewrite LOAD/STORE SHAPETRACKER to LOAD/STORE with indexed
(UPat((UOps.LOAD, UOps.STORE), src=(UPat(), UPat(UOps.SHAPETRACKER)), allow_any_len=True, name="x"), lower_load_store),
])
class IndependentLowerer:
def lower(self, ast:UOp, opts:Renderer) -> UOp:
self.output_count = len(ast.src)
@ -79,54 +127,7 @@ class IndependentLowerer:
for a in range(first_reduce, first_reduce+group_for_reduces):
self.ridxs[a] = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(full_shape[a])), (1000+a, True))
self.uop_cache: Dict[UOp, UOp] = {}
return self.to_uop(ast)
def to_uop(self, x:UOp) -> UOp:
if uop:=self.uop_cache.get(x, None): return uop
ret = self._to_uop(x)
self.uop_cache[x] = ret
return ret
def _to_uop(self, x:UOp) -> UOp:
if x.op in BUFFER_UOPS:
idx, valid = x.st_arg.to_indexed_uops(self.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else self.idxs)
if x.op is UOps.VALID: return valid
# TODO: check has_valid in UPat, not here
has_valid = valid.op is not UOps.CONST or valid.arg is not True
buf = x.src[0]
if x.op is UOps.LOAD:
barrier = (UOp(UOps.BARRIER, dtypes.void, (self.to_uop(x.src[2]),)),) if x.src[0].op is UOps.DEFINE_LOCAL else ()
return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((x.const_like(0), valid) if has_valid else ()) + barrier)
# NOTE: only store the local reduceop in the threads that are actually doing the reduce
store_back = x.src[0].op is UOps.DEFINE_LOCAL and x.src[2].op is UOps.REDUCE_AXIS and \
x.src[2].src[0].op is UOps.LOAD and x.src[2].src[0].src[0].op is UOps.DEFINE_LOCAL
# NOTE: If we're storing the reduced value back into each thread, need to zero-out the reduced axes
if store_back: idx, _ = x.st_arg.to_indexed_uops([u.const_like(0) if i in x.src[2].arg[1] else u for i,u in enumerate(self.idxs)])
if x.src[0].op is UOps.DEFINE_GLOBAL or store_back:
for oidx, ridx in zip(self.idxs, self.ridxs):
if oidx != ridx: valid = valid * oidx.eq(0)
has_valid = valid.op is not UOps.CONST or valid.arg is not True
return UOp(UOps.STORE, dtypes.void, (buf, idx, self.to_uop(x.src[2])) + ((valid,) if has_valid else ()))
in_uops = tuple(self.to_uop(y) for y in x.src)
if x.op is UOps.WMMA:
upcast_axes = x.arg[-2]
wmma_sz = [prod(x[1] for x in l) for l in upcast_axes]
ret = UOp(UOps.WMMA, dtype=x.dtype.vec(wmma_sz[2]), src=(
UOp(UOps.CONTRACT, dtype=in_uops[0].dtype.vec(wmma_sz[0]), src=(in_uops[0],), arg=upcast_axes[0]),
UOp(UOps.CONTRACT, dtype=in_uops[1].dtype.vec(wmma_sz[1]), src=(in_uops[1],), arg=upcast_axes[1]),
UOp.const(x.dtype.vec(wmma_sz[2]), 0.0)), arg=x.arg)
return UOp(UOps.EXPAND, x.dtype, (ret,), arg=upcast_axes[2])
if x.op is UOps.REDUCE_AXIS:
# NOTE: always using ridxs is fine here
reduce_range, reduce_expand = partition([self.ridxs[i] for i in x.arg[1]], lambda y: y.op is UOps.RANGE)
alu_op: BinaryOps = x.arg[0]
ret = in_uops[0]
if len(contract_axis:=flatten(x.arg for x in reduce_expand)):
ret = UOp(UOps.CONTRACT, x.dtype.vec(prod(x[1] for x in contract_axis)), (ret,), tuple(contract_axis))
ret = functools.reduce(lambda x,y: x.alu(alu_op, y), [ret.gep(i) for i in range(ret.dtype.count)])
return UOp(UOps.REDUCE, x.dtype, (ret,) + tuple(reduce_range), alu_op) if len(reduce_range) else ret
return x if x.src == in_uops else UOp(x.op, x.dtype, in_uops, x.arg)
# rewrite to add the index
return graph_rewrite(ast, pm_lowerer, ctx=self)
def ast_to_uop(ast:UOp, opts:Renderer) -> UOp: return IndependentLowerer().lower(ast, opts)

View File

@ -486,11 +486,11 @@ class PatternMatcher:
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def __add__(self, more:PatternMatcher): return PatternMatcher(self.patterns+more.patterns)
def rewrite(self, uop:UOp) -> Optional[UOp]:
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + ([] if uop.arg is None else self.pdict[(uop.op, None)]):
if not early_reject.issubset(ler): continue
if (matches := p.match(uop, {})) and (ret:=fxn(**matches[0])) is not None: return ret # NOTE: if it returns None, we keep trying to match
if (matches := p.match(uop, {})) and (ret:=(fxn(ctx, **matches[0]) if ctx is not None else fxn(**matches[0]))) is not None: return ret
return None
# *** tracking pattern matcher ***
@ -510,7 +510,7 @@ class TrackedPatternMatcher(PatternMatcher):
for p,_ in self.patterns:
if p not in match_stats: match_stats[p] = [0,0,0.0,0.0]
def rewrite(self, uop:UOp) -> Optional[UOp]:
def rewrite(self, uop:UOp, ctx=None) -> Optional[UOp]:
ret = None
ler = set([v for u in uop.src for v in ((u.op, u.arg), (u.op, None))])
for p,fxn,early_reject in self.pdict[(uop.op, uop.arg)] + ([] if uop.arg is None else self.pdict[(uop.op, None)]):
@ -519,7 +519,7 @@ class TrackedPatternMatcher(PatternMatcher):
match_stats[p][2] += time.perf_counter()-st
continue
match_stats[p][1] += 1
if (matches := p.match(uop, {})) and (ret:=fxn(**matches[0])) is not None:
if (matches := p.match(uop, {})) and (ret:=(fxn(ctx, **matches[0]) if ctx is not None else fxn(**matches[0]))) is not None:
match_stats[p][0] += 1
match_stats[p][2] += (et:=time.perf_counter()-st)
match_stats[p][3] += et
@ -551,8 +551,9 @@ if TRACK_MATCH_STATS:
# *** simple graph rewrite engine ***
class RewriteContext:
def __init__(self, pm):
def __init__(self, pm, ctx):
self.pm: PatternMatcher = pm
self.ctx = ctx
self.nodes: Dict[Tuple, UOp] = {}
self.replace: Dict[UOp, UOp] = {}
def rewrite(self, n:UOp) -> UOp:
@ -561,12 +562,12 @@ class RewriteContext:
if found := self.nodes.get(replace_source): self.replace[n] = found
else:
x = UOp(*replace_source) if new_src != n.src else n
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x)) else x
self.nodes[replace_source] = self.replace[n] = found = self.rewrite(new_x) if (new_x := self.pm.rewrite(x, self.ctx)) else x
return found
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None) -> UOp:
if TRACK_MATCH_STATS >= 2:
contexts.append(TrackedRewriteContext(f"{(f:=sys._getframe(1)).f_code.co_filename.split('/')[-1]}:{f.f_lineno}", sink, _CURRENT_KERNEL.get()))
return RewriteContext(pm).rewrite(sink)
return RewriteContext(pm, ctx).rewrite(sink)
# ***** uop type spec *****