mirror of https://github.com/commaai/tinygrad.git
big graph var_vals as rewrite context (#7007)
* var_vals as rewrite context * no default arg * add st var_vals * delete some stuff * add the rewrite rule again * extra * this whole part is preschedule * test with a second context * redo * i always forget tensor variable
This commit is contained in:
parent
390171d686
commit
40f33c110b
|
@ -30,7 +30,7 @@ if REF == "master": SKIP_PROCESS_REPLAY = True
|
|||
|
||||
# *** recreators
|
||||
|
||||
def recreate_sched(sink:UOp, ctx, _) -> UOp: return full_ast_rewrite(sink, ctx)
|
||||
def recreate_sched(sink:UOp, ctx, _) -> UOp: return full_ast_rewrite(sink, ctx, {})
|
||||
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:List[Opt], name:str, ctx:ProcessReplayContext, _) -> str:
|
||||
with Context(**{k:v for k,v in ctx.ctx_vars.items() if k in ContextVar._cache and k != "DEBUG"}):
|
||||
k = Kernel(ast, opts=opts)
|
||||
|
|
|
@ -121,7 +121,14 @@ view_right = merge_views+PatternMatcher([
|
|||
(UPat(UOps.REDUCE_AXIS, src=(UPat(UOps.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
])
|
||||
|
||||
enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp(UOps.DEFINE_GLOBAL, x.dtype, (), ctx.index(x.arg[0])))])
|
||||
def simplify_and_unbind(ctx, x:UOp) -> Optional[UOp]:
|
||||
if (st:=unwrap(x.st)) in ctx[2]: return None
|
||||
st, var_vals = st.simplify().unbind()
|
||||
ctx[0].update(var_vals)
|
||||
ctx[2].add(st)
|
||||
return st.to_uop()
|
||||
append_vars = PatternMatcher([(UPat(UOps.VIEW, name="x"), simplify_and_unbind)])
|
||||
enumerate_bufs = PatternMatcher([(UPat(UOps.BUFFER, name="x"), lambda ctx,x: UOp(UOps.DEFINE_GLOBAL, x.dtype, (), ctx[1].index(x.arg[0])))])
|
||||
|
||||
PROCESS_REPLAY_CAPTURE: List[Tuple[UOp, Tuple[int, ...], UOp]] = []
|
||||
if getenv("RUN_PROCESS_REPLAY"):
|
||||
|
@ -130,15 +137,15 @@ if getenv("RUN_PROCESS_REPLAY"):
|
|||
for base_sink,ctx,ret in PROCESS_REPLAY_CAPTURE: diskcache_put("schedule_process_replay", str(base_sink.key), (base_sink, ctx, ret))
|
||||
|
||||
@track_rewrites
|
||||
def full_ast_rewrite(base_sink:UOp, bufs:Tuple[int, ...]) -> UOp:
|
||||
def full_ast_rewrite(base_sink:UOp, bufs:Tuple[int, ...], var_vals:Dict[Variable, int]) -> UOp:
|
||||
sink = graph_rewrite(graph_rewrite(base_sink, view_left), view_right)
|
||||
ret = graph_rewrite(sink, enumerate_bufs, bufs)
|
||||
ret = graph_rewrite(sink, append_vars+enumerate_bufs, (var_vals, bufs, set()))
|
||||
PROCESS_REPLAY_CAPTURE.append((base_sink, bufs, ret))
|
||||
return ret
|
||||
|
||||
# *** List[LazyBuffer] lowering to ScheduleItem ***
|
||||
|
||||
def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], inputs:List[LazyBuffer],
|
||||
def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], inputs:List[LazyBuffer],
|
||||
buf_uops:Dict[Buffer, UOp], cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp:
|
||||
"""recursively create a UOp"""
|
||||
if buf is not buf.base: st, buf = buf.st+st, buf.base
|
||||
|
@ -149,17 +156,13 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
|||
# buffer ops define ShapeTracker
|
||||
# if it's realized, it's a load and we add it to the inputs
|
||||
if (ubuf:=buf_uops.get(buf.buffer)) is not None and buf not in outputs:
|
||||
unbound_st, st_var_vals = st.simplify().unbind()
|
||||
var_vals.update(st_var_vals)
|
||||
if buf.op is MetaOps.CONST:
|
||||
if isinstance(val:=buf.arg, UOp): var_vals.update([val.unbind()])
|
||||
return ubuf.view(unbound_st)
|
||||
if buf.op is MetaOps.CONST: return ubuf.view(st)
|
||||
if not any(x.buffer is buf.buffer for x in outputs) and buf not in inputs: inputs.append(buf)
|
||||
return UOp(UOps.LOAD, dtype, (ubuf, unbound_st.to_uop()))
|
||||
return UOp(UOps.LOAD, dtype, (ubuf, st.to_uop()))
|
||||
|
||||
# only reduceop changes shape
|
||||
src_st = ShapeTracker.from_shape(buf.srcs[0].shape) if buf.op in ReduceOps else st
|
||||
src: List[UOp] = [_recursive_uop(x, src_st, outputs, var_vals, inputs, buf_uops, cache) for x in buf.srcs]
|
||||
src: List[UOp] = [_recursive_uop(x, src_st, outputs, inputs, buf_uops, cache) for x in buf.srcs]
|
||||
if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg).view(st)
|
||||
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, (buf_uops[buf.buffer], src[0]))
|
||||
elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (buf_uops[buf.buffer], src[1]))
|
||||
|
@ -169,32 +172,29 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
|||
cache[(buf, st)] = ret
|
||||
return ret
|
||||
|
||||
def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp]) -> Tuple[LBScheduleItem, Dict[Variable, int]]:
|
||||
def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], var_vals:Dict[Variable, int]) -> LBScheduleItem:
|
||||
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
|
||||
if (out:=outs[0]).op in METAOPS:
|
||||
return LBScheduleItem(UOp(METAOPS[cast(MetaOps, out.op)], out.dtype, (), out.arg), (out,)+tuple(x.base for x in out.srcs),
|
||||
(out.metadata,) if out.metadata is not None else None), {}
|
||||
(out.metadata,) if out.metadata is not None else None)
|
||||
# create the stores
|
||||
var_vals = merge_dicts([out.st.var_vals.copy() for out in outs])
|
||||
cache: Dict[Tuple[LazyBuffer, ShapeTracker], UOp] = {}
|
||||
ast: List[UOp] = []
|
||||
inputs: List[LazyBuffer] = []
|
||||
for out in outs:
|
||||
src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, buf_uops, cache=cache)
|
||||
src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), inputs, buf_uops, cache=cache)
|
||||
if out.op is MetaOps.ASSIGN and out.arg:
|
||||
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
|
||||
output_st = out.arg[0]
|
||||
output_st, vv = output_st.simplify().unbind()
|
||||
var_vals.update(vv)
|
||||
ast.append(UOp(UOps.STORE, dtypes.void, (buf_uops[out.buffer], output_st.to_uop(), src)))
|
||||
sink = full_ast_rewrite(ast[0].sink(*ast[1:]), tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs))
|
||||
sink = full_ast_rewrite(ast[0].sink(*ast[1:]), tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs), var_vals)
|
||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is UOps.ASSIGN]) != 0:
|
||||
if not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
|
||||
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is UOps.LOAD and x.src[0] in assign_targets):
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x.metadata for x,_ in cache if x.metadata and x not in inputs]))), var_vals
|
||||
return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x.metadata for x,_ in cache if x.metadata and x not in inputs])))
|
||||
|
||||
# *** DAG creation: decide which LazyBuffers should realize ***
|
||||
|
||||
|
@ -353,6 +353,7 @@ def _graph_schedule(outs:List[LazyBuffer]) -> \
|
|||
|
||||
output_groups: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
|
||||
buf_uops: Dict[Buffer, UOp] = {}
|
||||
var_vals: Dict[Variable, int] = {}
|
||||
for buf in realizes:
|
||||
if buf.realized is None and buf.op is not MetaOps.CONST:
|
||||
output_groups[reduce_for_op[buf] if buf in reduce_for_op and MULTIOUTPUT else buf].append(buf)
|
||||
|
@ -368,17 +369,14 @@ def _graph_schedule(outs:List[LazyBuffer]) -> \
|
|||
buf.buffer.dtype = dtypes.float32
|
||||
buf.buffer.options = None
|
||||
if buf.op is MetaOps.CONST:
|
||||
if isinstance(val:=buf.arg, UOp): var_vals.update([val.unbind()])
|
||||
uop = UOp(UOps.VALID, dtypes.bool, (buf.st.to_uop(),)).where(v:=UOp.const(buf.dtype.scalar(), buf.arg), v.const_like(0))
|
||||
# NOTE: UOps.BUFFER creation must come after the ImageDType fixup
|
||||
else: uop = UOp(UOps.BUFFER, buf.buffer.dtype.ptr(), (), (len(buf_uops), (buf.buffer.device, buf.buffer.size, buf.buffer.dtype)))
|
||||
buf_uops.setdefault(buf.buffer, uop)
|
||||
|
||||
# preschedule all buffers in realizes
|
||||
prescheduled: List[LBScheduleItem] = []
|
||||
var_vals: Dict[Variable, int] = {}
|
||||
for outs in output_groups.values():
|
||||
prescheduled.append((ret:=_lower_lazybuffer(outs, buf_uops))[0])
|
||||
var_vals = merge_dicts([var_vals, ret[1]])
|
||||
prescheduled = [_lower_lazybuffer(outs, buf_uops, var_vals) for outs in output_groups.values()]
|
||||
schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}
|
||||
|
||||
graph: DefaultDict[LBScheduleItem, List[LBScheduleItem]] = defaultdict(list)
|
||||
|
|
Loading…
Reference in New Issue