From 53586eac56645bf56c4959fbe114ed732571393a Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 16 Oct 2024 06:26:04 +0300 Subject: [PATCH] late assert post permuted assign [pr] (#7084) * late assert post permuted assign [pr] * a lil earlier --- tinygrad/engine/schedule.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 5053e887..87c2bec8 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -139,7 +139,7 @@ def full_ast_rewrite(base_sink:UOp, bufs:Tuple[int, ...]) -> UOp: # *** List[LazyBuffer] lowering to ScheduleItem *** def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], inputs:List[LazyBuffer], - buf_uops:Dict[Buffer, UOp], assign_targets:Dict[LazyBuffer, LazyBuffer], cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp: + buf_uops:Dict[Buffer, UOp], cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp: """recursively create a UOp""" if buf is not buf.base: st, buf = buf.st+st, buf.base if (buf, st) in cache: return cache[(buf, st)] @@ -154,17 +154,12 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. if buf.op is MetaOps.CONST: if isinstance(val:=buf.arg, UOp): var_vals.update([val.unbind()]) return ubuf.view(unbound_st) - if buf in assign_targets and not (unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and \ - ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask))): - # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine - raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" - +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) - if buf not in assign_targets and buf not in inputs: inputs.append(buf) + if not any(x.buffer is buf.buffer for x in outputs) and buf not in inputs: inputs.append(buf) return UOp(UOps.LOAD, dtype, (ubuf, unbound_st.to_uop())) # only reduceop changes shape src_st = ShapeTracker.from_shape(buf.srcs[0].shape) if buf.op in ReduceOps else st - src: List[UOp] = [_recursive_uop(x, src_st, outputs, var_vals, inputs, buf_uops, assign_targets, cache) for x in buf.srcs] + src: List[UOp] = [_recursive_uop(x, src_st, outputs, var_vals, inputs, buf_uops, cache) for x in buf.srcs] if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg).view(st) elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, (buf_uops[buf.buffer], src[0])) elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (buf_uops[buf.buffer], src[1])) @@ -181,12 +176,11 @@ def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp]) -> Tupl (out.metadata,) if out.metadata is not None else None), {} # create the stores var_vals = merge_dicts([out.st.var_vals.copy() for out in outs]) - assign_targets = {x.srcs[0]:x for x in outs if x.op is MetaOps.ASSIGN} cache: Dict[Tuple[LazyBuffer, ShapeTracker], UOp] = {} ast: List[UOp] = [] inputs: List[LazyBuffer] = [] for out in outs: - src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, buf_uops, assign_targets, cache=cache) + src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, buf_uops, cache=cache) if out.op is MetaOps.ASSIGN and out.arg: assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}" output_st = out.arg[0] @@ -194,6 +188,12 @@ def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp]) -> Tupl var_vals.update(vv) ast.append(UOp(UOps.STORE, dtypes.void, (buf_uops[out.buffer], output_st.to_uop(), src))) sink = full_ast_rewrite(ast[0].sink(*ast[1:]), tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs)) + # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine + if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is UOps.ASSIGN]) != 0: + if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \ + and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is UOps.LOAD and x.src[0] in assign_targets): + raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" + +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x.metadata for x,_ in cache if x.metadata and x not in inputs]))), var_vals # *** DAG creation: decide which LazyBuffers should realize ***