mirror of https://github.com/commaai/tinygrad.git
update CI tests in extra with UOp AST (#6290)
This commit is contained in:
parent
3517aa89d9
commit
94a72d44d2
|
@ -11,7 +11,7 @@ inf, nan = float('inf'), float('nan')
|
||||||
|
|
||||||
# kernel unpacker
|
# kernel unpacker
|
||||||
from tinygrad.codegen.kernel import Kernel
|
from tinygrad.codegen.kernel import Kernel
|
||||||
def ast_str_to_ast(ast_str:str) -> LazyOp: return LazyOp(MetaOps.KERNEL, val) if isinstance(val:=eval(ast_str), tuple) else val
|
def ast_str_to_ast(ast_str:str) -> UOp: return eval(ast_str)
|
||||||
def ast_str_to_lin(ast_str:str, opts=None): return Kernel(ast_str_to_ast(ast_str), opts=opts)
|
def ast_str_to_lin(ast_str:str, opts=None): return Kernel(ast_str_to_ast(ast_str), opts=opts)
|
||||||
def kern_str_to_lin(kern_str:str, opts=None):
|
def kern_str_to_lin(kern_str:str, opts=None):
|
||||||
(ast, applied_opts,) = eval(kern_str)
|
(ast, applied_opts,) = eval(kern_str)
|
||||||
|
@ -28,7 +28,7 @@ from tinygrad.helpers import dedup
|
||||||
def load_worlds(filter_reduce=True, filter_noimage=True, filter_novariable=True):
|
def load_worlds(filter_reduce=True, filter_noimage=True, filter_novariable=True):
|
||||||
fn = Path(__file__).parent.parent / "datasets/sops.gz"
|
fn = Path(__file__).parent.parent / "datasets/sops.gz"
|
||||||
ast_strs = dedup(gzip.open(fn).read().decode('utf-8').strip().split("\n"))
|
ast_strs = dedup(gzip.open(fn).read().decode('utf-8').strip().split("\n"))
|
||||||
if filter_reduce: ast_strs = [x for x in ast_strs if "ReduceOps" in x]
|
if filter_reduce: ast_strs = [x for x in ast_strs if "REDUCE_AXIS" in x]
|
||||||
if filter_noimage: ast_strs = [x for x in ast_strs if "dtypes.image" not in x]
|
if filter_noimage: ast_strs = [x for x in ast_strs if "dtypes.image" not in x]
|
||||||
if filter_novariable: ast_strs = [x for x in ast_strs if "Variable" not in x]
|
if filter_novariable: ast_strs = [x for x in ast_strs if "Variable" not in x]
|
||||||
random.seed(1337)
|
random.seed(1337)
|
||||||
|
|
|
@ -3,9 +3,8 @@ from enum import Enum, auto
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import List, Tuple, DefaultDict
|
from typing import List, Tuple, DefaultDict
|
||||||
from extra.optimization.helpers import load_worlds, ast_str_to_ast
|
from extra.optimization.helpers import load_worlds, ast_str_to_ast
|
||||||
from extra.ops import LazyOp
|
|
||||||
from tinygrad.helpers import prod, tqdm
|
from tinygrad.helpers import prod, tqdm
|
||||||
from tinygrad.ops import UOps
|
from tinygrad.ops import UOp, UOps
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
from tinygrad.shape.symbolic import sym_infer, Node
|
from tinygrad.shape.symbolic import sym_infer, Node
|
||||||
|
|
||||||
|
@ -136,7 +135,7 @@ def test_rebuild(st: ShapeTracker):
|
||||||
last_v2 = rebuilt_st.views[-1]
|
last_v2 = rebuilt_st.views[-1]
|
||||||
assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}"
|
assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}"
|
||||||
|
|
||||||
def test_rebuild_bufferop_st(ast:LazyOp):
|
def test_rebuild_bufferop_st(ast:UOp):
|
||||||
if ast.op is UOps.SHAPETRACKER:
|
if ast.op is UOps.SHAPETRACKER:
|
||||||
test_rebuild(ast.arg)
|
test_rebuild(ast.arg)
|
||||||
for src in ast.src: test_rebuild_bufferop_st(src)
|
for src in ast.src: test_rebuild_bufferop_st(src)
|
||||||
|
|
Loading…
Reference in New Issue