update fuzz_linearizer (#3648)

included non-reduce kernel and kernel with variables. green msg when everything passed
it's possible that creating rawbufs failed due to memory error, included that in failure cases
This commit is contained in:
chenyu 2024-03-07 18:41:22 -05:00 committed by GitHub
parent b282a45e39
commit 57df8e8d82
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 16 additions and 5 deletions

View File

@ -7,7 +7,7 @@ from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.search import get_linearizer_actions, bufs_from_lin
from tinygrad.tensor import Tensor
from tinygrad.features.graph import print_tree
from tinygrad.helpers import getenv, from_mv, prod, Context
from tinygrad.helpers import getenv, from_mv, prod, colored, Context
from tinygrad.device import Device, Compiled
from tinygrad.codegen.linearizer import UOp
@ -70,7 +70,6 @@ def fuzz_linearizer(lin: Linearizer):
np.random.seed(42)
print_tree(lin.ast)
print(lin.colored_shape())
rawbufs = get_fuzz_rawbufs(lin)
seen_uops = {}
last_lins = [lin]
failures = defaultdict(list)
@ -84,6 +83,15 @@ def fuzz_linearizer(lin: Linearizer):
# get baseline unoptimized output
unoptimized = Linearizer(lin.ast)
var_vals = {v: random.randint(v.min, v.max) for v in lin.ast.vars()}
try:
rawbufs = get_fuzz_rawbufs(lin)
except Exception:
traceback.print_exc()
print("RAWBUFS FAILED!!")
failures["RAWBUFS_ERROR"].append((unoptimized.ast, unoptimized.applied_opts))
return failures
if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS":
failures["BASELINE_ERROR"].append((unoptimized.ast, unoptimized.applied_opts))
return failures
@ -134,7 +142,7 @@ def fuzz_linearizer(lin: Linearizer):
return failures
if __name__ == "__main__":
ast_strs = load_worlds()
ast_strs = load_worlds(filter_reduce=False, filter_novariable=False)
print(f"{len(ast_strs)=}")
tested = 0
failures = defaultdict(list)
@ -151,5 +159,8 @@ if __name__ == "__main__":
print(f"{msg} {i} AST: {ast}")
print(f"{msg} {i} OPTS: {opts}\n")
print(f"{tested=}")
for msg, errors in failures.items():
print(f"{msg}: {len(errors)}")
if failures:
for msg, errors in failures.items():
print(f"{msg}: {len(errors)}")
else:
print(colored("all passed", "green"))