beam capture and replay in fuzz (#7099)

* beam capture and reply in fuzz

* clean a bit
This commit is contained in:
nimlgen 2024-10-16 20:26:58 +03:00 committed by GitHub
parent eac58eaaba
commit 39ab67e9ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 28 additions and 7 deletions

View File

@ -140,7 +140,7 @@ def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=No
return ("PASS", rawbufs, var_vals, ground_truth, run_state)
def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2):
def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2, opts_list=None):
SEED = getenv("SEED", 42)
random.seed(SEED)
np.random.seed(SEED)
@ -162,10 +162,18 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2):
print("skipping simple kernel")
return failures
for depth in range(getenv("DEPTH", 1 if FUZZ_ALL_ACTIONS else 10)):
test_depth = 1 if opts_list is not None else getenv("DEPTH", 1 if FUZZ_ALL_ACTIONS else 10)
for depth in range(test_depth):
next_lins = []
for lin in last_lins:
actions = get_kernel_actions(lin, include_0=False)
if opts_list is None: actions = get_kernel_actions(lin, include_0=False)
else:
actions = {}
for oi,opts in enumerate(opts_list):
lin2 = lin.copy()
for o in opts: lin2.apply_opt(o)
actions[oi] = lin2
if not actions: continue
if depth == 0 and getenv("FUZZ_REQUIRE_TC", 0):
tc_acts = {i: k for k in actions.values() if k.applied_opts[0].op == OptOps.TC}
@ -174,7 +182,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2):
test_lins = list(actions.values())
if FUZZ_ALL_ACTIONS: print(f"testing {lin.applied_opts=} with {len(actions)} actions")
else: test_lins = [random.choice(test_lins)]
elif opts_list is None: test_lins = [random.choice(test_lins)]
for test_lin in test_lins:
if not FUZZ_ALL_ACTIONS and test_lin.applied_opts: print(f"applied opts: {test_lin.applied_opts}")
@ -230,12 +238,14 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run a fuzz testing on one or more kernels", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--ast", type=str, default=None, help="the ast for the kernel to be optimized")
parser.add_argument("--file", type=str, default=None, help="a file containing asts to be optimized, one per line")
parser.add_argument("--beamreplay", type=str, default=None, help="replay asts and opts got from beam with CAPTURE_BEAM")
parser.add_argument("--logfile", type=str, default=None, help="a file containing a tuple of ast and applied_opts, one per line")
parser.add_argument("--expected-failures", type=int, default=0, help="the number of expected failed kernels")
parser.add_argument("--rtol", type=float, default=1e-2, help="relative tolerance for numerical comparison")
parser.add_argument("--atol", type=float, default=1e-2, help="absolute tolerance for numerical comparison")
args = parser.parse_args()
opts_list = None
if args.ast is not None:
print("loaded AST from CLI")
ast_strs = [args.ast]
@ -243,6 +253,16 @@ if __name__ == "__main__":
print(f"loading ASTs from file '{args.file}'")
with open(args.file, 'r') as file:
ast_strs = file.readlines()
elif args.beamreplay is not None:
print(f"loading BEAM replay from file '{args.beamreplay}'")
with open(args.beamreplay, 'r') as file: fdata = file.readlines()
ast_strs, opts_list = [x.split(' :: ')[0] for x in fdata], [x.split(' :: ')[1] for x in fdata]
# dedup ast_strs and opts_list
dct = defaultdict(list)
for i in range(len(ast_strs)): dct[ast_strs[i]].append(eval(opts_list[i]))
ast_strs_items = list(dct.keys())
opts_list = [dct[c] for c in ast_strs_items]
elif args.logfile is not None:
print(f"loading ASTs from LOGKERNS file '{args.file}'")
with open(args.logfile, 'r') as file:
@ -273,7 +293,7 @@ if __name__ == "__main__":
with Timing(f"tested ast {i}: "):
tested += 1
fuzz_failures = fuzz_linearizer(lin, rtol=args.rtol, atol=args.atol)
fuzz_failures = fuzz_linearizer(lin, rtol=args.rtol, atol=args.atol, opts_list=(opts_list[i] if opts_list else None))
if fuzz_failures: failed_ids.append(i)
for k, v in fuzz_failures.items():
for f in v:

View File

@ -117,7 +117,7 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> Dict[int, Kernel]:
except KernelOptError: pass
return acted_lins
beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
beam_pool, BEAM_DEBUG, CAPTURE_BEAM = None, getenv("BEAM_DEBUG"), getenv("CAPTURE_BEAM", "")
def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True, disable_cache=getenv("IGNORE_BEAM_CACHE")) -> Kernel:
global beam_pool
key = {"ast": lin.ast.key, "amt": amt, "allow_test_size": allow_test_size, "device": lin.opts.device, "suffix": lin.opts.suffix}
@ -154,7 +154,8 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True,
# filter out kernels that use 1000x more compute than the smallest
least_compute_ops = min(this_compute_ops:=sym_infer(p.op_estimate, var_vals), least_compute_ops)
if least_compute_ops*1000 < this_compute_ops: continue
#print(acted_lins[i].colored_shape(), acted_lins[i].applied_opts) # for debugging BEAMs that segfault
if len(CAPTURE_BEAM) > 0:
with open(CAPTURE_BEAM, 'a') as f: f.write(str(acted_lins[i].ast).replace('\n','')+f" :: {acted_lins[i].applied_opts}\n")
seen_libs.add(lib)
try: tms = _time_program(p, lib, var_vals, rawbufs, early_stop=beam[0][1]*3 if len(beam) else 1.0, clear_l2=hasattr(dev, 'invalidate_caches'))
except RuntimeError: continue # for runtime issues