mirror of https://github.com/commaai/tinygrad.git
si.inputs+outputs -> bufs (#4279)
This commit is contained in:
parent
8401de9922
commit
ad28fdecb1
|
@ -48,7 +48,7 @@ for si in schedule: print(str(si)[:80])
|
|||
# 4. Lower a schedule.
|
||||
|
||||
from tinygrad.engine.realize import lower_schedule_item, ExecItem
|
||||
lowered: List[ExecItem] = [ExecItem(lower_schedule_item(si), list(si.outputs+si.inputs)) for si in tqdm(schedule)]
|
||||
lowered: List[ExecItem] = [ExecItem(lower_schedule_item(si), list(si.bufs)) for si in tqdm(schedule)]
|
||||
|
||||
# *****
|
||||
# 5. Run the schedule
|
||||
|
|
|
@ -49,7 +49,7 @@ if __name__ == "__main__":
|
|||
src = Device["CLANG"].compiler.render(to_function_name(k.name), k.uops).strip(CLANG_PROGRAM_HEADER)
|
||||
srcs[ast] = (k.name, src)
|
||||
print("functions:", len(srcs))
|
||||
used_buffers = dedup(flatten([si.outputs+si.inputs for si in sched]))
|
||||
used_buffers = dedup(flatten([si.bufs for si in sched]))
|
||||
numbered_bufs = {x:i for i,x in enumerate(used_buffers)}
|
||||
print("buffers:", len(numbered_bufs))
|
||||
|
||||
|
@ -74,7 +74,7 @@ if __name__ == "__main__":
|
|||
|
||||
all_bufs = []
|
||||
for i,si in enumerate(sched):
|
||||
bufs = [(named_buffers.get(b, f"b{numbered_bufs[b]}"), b) for b in si.outputs+si.inputs]
|
||||
bufs = [(named_buffers.get(b, f"b{numbered_bufs[b]}"), b) for b in si.bufs]
|
||||
all_bufs += bufs
|
||||
if si.ast[0].op is not BufferOps.STORE:
|
||||
print(f"// {si.ast[0].op}", bufs)
|
||||
|
|
|
@ -33,7 +33,7 @@ def fuzz_schedule(outs: List[LazyBuffer]):
|
|||
if out.op is LoadOps.ASSIGN: prerealized[out] = out.buffer.as_buffer()
|
||||
for x in ps.inputs:
|
||||
if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
|
||||
si = ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs if x.size != 0), tuple(x.buffer for x in ps.inputs if x.size != 0))
|
||||
si = ScheduleItem(ps.ast, tuple(x.buffer for x in (ps.outputs+ps.inputs) if x.size != 0))
|
||||
_exec_si(si, seed)
|
||||
for out in ps.outputs:
|
||||
ground_truth[out] = out.buffer.as_buffer()
|
||||
|
@ -52,7 +52,7 @@ def fuzz_schedule(outs: List[LazyBuffer]):
|
|||
if x.device == "NPY": rawbufs[x] = x.buffer
|
||||
# copy the pre realized input
|
||||
else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=prerealized[x])
|
||||
si = ScheduleItem(ps.ast, tuple(rawbufs[x] for x in ps.outputs if x.size != 0), tuple(rawbufs[x] for x in ps.inputs if x.size != 0))
|
||||
si = ScheduleItem(ps.ast, tuple(rawbufs[x] for x in (ps.outputs+ps.inputs) if x.size != 0))
|
||||
_exec_si(si, seed)
|
||||
for out in ps.outputs:
|
||||
outbuf = np.frombuffer(rawbufs[out].as_buffer(), out.dtype.np)
|
||||
|
@ -62,7 +62,7 @@ def fuzz_schedule(outs: List[LazyBuffer]):
|
|||
raise e
|
||||
|
||||
def _exec_si(si: ScheduleItem, seed:int):
|
||||
ei = ExecItem(lower_schedule_item(si), list(si.outputs+si.inputs))
|
||||
ei = ExecItem(lower_schedule_item(si), list(si.bufs))
|
||||
if len(capturing): capturing[0].add(ei)
|
||||
if isinstance(ei.prg, CustomOp): Tensor._seed = seed
|
||||
ei.run()
|
||||
|
|
|
@ -37,7 +37,7 @@ class EmptyOp(Runner):
|
|||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False): pass
|
||||
|
||||
def lower_schedule_item(si:ScheduleItem) -> Runner:
|
||||
assert len(set(x.device for x in si.outputs+si.inputs)) == 1 or si.ast[0].op is LoadOps.COPY
|
||||
assert len(set(x.device for x in si.bufs)) == 1 or si.ast[0].op is LoadOps.COPY
|
||||
if si.ast[0].op is BufferOps.STORE: return Device[si.outputs[0].device].get_runner(*si.ast)
|
||||
assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput"
|
||||
out, ast = si.outputs[0], si.ast[0]
|
||||
|
@ -52,7 +52,7 @@ def lower_schedule_item(si:ScheduleItem) -> Runner:
|
|||
raise RuntimeError(f"don't know how to lower {ast}")
|
||||
|
||||
def lower_schedule(schedule:List[ScheduleItem]) -> Generator[ExecItem, None, None]:
|
||||
while len(schedule): yield ExecItem(lower_schedule_item(si:=schedule.pop(0)), list(si.outputs+si.inputs))
|
||||
while len(schedule): yield ExecItem(lower_schedule_item(si:=schedule.pop(0)), list(si.bufs))
|
||||
|
||||
capturing: List = [] # put classes with an add method in here
|
||||
|
||||
|
@ -81,9 +81,8 @@ def _internal_memory_planner(buffers:List[Iterable[Buffer]], debug_prefix="") ->
|
|||
return assigned
|
||||
|
||||
def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
|
||||
assigned = _internal_memory_planner([si.outputs+si.inputs for si in schedule])
|
||||
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.outputs),
|
||||
tuple(assigned.get(x, x) for x in si.inputs)) for si in schedule]
|
||||
assigned = _internal_memory_planner([si.bufs for si in schedule])
|
||||
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs)) for si in schedule]
|
||||
|
||||
def run_schedule(schedule:List[ScheduleItem], var_vals:Optional[Dict[Variable, int]]=None):
|
||||
for ei in lower_schedule(schedule):
|
||||
|
|
|
@ -254,7 +254,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
|
|||
for out in ps.outputs: realized_lazybuffer(out, kernel_number)
|
||||
var_vals = merge_dicts([var_vals, ps.var_vals])
|
||||
for out in ps.outputs: del out.srcs # can only schedule once
|
||||
schedule.append(si:=ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs if x.size != 0), tuple(x.buffer for x in ps.inputs if x.size != 0)))
|
||||
schedule.append(si:=ScheduleItem(ps.ast, tuple(x.buffer for x in (ps.outputs+ps.inputs) if x.size != 0)))
|
||||
if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
|
||||
for x in graph[ps.outputs[0]]:
|
||||
in_degree[x] -= 1
|
||||
|
|
|
@ -42,8 +42,15 @@ class ConstBuffer:
|
|||
@dataclass(frozen=True)
|
||||
class ScheduleItem:
|
||||
ast: Tuple[LazyOp, ...]
|
||||
outputs: Tuple[Buffer, ...]
|
||||
inputs: Tuple[Buffer, ...]
|
||||
bufs: Tuple[Buffer, ...]
|
||||
@property
|
||||
def outputs(self) -> Tuple[Buffer, ...]:
|
||||
"""Read/write or write only buffers in the schedule."""
|
||||
return self.bufs[:len(self.ast)]
|
||||
@property
|
||||
def inputs(self) -> Tuple[Buffer, ...]:
|
||||
"""Read only buffers in the schedule."""
|
||||
return self.bufs[len(self.ast):]
|
||||
|
||||
@dataclass(frozen=True, eq=False)
|
||||
class LazyOp:
|
||||
|
|
Loading…
Reference in New Issue