mirror of https://github.com/commaai/tinygrad.git
update fuzz_linearizer (#3648)
included non-reduce kernel and kernel with variables. green msg when everything passed it's possible that creating rawbufs failed due to memory error, included that in failure cases
This commit is contained in:
parent
b282a45e39
commit
57df8e8d82
|
@ -7,7 +7,7 @@ from tinygrad.codegen.linearizer import Linearizer
|
|||
from tinygrad.features.search import get_linearizer_actions, bufs_from_lin
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.features.graph import print_tree
|
||||
from tinygrad.helpers import getenv, from_mv, prod, Context
|
||||
from tinygrad.helpers import getenv, from_mv, prod, colored, Context
|
||||
from tinygrad.device import Device, Compiled
|
||||
from tinygrad.codegen.linearizer import UOp
|
||||
|
||||
|
@ -70,7 +70,6 @@ def fuzz_linearizer(lin: Linearizer):
|
|||
np.random.seed(42)
|
||||
print_tree(lin.ast)
|
||||
print(lin.colored_shape())
|
||||
rawbufs = get_fuzz_rawbufs(lin)
|
||||
seen_uops = {}
|
||||
last_lins = [lin]
|
||||
failures = defaultdict(list)
|
||||
|
@ -84,6 +83,15 @@ def fuzz_linearizer(lin: Linearizer):
|
|||
# get baseline unoptimized output
|
||||
unoptimized = Linearizer(lin.ast)
|
||||
var_vals = {v: random.randint(v.min, v.max) for v in lin.ast.vars()}
|
||||
|
||||
try:
|
||||
rawbufs = get_fuzz_rawbufs(lin)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
print("RAWBUFS FAILED!!")
|
||||
failures["RAWBUFS_ERROR"].append((unoptimized.ast, unoptimized.applied_opts))
|
||||
return failures
|
||||
|
||||
if run_linearizer(unoptimized, rawbufs, var_vals) != "PASS":
|
||||
failures["BASELINE_ERROR"].append((unoptimized.ast, unoptimized.applied_opts))
|
||||
return failures
|
||||
|
@ -134,7 +142,7 @@ def fuzz_linearizer(lin: Linearizer):
|
|||
return failures
|
||||
|
||||
if __name__ == "__main__":
|
||||
ast_strs = load_worlds()
|
||||
ast_strs = load_worlds(filter_reduce=False, filter_novariable=False)
|
||||
print(f"{len(ast_strs)=}")
|
||||
tested = 0
|
||||
failures = defaultdict(list)
|
||||
|
@ -151,5 +159,8 @@ if __name__ == "__main__":
|
|||
print(f"{msg} {i} AST: {ast}")
|
||||
print(f"{msg} {i} OPTS: {opts}\n")
|
||||
print(f"{tested=}")
|
||||
for msg, errors in failures.items():
|
||||
print(f"{msg}: {len(errors)}")
|
||||
if failures:
|
||||
for msg, errors in failures.items():
|
||||
print(f"{msg}: {len(errors)}")
|
||||
else:
|
||||
print(colored("all passed", "green"))
|
||||
|
|
Loading…
Reference in New Issue