prescheduling refactor (#5300)

* p1

* refactor tuple
This commit is contained in:
qazal 2024-07-06 12:04:03 +03:00 committed by GitHub
parent c1e166c08a
commit d813617742
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 27 deletions

View File

@ -36,7 +36,7 @@ def fuzz_schedule(outs:List[LazyBuffer]):
assign_targets[out.srcs[1]] = out
for x in ps[2]:
if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
si = ScheduleItem(ps[1], tuple(x.buffer for x in (tuple(ps[0])+ps[2]) if x.size != 0))
si = ScheduleItem(ps[1], tuple(x.buffer for x in ps[0]+ps[2] if x.size != 0))
_exec_si(si, seed)
for out in ps[0]:
ground_truth[out] = out.buffer.as_buffer()
@ -57,7 +57,7 @@ def fuzz_schedule(outs:List[LazyBuffer]):
elif x.device == "NPY": rawbufs[x] = x.buffer
# copy the pre realized input
else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=prerealized[x])
si = ScheduleItem(ps[1], tuple(rawbufs[x] for x in (tuple(ps[0])+ps[2]) if x.size != 0))
si = ScheduleItem(ps[1], tuple(rawbufs[x] for x in ps[0]+ps[2] if x.size != 0))
_exec_si(si, seed)
for out in ps[0]:
outbuf = np.frombuffer(rawbufs[out].as_buffer(), _to_np_dtype(out.dtype))

View File

@ -92,30 +92,24 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[Laz
LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outputs, var_vals, st, realizes, assign_targets, cache) for x in buf.srcs), buf.arg)
return ret
def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None], reduce_for_op: Dict[LazyBuffer, LazyBuffer]):
"""create a schedule item from a list of outputs"""
inputs: List[LazyBuffer] = []
ast: List[LazyOp] = []
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, LazyBuffer]):
"""describe the computation for a LazyBuffer with LazyOp + inputs + var_vals"""
if (out:=outs[0]).op is LoadOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
return (LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)), ), [x.base for x in out.srcs], {}
if out.op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY, LoadOps.VIEW}: return (LazyOp(out.op, (), out.arg), ), [x.base for x in out.srcs], {}
var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs])
# single output AST
if (op:=(out:=outs[0]).op) in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY, LoadOps.VIEW}:
assert len(outs) == 1, f"can't schedule a group of {op}"
inputs = [x.base for x in out.srcs]
if getenv("USE_COPY_KERNEL") and op is LoadOps.COPY and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
ast = [LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st))]
else: ast = [LazyOp(op, (), out.arg)]
# multi output AST
else:
assign_targets = {x.srcs[1]:x for x in outs if x.op is LoadOps.ASSIGN}
for i, out in enumerate(outs):
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st
lop = _recursive_lazyop(out, inputs, outs, var_vals, output_st, realizes, assign_targets, cache={})
output_view, vv = output_view.simplify().unbind()
if vv: var_vals.update(vv)
ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view)))
return tuple(ast), tuple(inputs), var_vals
assign_targets = {x.srcs[1]:x for x in outs if x.op is LoadOps.ASSIGN}
ast: List[LazyOp] = []
inputs: List[LazyBuffer] = []
for i, out in enumerate(outs):
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st
lop = _recursive_lazyop(out, inputs, tuple(outs), var_vals, output_st, realizes, assign_targets, cache={})
output_view, vv = output_view.simplify().unbind()
if vv: var_vals.update(vv)
ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view)))
return tuple(ast), inputs, var_vals
# *** DAG creation: decide which LazyBuffers should realize ***
@ -260,7 +254,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]):
buf.buffer.options = None
# preschedule all buffers in realizes
prescheduled = {group[0]:(group, *_schedule_group(tuple(group), realizes, reduce_for_op)) for group in output_groups.values()}
prescheduled = {group[0]:(group, *_lower_lazybuffer(group, realizes, reduce_for_op)) for group in output_groups.values()}
schedule_targets = {out:ps for ps in prescheduled.values() for out in ps[0]}
graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
@ -298,7 +292,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
for out in ps[0]: realized_lazybuffer(out, kernel_number)
var_vals = merge_dicts([var_vals, ps[3]])
for out in ps[0]: del out.srcs # can only schedule once
schedule.append(si:=ScheduleItem(ps[1], tuple(x.buffer for x in (tuple(ps[0])+ps[2]) if x.size != 0)))
schedule.append(si:=ScheduleItem(ps[1], tuple(x.buffer for x in ps[0]+ps[2] if x.size != 0)))
if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
for x in graph[ps[0][0]]:
in_degree[x] -= 1