mirror of https://github.com/commaai/tinygrad.git
beam capture and replay in fuzz (#7099)
* beam capture and reply in fuzz * clean a bit
This commit is contained in:
parent
eac58eaaba
commit
39ab67e9ef
|
@ -140,7 +140,7 @@ def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=No
|
|||
|
||||
return ("PASS", rawbufs, var_vals, ground_truth, run_state)
|
||||
|
||||
def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2):
|
||||
def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2, opts_list=None):
|
||||
SEED = getenv("SEED", 42)
|
||||
random.seed(SEED)
|
||||
np.random.seed(SEED)
|
||||
|
@ -162,10 +162,18 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2):
|
|||
print("skipping simple kernel")
|
||||
return failures
|
||||
|
||||
for depth in range(getenv("DEPTH", 1 if FUZZ_ALL_ACTIONS else 10)):
|
||||
test_depth = 1 if opts_list is not None else getenv("DEPTH", 1 if FUZZ_ALL_ACTIONS else 10)
|
||||
for depth in range(test_depth):
|
||||
next_lins = []
|
||||
for lin in last_lins:
|
||||
actions = get_kernel_actions(lin, include_0=False)
|
||||
if opts_list is None: actions = get_kernel_actions(lin, include_0=False)
|
||||
else:
|
||||
actions = {}
|
||||
for oi,opts in enumerate(opts_list):
|
||||
lin2 = lin.copy()
|
||||
for o in opts: lin2.apply_opt(o)
|
||||
actions[oi] = lin2
|
||||
|
||||
if not actions: continue
|
||||
if depth == 0 and getenv("FUZZ_REQUIRE_TC", 0):
|
||||
tc_acts = {i: k for k in actions.values() if k.applied_opts[0].op == OptOps.TC}
|
||||
|
@ -174,7 +182,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2):
|
|||
|
||||
test_lins = list(actions.values())
|
||||
if FUZZ_ALL_ACTIONS: print(f"testing {lin.applied_opts=} with {len(actions)} actions")
|
||||
else: test_lins = [random.choice(test_lins)]
|
||||
elif opts_list is None: test_lins = [random.choice(test_lins)]
|
||||
|
||||
for test_lin in test_lins:
|
||||
if not FUZZ_ALL_ACTIONS and test_lin.applied_opts: print(f"applied opts: {test_lin.applied_opts}")
|
||||
|
@ -230,12 +238,14 @@ if __name__ == "__main__":
|
|||
parser = argparse.ArgumentParser(description="Run a fuzz testing on one or more kernels", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument("--ast", type=str, default=None, help="the ast for the kernel to be optimized")
|
||||
parser.add_argument("--file", type=str, default=None, help="a file containing asts to be optimized, one per line")
|
||||
parser.add_argument("--beamreplay", type=str, default=None, help="replay asts and opts got from beam with CAPTURE_BEAM")
|
||||
parser.add_argument("--logfile", type=str, default=None, help="a file containing a tuple of ast and applied_opts, one per line")
|
||||
parser.add_argument("--expected-failures", type=int, default=0, help="the number of expected failed kernels")
|
||||
parser.add_argument("--rtol", type=float, default=1e-2, help="relative tolerance for numerical comparison")
|
||||
parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance for numerical comparison")
|
||||
args = parser.parse_args()
|
||||
|
||||
opts_list = None
|
||||
if args.ast is not None:
|
||||
print("loaded AST from CLI")
|
||||
ast_strs = [args.ast]
|
||||
|
@ -243,6 +253,16 @@ if __name__ == "__main__":
|
|||
print(f"loading ASTs from file '{args.file}'")
|
||||
with open(args.file, 'r') as file:
|
||||
ast_strs = file.readlines()
|
||||
elif args.beamreplay is not None:
|
||||
print(f"loading BEAM replay from file '{args.beamreplay}'")
|
||||
with open(args.beamreplay, 'r') as file: fdata = file.readlines()
|
||||
ast_strs, opts_list = [x.split(' :: ')[0] for x in fdata], [x.split(' :: ')[1] for x in fdata]
|
||||
|
||||
# dedup ast_strs and opts_list
|
||||
dct = defaultdict(list)
|
||||
for i in range(len(ast_strs)): dct[ast_strs[i]].append(eval(opts_list[i]))
|
||||
ast_strs_items = list(dct.keys())
|
||||
opts_list = [dct[c] for c in ast_strs_items]
|
||||
elif args.logfile is not None:
|
||||
print(f"loading ASTs from LOGKERNS file '{args.file}'")
|
||||
with open(args.logfile, 'r') as file:
|
||||
|
@ -273,7 +293,7 @@ if __name__ == "__main__":
|
|||
|
||||
with Timing(f"tested ast {i}: "):
|
||||
tested += 1
|
||||
fuzz_failures = fuzz_linearizer(lin, rtol=args.rtol, atol=args.atol)
|
||||
fuzz_failures = fuzz_linearizer(lin, rtol=args.rtol, atol=args.atol, opts_list=(opts_list[i] if opts_list else None))
|
||||
if fuzz_failures: failed_ids.append(i)
|
||||
for k, v in fuzz_failures.items():
|
||||
for f in v:
|
||||
|
|
|
@ -117,7 +117,7 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> Dict[int, Kernel]:
|
|||
except KernelOptError: pass
|
||||
return acted_lins
|
||||
|
||||
beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
|
||||
beam_pool, BEAM_DEBUG, CAPTURE_BEAM = None, getenv("BEAM_DEBUG"), getenv("CAPTURE_BEAM", "")
|
||||
def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True, disable_cache=getenv("IGNORE_BEAM_CACHE")) -> Kernel:
|
||||
global beam_pool
|
||||
key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
|
||||
|
@ -154,7 +154,8 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True,
|
|||
# filter out kernels that use 1000x more compute than the smallest
|
||||
least_compute_ops = min(this_compute_ops:=sym_infer(p.op_estimate, var_vals), least_compute_ops)
|
||||
if least_compute_ops*1000 < this_compute_ops: continue
|
||||
#print(acted_lins[i].colored_shape(), acted_lins[i].applied_opts) # for debugging BEAMs that segfault
|
||||
if len(CAPTURE_BEAM) > 0:
|
||||
with open(CAPTURE_BEAM, 'a') as f: f.write(str(acted_lins[i].ast).replace('\n','')+f" :: {acted_lins[i].applied_opts}\n")
|
||||
seen_libs.add(lib)
|
||||
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches'))
|
||||
except RuntimeError: continue # for runtime issues
|
||||
|
|
Loading…
Reference in New Issue