mirror of https://github.com/commaai/tinygrad.git
move assign st override to upat (#7122)
* move assign st override to upat * merge view
This commit is contained in:
parent
ded1b38b84
commit
a2eefa6f97
|
@ -112,6 +112,9 @@ view_left = merge_views+PatternMatcher([
|
||||||
|
|
||||||
# push VIEW to stores
|
# push VIEW to stores
|
||||||
view_right = merge_views+PatternMatcher([
|
view_right = merge_views+PatternMatcher([
|
||||||
|
# ASSIGN can override st
|
||||||
|
(UPat(UOps.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(UOps.ASSIGN, name="a"))),
|
||||||
|
lambda a,b,st: UOp.store(b, (a.arg[0]+st.arg).to_uop(), a.replace(arg=())) if a.arg else None),
|
||||||
# view on reduce creates a new VIEW
|
# view on reduce creates a new VIEW
|
||||||
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r),
|
(UPat(UOps.VIEW, src=(UPat(UOps.REDUCE_AXIS, src=UPat.var("rsrc"), name="r"),), name="view"), view_r),
|
||||||
# push a SWIZZLE down to STORE, through a reduce (ONLY reshapes)
|
# push a SWIZZLE down to STORE, through a reduce (ONLY reshapes)
|
||||||
|
@ -158,7 +161,7 @@ def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], bu
|
||||||
src = tuple(to_uop(x, outputs, inputs, buf_uops, cache) for x in buf.srcs)
|
src = tuple(to_uop(x, outputs, inputs, buf_uops, cache) for x in buf.srcs)
|
||||||
if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg)
|
if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg)
|
||||||
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, (buf_uops[buf.buffer], src[0]))
|
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, (buf_uops[buf.buffer], src[0]))
|
||||||
elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (buf_uops[buf.buffer], src[1]))
|
elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (buf_uops[buf.buffer], src[1]), buf.arg)
|
||||||
elif buf.op is UnaryOps.CAST: ret = UOp(UOps.CAST, dtype, src)
|
elif buf.op is UnaryOps.CAST: ret = UOp(UOps.CAST, dtype, src)
|
||||||
elif buf.op is UnaryOps.BITCAST: ret = UOp(UOps.BITCAST, dtype, src)
|
elif buf.op is UnaryOps.BITCAST: ret = UOp(UOps.BITCAST, dtype, src)
|
||||||
else: ret = UOp(UOps.ALU, dtype, src, buf.op)
|
else: ret = UOp(UOps.ALU, dtype, src, buf.op)
|
||||||
|
@ -172,16 +175,10 @@ def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], var_val
|
||||||
(out.metadata,) if out.metadata is not None else None)
|
(out.metadata,) if out.metadata is not None else None)
|
||||||
# create the stores
|
# create the stores
|
||||||
cache: Dict[LazyBuffer, UOp] = {}
|
cache: Dict[LazyBuffer, UOp] = {}
|
||||||
ast: List[UOp] = []
|
|
||||||
inputs: List[LazyBuffer] = []
|
inputs: List[LazyBuffer] = []
|
||||||
for out in outs:
|
sink = UOp(UOps.SINK, src=tuple(UOp.store(buf_uops[out.buffer], ShapeTracker.from_shape(out.shape).to_uop(),
|
||||||
src = to_uop(out, outs, inputs, buf_uops, cache)
|
to_uop(out, outs, inputs, buf_uops, cache)) for out in outs))
|
||||||
if out.op is MetaOps.ASSIGN and out.arg:
|
sink = full_ast_rewrite(sink, tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs), var_vals)
|
||||||
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
|
|
||||||
output_st = out.arg[0]
|
|
||||||
else: output_st = ShapeTracker.from_shape(out.shape)
|
|
||||||
ast.append(UOp(UOps.STORE, dtypes.void, (buf_uops[out.buffer], output_st.to_uop(), src)))
|
|
||||||
sink = full_ast_rewrite(ast[0].sink(*ast[1:]), tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs), var_vals)
|
|
||||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||||
if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is UOps.ASSIGN]) != 0:
|
if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is UOps.ASSIGN]) != 0:
|
||||||
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
|
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
|
||||||
|
|
Loading…
Reference in New Issue