update CI tests in extra with UOp AST (#6290)

This commit is contained in:
gswangg 2024-08-28 12:26:50 -07:00 committed by GitHub
parent 3517aa89d9
commit 94a72d44d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 5 deletions

View File

@ -11,7 +11,7 @@ inf, nan = float('inf'), float('nan')
# kernel unpacker
from tinygrad.codegen.kernel import Kernel
def ast_str_to_ast(ast_str:str) -> LazyOp: return LazyOp(MetaOps.KERNEL, val) if isinstance(val:=eval(ast_str), tuple) else val
def ast_str_to_ast(ast_str:str) -> UOp: return eval(ast_str)
def ast_str_to_lin(ast_str:str, opts=None): return Kernel(ast_str_to_ast(ast_str), opts=opts)
def kern_str_to_lin(kern_str:str, opts=None):
(ast, applied_opts,) = eval(kern_str)
@ -28,7 +28,7 @@ from tinygrad.helpers import dedup
def load_worlds(filter_reduce=True, filter_noimage=True, filter_novariable=True):
fn = Path(__file__).parent.parent / "datasets/sops.gz"
ast_strs = dedup(gzip.open(fn).read().decode('utf-8').strip().split("\n"))
if filter_reduce: ast_strs = [x for x in ast_strs if "ReduceOps" in x]
if filter_reduce: ast_strs = [x for x in ast_strs if "REDUCE_AXIS" in x]
if filter_noimage: ast_strs = [x for x in ast_strs if "dtypes.image" not in x]
if filter_novariable: ast_strs = [x for x in ast_strs if "Variable" not in x]
random.seed(1337)

View File

@ -3,9 +3,8 @@ from enum import Enum, auto
from collections import defaultdict
from typing import List, Tuple, DefaultDict
from extra.optimization.helpers import load_worlds, ast_str_to_ast
from extra.ops import LazyOp
from tinygrad.helpers import prod, tqdm
from tinygrad.ops import UOps
from tinygrad.ops import UOp, UOps
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import sym_infer, Node
@ -136,7 +135,7 @@ def test_rebuild(st: ShapeTracker):
last_v2 = rebuilt_st.views[-1]
assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}"
def test_rebuild_bufferop_st(ast:LazyOp):
def test_rebuild_bufferop_st(ast:UOp):
if ast.op is UOps.SHAPETRACKER:
test_rebuild(ast.arg)
for src in ast.src: test_rebuild_bufferop_st(src)