diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index f5caaf74..52c6f9c0 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -112,6 +112,9 @@ view_left = merge_views+PatternMatcher([ # push VIEW to stores view_right = merge_views+PatternMatcher([ + # ASSIGN can override st + (UPat(UOps.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(UOps.ASSIGN, name="a"))), + lambda a,b,st: UOp.store(b, (a.arg[0]+st.arg).to_uop(), a.replace(arg=())) if a.arg else None), # view on reduce creates a new VIEW (UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r), # push a SWIZZLE down to STORE, through a reduce (ONLY reshapes) @@ -158,7 +161,7 @@ def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], bu src = tuple(to_uop(x, outputs, inputs, buf_uops, cache) for x in buf.srcs) if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg) elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, (buf_uops[buf.buffer], src[0])) - elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (buf_uops[buf.buffer], src[1])) + elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (buf_uops[buf.buffer], src[1]), buf.arg) elif buf.op is UnaryOps.CAST: ret = UOp(UOps.CAST, dtype, src) elif buf.op is UnaryOps.BITCAST: ret = UOp(UOps.BITCAST, dtype, src) else: ret = UOp(UOps.ALU, dtype, src, buf.op) @@ -172,16 +175,10 @@ def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], var_val (out.metadata,) if out.metadata is not None else None) # create the stores cache: Dict[LazyBuffer, UOp] = {} - ast: List[UOp] = [] inputs: List[LazyBuffer] = [] - for out in outs: - src = to_uop(out, outs, inputs, buf_uops, cache) - if out.op is MetaOps.ASSIGN and out.arg: - assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}" - output_st = out.arg[0] - else: output_st = ShapeTracker.from_shape(out.shape) - ast.append(UOp(UOps.STORE, dtypes.void, (buf_uops[out.buffer], output_st.to_uop(), src))) - sink = full_ast_rewrite(ast[0].sink(*ast[1:]), tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs), var_vals) + sink = UOp(UOps.SINK, src=tuple(UOp.store(buf_uops[out.buffer], ShapeTracker.from_shape(out.shape).to_uop(), + to_uop(out, outs, inputs, buf_uops, cache)) for out in outs)) + sink = full_ast_rewrite(sink, tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs), var_vals) # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is UOps.ASSIGN]) != 0: if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \