add timing to fuzz_linearizer (#7056)

and applied smaller FUZZ_MAX_SIZE. this is getting quite slow in CI
This commit is contained in:
chenyu 2024-10-14 11:57:41 -04:00 committed by GitHub
parent 0d2462cbdf
commit fbaab30fe3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 10 deletions

View File

@ -345,7 +345,7 @@ jobs:
- name: Test Beam Search
run: PYTHONPATH="." METAL=1 IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py
- name: Fuzz Test linearizer
run: PYTHONPATH="." METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=2 FUZZ_N=24 FUZZ_MAX_SIZE=10000000 python test/external/fuzz_linearizer.py
run: PYTHONPATH="." METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=2 FUZZ_N=24 FUZZ_MAX_SIZE=1000000 python test/external/fuzz_linearizer.py
- name: Fuzz Test models schedule
run: FUZZ_SCHEDULE=1 FUZZ_SCHEDULE_MAX_PATHS=5 python -m pytest test/models/test_train.py test/models/test_end2end.py
- name: Run TRANSCENDENTAL math

View File

@ -10,7 +10,7 @@ from tinygrad.codegen.kernel import Kernel
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.engine.search import get_kernel_actions, bufs_from_lin
from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG
from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG, Timing
from tinygrad.ops import UnaryOps, UOp, UOps
from test.helpers import is_dtype_supported
@ -231,14 +231,13 @@ if __name__ == "__main__":
print("skipping kernel due to not supported dtype")
continue
print(f"testing ast {i}")
tested += 1
fuzz_failures = fuzz_linearizer(lin, rtol=args.rtol, atol=args.atol)
if fuzz_failures: failed_ids.append(i)
for k, v in fuzz_failures.items():
for f in v:
failures[k].append(f)
with Timing(f"tested ast {i}: "):
tested += 1
fuzz_failures = fuzz_linearizer(lin, rtol=args.rtol, atol=args.atol)
if fuzz_failures: failed_ids.append(i)
for k, v in fuzz_failures.items():
for f in v:
failures[k].append(f)
for msg, errors in failures.items():
for i, (ast, opts) in enumerate(errors):