refactor DEFINE_GLOBAL inputs to list [run_process_replay] (#6711)

This commit is contained in:
qazal 2024-09-24 17:43:24 +08:00 committed by GitHub
parent f932116e05
commit ae3f3fec38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 4 deletions

View File

@ -119,7 +119,7 @@ def full_ast_rewrite(sink:UOp) -> UOp:
# *** List[LazyBuffer] lowering to ScheduleItem *** # *** List[LazyBuffer] lowering to ScheduleItem ***
def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], inputs:Dict[LazyBuffer, int], def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], inputs:List[LazyBuffer],
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer], realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer],
cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp: cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp:
"""recursively create a UOp""" """recursively create a UOp"""
@ -145,8 +145,9 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
if buf not in assign_targets and buf not in inputs: inputs.append(buf)
ubuf = UOp(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), ubuf = UOp(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
outputs.index(assign_targets[buf]) if buf in assign_targets else len(outputs)+inputs.setdefault(buf, len(inputs))) outputs.index(assign_targets[buf]) if buf in assign_targets else len(outputs)+inputs.index(buf))
return UOp(UOps.LOAD, dtype, (ubuf, unbound_st.to_uop())) return UOp(UOps.LOAD, dtype, (ubuf, unbound_st.to_uop()))
# reduce ops change ShapeTracker # reduce ops change ShapeTracker
@ -173,7 +174,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN} assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}
cache: Dict[Tuple[LazyBuffer, ShapeTracker], UOp] = {} cache: Dict[Tuple[LazyBuffer, ShapeTracker], UOp] = {}
ast: List[UOp] = [] ast: List[UOp] = []
inputs: Dict[LazyBuffer, int] = {} inputs: List[LazyBuffer] = []
for i, out in enumerate(outs): for i, out in enumerate(outs):
src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, realizes, assign_targets, cache=cache) src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, realizes, assign_targets, cache=cache)
if out.op is MetaOps.ASSIGN and out.arg: if out.op is MetaOps.ASSIGN and out.arg:
@ -184,7 +185,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), i) ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), i)
ast.append(UOp(UOps.STORE, dtypes.void, (ubuf, output_st.to_uop(), src))) ast.append(UOp(UOps.STORE, dtypes.void, (ubuf, output_st.to_uop(), src)))
sink = full_ast_rewrite(ast[0].sink(*ast[1:])) sink = full_ast_rewrite(ast[0].sink(*ast[1:]))
return LBScheduleItem(sink, outs, list(inputs), dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])), var_vals return LBScheduleItem(sink, outs, inputs, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])), var_vals
# *** DAG creation: decide which LazyBuffers should realize *** # *** DAG creation: decide which LazyBuffers should realize ***