si.inputs+outputs -> bufs (#4279)

This commit is contained in:
George Hotz 2024-04-24 11:12:34 +04:00 committed by GitHub
parent 8401de9922
commit ad28fdecb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 20 additions and 14 deletions

View File

@ -48,7 +48,7 @@ for si in schedule: print(str(si)[:80])
# 4. Lower a schedule.
from tinygrad.engine.realize import lower_schedule_item, ExecItem
lowered: List[ExecItem] = [ExecItem(lower_schedule_item(si), list(si.outputs+si.inputs)) for si in tqdm(schedule)]
lowered: List[ExecItem] = [ExecItem(lower_schedule_item(si), list(si.bufs)) for si in tqdm(schedule)]
# *****
# 5. Run the schedule

View File

@ -49,7 +49,7 @@ if __name__ == "__main__":
src = Device["CLANG"].compiler.render(to_function_name(k.name), k.uops).strip(CLANG_PROGRAM_HEADER)
srcs[ast] = (k.name, src)
print("functions:", len(srcs))
used_buffers = dedup(flatten([si.outputs+si.inputs for si in sched]))
used_buffers = dedup(flatten([si.bufs for si in sched]))
numbered_bufs = {x:i for i,x in enumerate(used_buffers)}
print("buffers:", len(numbered_bufs))
@ -74,7 +74,7 @@ if __name__ == "__main__":
all_bufs = []
for i,si in enumerate(sched):
bufs = [(named_buffers.get(b, f"b{numbered_bufs[b]}"), b) for b in si.outputs+si.inputs]
bufs = [(named_buffers.get(b, f"b{numbered_bufs[b]}"), b) for b in si.bufs]
all_bufs += bufs
if si.ast[0].op is not BufferOps.STORE:
print(f"// {si.ast[0].op}", bufs)

View File

@ -33,7 +33,7 @@ def fuzz_schedule(outs: List[LazyBuffer]):
if out.op is LoadOps.ASSIGN: prerealized[out] = out.buffer.as_buffer()
for x in ps.inputs:
if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
si = ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs if x.size != 0), tuple(x.buffer for x in ps.inputs if x.size != 0))
si = ScheduleItem(ps.ast, tuple(x.buffer for x in (ps.outputs+ps.inputs) if x.size != 0))
_exec_si(si, seed)
for out in ps.outputs:
ground_truth[out] = out.buffer.as_buffer()
@ -52,7 +52,7 @@ def fuzz_schedule(outs: List[LazyBuffer]):
if x.device == "NPY": rawbufs[x] = x.buffer
# copy the pre realized input
else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=prerealized[x])
si = ScheduleItem(ps.ast, tuple(rawbufs[x] for x in ps.outputs if x.size != 0), tuple(rawbufs[x] for x in ps.inputs if x.size != 0))
si = ScheduleItem(ps.ast, tuple(rawbufs[x] for x in (ps.outputs+ps.inputs) if x.size != 0))
_exec_si(si, seed)
for out in ps.outputs:
outbuf = np.frombuffer(rawbufs[out].as_buffer(), out.dtype.np)
@ -62,7 +62,7 @@ def fuzz_schedule(outs: List[LazyBuffer]):
raise e
def _exec_si(si: ScheduleItem, seed:int):
ei = ExecItem(lower_schedule_item(si), list(si.outputs+si.inputs))
ei = ExecItem(lower_schedule_item(si), list(si.bufs))
if len(capturing): capturing[0].add(ei)
if isinstance(ei.prg, CustomOp): Tensor._seed = seed
ei.run()

View File

@ -37,7 +37,7 @@ class EmptyOp(Runner):
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): pass
def lower_schedule_item(si:ScheduleItem) -> Runner:
assert len(set(x.device for x in si.outputs+si.inputs)) == 1 or si.ast[0].op is LoadOps.COPY
assert len(set(x.device for x in si.bufs)) == 1 or si.ast[0].op is LoadOps.COPY
if si.ast[0].op is BufferOps.STORE: return Device[si.outputs[0].device].get_runner(*si.ast)
assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput"
out, ast = si.outputs[0], si.ast[0]
@ -52,7 +52,7 @@ def lower_schedule_item(si:ScheduleItem) -> Runner:
raise RuntimeError(f"don't know how to lower {ast}")
def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
while len(schedule): yield ExecItem(lower_schedule_item(si:=schedule.pop(0)), list(si.outputs+si.inputs))
while len(schedule): yield ExecItem(lower_schedule_item(si:=schedule.pop(0)), list(si.bufs))
capturing: List = [] # put classes with an add method in here
@ -81,9 +81,8 @@ def _internal_memory_planner(buffers:List[Iterable[Buffer]], debug_prefix="") ->
return assigned
def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
assigned = _internal_memory_planner([si.outputs+si.inputs for si in schedule])
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.outputs),
tuple(assigned.get(x, x) for x in si.inputs)) for si in schedule]
assigned = _internal_memory_planner([si.bufs for si in schedule])
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs)) for si in schedule]
def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None):
for ei in lower_schedule(schedule):

View File

@ -254,7 +254,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
for out in ps.outputs: realized_lazybuffer(out, kernel_number)
var_vals = merge_dicts([var_vals, ps.var_vals])
for out in ps.outputs: del out.srcs # can only schedule once
schedule.append(si:=ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs if x.size != 0), tuple(x.buffer for x in ps.inputs if x.size != 0)))
schedule.append(si:=ScheduleItem(ps.ast, tuple(x.buffer for x in (ps.outputs+ps.inputs) if x.size != 0)))
if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
for x in graph[ps.outputs[0]]:
in_degree[x] -= 1

View File

@ -42,8 +42,15 @@ class ConstBuffer:
@dataclass(frozen=True)
class ScheduleItem:
ast: Tuple[LazyOp, ...]
outputs: Tuple[Buffer, ...]
inputs: Tuple[Buffer, ...]
bufs: Tuple[Buffer, ...]
@property
def outputs(self) -> Tuple[Buffer, ...]:
"""Read/write or write only buffers in the schedule."""
return self.bufs[:len(self.ast)]
@property
def inputs(self) -> Tuple[Buffer, ...]:
"""Read only buffers in the schedule."""
return self.bufs[len(self.ast):]
@dataclass(frozen=True, eq=False)
class LazyOp: