mirror of https://github.com/commaai/tinygrad.git
rename prescheduled items to lsi [run_process_replay] (#5959)
* rename to lsi * fuzz_schedule more typings * rename fuzz_schedule
This commit is contained in:
parent
728b7e189e
commit
39dda3d042
|
@ -5,7 +5,7 @@ from tinygrad.device import Buffer
|
|||
from tinygrad.engine.realize import CustomOp, capturing, lower_schedule_item
|
||||
from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.engine.schedule import _graph_schedule, ScheduleItem
|
||||
from tinygrad.engine.schedule import LBScheduleItem, _graph_schedule, ScheduleItem
|
||||
from tinygrad.ops import MetaOps
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
|
||||
|
@ -14,11 +14,12 @@ FUZZ_SCHEDULE_MAX_PATHS = getenv("FUZZ_SCHEDULE_MAX_PATHS", 10)
|
|||
|
||||
def fuzz_schedule(outs:List[LazyBuffer]):
|
||||
# find toposorts across all tunable params
|
||||
unique_ts: Dict[Tuple[LazyBuffer, ...], Tuple[Dict, Dict[LazyBuffer, Tuple]]] = {}
|
||||
unique_ts: Dict[Tuple[LazyBuffer, ...], Tuple[Dict[str, int], Dict[LazyBuffer, LBScheduleItem]]] = {}
|
||||
for combination in itertools.product(*ctx_vars.values()):
|
||||
for var, val in zip(ctx_vars, combination): var.value = val
|
||||
ctx_var_values = dict(zip([v.key for v in ctx_vars], combination))
|
||||
graph, in_degree, prescheduled = _graph_schedule(outs, set())
|
||||
for ts in find_all_toposorts(graph, in_degree): unique_ts[ts] = (dict(zip([v.key for v in ctx_vars], combination)), prescheduled)
|
||||
for ts in find_all_toposorts(graph, in_degree): unique_ts[ts] = (ctx_var_values, prescheduled)
|
||||
toposorts = list(unique_ts.items())
|
||||
if DEBUG >= 1: print(colored(f"fuzzing {len(toposorts)} schedule permutations", "yellow"))
|
||||
|
||||
|
@ -30,16 +31,16 @@ def fuzz_schedule(outs:List[LazyBuffer]):
|
|||
seed = Tensor._seed
|
||||
ts, (_, prescheduled) = toposorts[0]
|
||||
for key in ts:
|
||||
for out in (ps:=prescheduled[key]).outputs:
|
||||
for out in (lsi:=prescheduled[key]).outputs:
|
||||
# freeze assign state before exec
|
||||
if out.op is MetaOps.ASSIGN:
|
||||
prerealized[out] = out.buffer.as_buffer()
|
||||
assign_targets[out.srcs[1]] = out
|
||||
for x in ps.inputs:
|
||||
for x in lsi.inputs:
|
||||
if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
|
||||
si = ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs+ps.inputs if x.size != 0))
|
||||
si = ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0))
|
||||
_exec_si(si, seed)
|
||||
for out in ps.outputs:
|
||||
for out in lsi.outputs:
|
||||
ground_truth[out] = out.buffer.as_buffer()
|
||||
del out.srcs # only schedule the LazyBuffer in this fuzz run
|
||||
|
||||
|
@ -48,19 +49,19 @@ def fuzz_schedule(outs:List[LazyBuffer]):
|
|||
if DEBUG >= 1: print(colored(f"testing permutation {i} {ctx}", "yellow"))
|
||||
rawbufs: Dict[LazyBuffer, Buffer] = {}
|
||||
for key in ts:
|
||||
for out in (ps:=prescheduled[key]).outputs:
|
||||
for out in (lsi:=prescheduled[key]).outputs:
|
||||
rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype)
|
||||
if out.op is MetaOps.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])
|
||||
for x in ps.inputs:
|
||||
for x in lsi.inputs:
|
||||
if x not in rawbufs:
|
||||
# override the assign_target after ASSIGN
|
||||
if x in assign_targets and assign_targets[x] in rawbufs: rawbufs[x] = rawbufs[assign_targets[x]]
|
||||
elif x.device == "NPY": rawbufs[x] = x.buffer
|
||||
# copy the pre realized input
|
||||
else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=prerealized[x])
|
||||
si = ScheduleItem(ps.ast, tuple(rawbufs[x] for x in ps.outputs+ps.inputs if x.size != 0))
|
||||
si = ScheduleItem(lsi.ast, tuple(rawbufs[x] for x in lsi.outputs+lsi.inputs if x.size != 0))
|
||||
_exec_si(si, seed)
|
||||
for out in ps.outputs:
|
||||
for out in lsi.outputs:
|
||||
outbuf = np.frombuffer(rawbufs[out].as_buffer(), _to_np_dtype(out.dtype))
|
||||
try: np.testing.assert_allclose(outbuf, np.frombuffer(ground_truth[out], _to_np_dtype(out.dtype)), atol=1e-2, rtol=1e-2)
|
||||
except Exception as e:
|
||||
|
|
|
@ -346,7 +346,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> \
|
|||
|
||||
# preschedule all buffers in realizes
|
||||
prescheduled = {group[0]:_lower_lazybuffer(group, realizes) for group in output_groups.values()}
|
||||
schedule_targets = {out:ps for ps in prescheduled.values() for out in ps.outputs}
|
||||
schedule_targets = {out:lsi for lsi in prescheduled.values() for out in lsi.outputs}
|
||||
|
||||
graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
|
||||
in_degree: DefaultDict[LazyBuffer, int] = defaultdict(int)
|
||||
|
@ -382,21 +382,21 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
|
|||
with Context(FUSE_ARANGE=0, SAVE_SCHEDULE=1): _graph_schedule(outs, set())
|
||||
with Context(FUSE_ARANGE=1, SAVE_SCHEDULE=1): graph, in_degree, prescheduled = _graph_schedule(outs, seen)
|
||||
else: graph, in_degree, prescheduled = _graph_schedule(outs, seen)
|
||||
queue = deque(si for key, si in prescheduled.items() if in_degree[key] == 0)
|
||||
queue = deque(lsi for key, lsi in prescheduled.items() if in_degree[key] == 0)
|
||||
schedule: List[ScheduleItem] = []
|
||||
var_vals: Dict[Variable, int] = {}
|
||||
kernel_number = GlobalCounters.kernel_count
|
||||
while queue:
|
||||
ps = queue.popleft()
|
||||
for buf in ps.outputs: seen.add(buf)
|
||||
lsi = queue.popleft()
|
||||
for buf in lsi.outputs: seen.add(buf)
|
||||
if GRAPH:
|
||||
kernel_number += 1
|
||||
for out in ps.outputs: realized_lazybuffer(out, kernel_number)
|
||||
var_vals = merge_dicts([var_vals, ps.var_vals])
|
||||
for out in ps.outputs: del out.srcs # can only schedule once
|
||||
schedule.append(si:=ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs+ps.inputs if x.size != 0), ps.metadata))
|
||||
for out in lsi.outputs: realized_lazybuffer(out, kernel_number)
|
||||
var_vals = merge_dicts([var_vals, lsi.var_vals])
|
||||
for out in lsi.outputs: del out.srcs # can only schedule once
|
||||
schedule.append(si:=ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata))
|
||||
if logops and si.ast.op is MetaOps.KERNEL and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
|
||||
for x in graph[ps.outputs[0]]:
|
||||
for x in graph[lsi.outputs[0]]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(prescheduled[x])
|
||||
|
||||
|
|
Loading…
Reference in New Issue