move assign st override to upat (#7122)

* move assign st override to upat

* merge view
This commit is contained in:
qazal 2024-10-17 13:33:37 +03:00 committed by GitHub
parent ded1b38b84
commit a2eefa6f97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 7 additions and 10 deletions

View File

@ -112,6 +112,9 @@ view_left = merge_views+PatternMatcher([
# push VIEW to stores # push VIEW to stores
view_right = merge_views+PatternMatcher([ view_right = merge_views+PatternMatcher([
# ASSIGN can override st
(UPat(UOps.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(UOps.ASSIGN, name="a"))),
lambda a,b,st: UOp.store(b, (a.arg[0]+st.arg).to_uop(), a.replace(arg=())) if a.arg else None),
# view on reduce creates a new VIEW # view on reduce creates a new VIEW
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r), (UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r),
# push a SWIZZLE down to STORE, through a reduce (ONLY reshapes) # push a SWIZZLE down to STORE, through a reduce (ONLY reshapes)
@ -158,7 +161,7 @@ def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], bu
src = tuple(to_uop(x, outputs, inputs, buf_uops, cache) for x in buf.srcs) src = tuple(to_uop(x, outputs, inputs, buf_uops, cache) for x in buf.srcs)
if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg) if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg)
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, (buf_uops[buf.buffer], src[0])) elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, (buf_uops[buf.buffer], src[0]))
elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (buf_uops[buf.buffer], src[1])) elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (buf_uops[buf.buffer], src[1]), buf.arg)
elif buf.op is UnaryOps.CAST: ret = UOp(UOps.CAST, dtype, src) elif buf.op is UnaryOps.CAST: ret = UOp(UOps.CAST, dtype, src)
elif buf.op is UnaryOps.BITCAST: ret = UOp(UOps.BITCAST, dtype, src) elif buf.op is UnaryOps.BITCAST: ret = UOp(UOps.BITCAST, dtype, src)
else: ret = UOp(UOps.ALU, dtype, src, buf.op) else: ret = UOp(UOps.ALU, dtype, src, buf.op)
@ -172,16 +175,10 @@ def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], var_val
(out.metadata,) if out.metadata is not None else None) (out.metadata,) if out.metadata is not None else None)
# create the stores # create the stores
cache: Dict[LazyBuffer, UOp] = {} cache: Dict[LazyBuffer, UOp] = {}
ast: List[UOp] = []
inputs: List[LazyBuffer] = [] inputs: List[LazyBuffer] = []
for out in outs: sink = UOp(UOps.SINK, src=tuple(UOp.store(buf_uops[out.buffer], ShapeTracker.from_shape(out.shape).to_uop(),
src = to_uop(out, outs, inputs, buf_uops, cache) to_uop(out, outs, inputs, buf_uops, cache)) for out in outs))
if out.op is MetaOps.ASSIGN and out.arg: sink = full_ast_rewrite(sink, tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs), var_vals)
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
output_st = out.arg[0]
else: output_st = ShapeTracker.from_shape(out.shape)
ast.append(UOp(UOps.STORE, dtypes.void, (buf_uops[out.buffer], output_st.to_uop(), src)))
sink = full_ast_rewrite(ast[0].sink(*ast[1:]), tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs), var_vals)
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is UOps.ASSIGN]) != 0: if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is UOps.ASSIGN]) != 0:
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \ if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \