update CI tests in extra with UOp AST (#6290)

This commit is contained in:
gswangg 2024-08-28 12:26:50 -07:00 committed by GitHub
parent 3517aa89d9
commit 94a72d44d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 4 additions and 5 deletions

View File

@ -11,7 +11,7 @@ inf, nan = float('inf'), float('nan')
# kernel unpacker # kernel unpacker
from tinygrad.codegen.kernel import Kernel from tinygrad.codegen.kernel import Kernel
def ast_str_to_ast(ast_str:str) -> LazyOp: return LazyOp(MetaOps.KERNEL, val) if isinstance(val:=eval(ast_str), tuple) else val def ast_str_to_ast(ast_str:str) -> UOp: return eval(ast_str)
def ast_str_to_lin(ast_str:str, opts=None): return Kernel(ast_str_to_ast(ast_str), opts=opts) def ast_str_to_lin(ast_str:str, opts=None): return Kernel(ast_str_to_ast(ast_str), opts=opts)
def kern_str_to_lin(kern_str:str, opts=None): def kern_str_to_lin(kern_str:str, opts=None):
(ast, applied_opts,) = eval(kern_str) (ast, applied_opts,) = eval(kern_str)
@ -28,7 +28,7 @@ from tinygrad.helpers import dedup
def load_worlds(filter_reduce=True, filter_noimage=True, filter_novariable=True): def load_worlds(filter_reduce=True, filter_noimage=True, filter_novariable=True):
fn = Path(__file__).parent.parent / "datasets/sops.gz" fn = Path(__file__).parent.parent / "datasets/sops.gz"
ast_strs = dedup(gzip.open(fn).read().decode('utf-8').strip().split("\n")) ast_strs = dedup(gzip.open(fn).read().decode('utf-8').strip().split("\n"))
if filter_reduce: ast_strs = [x for x in ast_strs if "ReduceOps" in x] if filter_reduce: ast_strs = [x for x in ast_strs if "REDUCE_AXIS" in x]
if filter_noimage: ast_strs = [x for x in ast_strs if "dtypes.image" not in x] if filter_noimage: ast_strs = [x for x in ast_strs if "dtypes.image" not in x]
if filter_novariable: ast_strs = [x for x in ast_strs if "Variable" not in x] if filter_novariable: ast_strs = [x for x in ast_strs if "Variable" not in x]
random.seed(1337) random.seed(1337)

View File

@ -3,9 +3,8 @@ from enum import Enum, auto
from collections import defaultdict from collections import defaultdict
from typing import List, Tuple, DefaultDict from typing import List, Tuple, DefaultDict
from extra.optimization.helpers import load_worlds, ast_str_to_ast from extra.optimization.helpers import load_worlds, ast_str_to_ast
from extra.ops import LazyOp
from tinygrad.helpers import prod, tqdm from tinygrad.helpers import prod, tqdm
from tinygrad.ops import UOps from tinygrad.ops import UOp, UOps
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import sym_infer, Node from tinygrad.shape.symbolic import sym_infer, Node
@ -136,7 +135,7 @@ def test_rebuild(st: ShapeTracker):
last_v2 = rebuilt_st.views[-1] last_v2 = rebuilt_st.views[-1]
assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}" assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}"
def test_rebuild_bufferop_st(ast:LazyOp): def test_rebuild_bufferop_st(ast:UOp):
if ast.op is UOps.SHAPETRACKER: if ast.op is UOps.SHAPETRACKER:
test_rebuild(ast.arg) test_rebuild(ast.arg)
for src in ast.src: test_rebuild_bufferop_st(src) for src in ast.src: test_rebuild_bufferop_st(src)