mirror of https://github.com/commaai/tinygrad.git
refactor DEFINE_GLOBAL inputs to list [run_process_replay] (#6711)
This commit is contained in:
parent
f932116e05
commit
ae3f3fec38
|
@ -119,7 +119,7 @@ def full_ast_rewrite(sink:UOp) -> UOp:
|
||||||
|
|
||||||
# *** List[LazyBuffer] lowering to ScheduleItem ***
|
# *** List[LazyBuffer] lowering to ScheduleItem ***
|
||||||
|
|
||||||
def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], inputs:Dict[LazyBuffer, int],
|
def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], inputs:List[LazyBuffer],
|
||||||
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer],
|
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer],
|
||||||
cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp:
|
cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp:
|
||||||
"""recursively create a UOp"""
|
"""recursively create a UOp"""
|
||||||
|
@ -145,8 +145,9 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||||
|
if buf not in assign_targets and buf not in inputs: inputs.append(buf)
|
||||||
ubuf = UOp(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
|
ubuf = UOp(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
|
||||||
outputs.index(assign_targets[buf]) if buf in assign_targets else len(outputs)+inputs.setdefault(buf, len(inputs)))
|
outputs.index(assign_targets[buf]) if buf in assign_targets else len(outputs)+inputs.index(buf))
|
||||||
return UOp(UOps.LOAD, dtype, (ubuf, unbound_st.to_uop()))
|
return UOp(UOps.LOAD, dtype, (ubuf, unbound_st.to_uop()))
|
||||||
|
|
||||||
# reduce ops change ShapeTracker
|
# reduce ops change ShapeTracker
|
||||||
|
@ -173,7 +174,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
|
||||||
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}
|
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}
|
||||||
cache: Dict[Tuple[LazyBuffer, ShapeTracker], UOp] = {}
|
cache: Dict[Tuple[LazyBuffer, ShapeTracker], UOp] = {}
|
||||||
ast: List[UOp] = []
|
ast: List[UOp] = []
|
||||||
inputs: Dict[LazyBuffer, int] = {}
|
inputs: List[LazyBuffer] = []
|
||||||
for i, out in enumerate(outs):
|
for i, out in enumerate(outs):
|
||||||
src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, realizes, assign_targets, cache=cache)
|
src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, realizes, assign_targets, cache=cache)
|
||||||
if out.op is MetaOps.ASSIGN and out.arg:
|
if out.op is MetaOps.ASSIGN and out.arg:
|
||||||
|
@ -184,7 +185,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
|
||||||
ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), i)
|
ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), i)
|
||||||
ast.append(UOp(UOps.STORE, dtypes.void, (ubuf, output_st.to_uop(), src)))
|
ast.append(UOp(UOps.STORE, dtypes.void, (ubuf, output_st.to_uop(), src)))
|
||||||
sink = full_ast_rewrite(ast[0].sink(*ast[1:]))
|
sink = full_ast_rewrite(ast[0].sink(*ast[1:]))
|
||||||
return LBScheduleItem(sink, outs, list(inputs), dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])), var_vals
|
return LBScheduleItem(sink, outs, inputs, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])), var_vals
|
||||||
|
|
||||||
# *** DAG creation: decide which LazyBuffers should realize ***
|
# *** DAG creation: decide which LazyBuffers should realize ***
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue