mirror of https://github.com/commaai/tinygrad.git
type verify intermediate UOps [run_process_replay] (#6140)
* type verify intermediate UOps [run_process_replay] * merge asserts * variable const
This commit is contained in:
parent
478145cb8e
commit
2242ff84be
|
@ -63,7 +63,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
|||
helper_test_lin(Kernel(ast), opts, failed_platforms=[])
|
||||
|
||||
def test_failure_6(self):
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 10), strides=(1, 20), offset=0, mask=None, contiguous=False))))),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=10.0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(10, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(10, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),))))
|
||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 10), strides=(1, 20), offset=0, mask=None, contiguous=False))))),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=10, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(10, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(10, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),))))
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=0)]
|
||||
# COMPILE FAILED, KeyError: UOps.CONST
|
||||
helper_test_lin(Kernel(ast), opts, failed_platforms=[])
|
||||
|
|
|
@ -4,7 +4,7 @@ from dataclasses import dataclass, replace
|
|||
from collections import defaultdict
|
||||
from typing import Literal, Optional, List, Tuple, Union, cast, Dict, Final, DefaultDict
|
||||
|
||||
from tinygrad.ops import BinaryOps, ReduceOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, print_uops
|
||||
from tinygrad.ops import BinaryOps, ReduceOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, print_uops, type_verify
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore, Program
|
||||
from tinygrad.dtype import DType, ImageDType, PtrDType
|
||||
|
@ -768,7 +768,8 @@ class Kernel:
|
|||
return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
|
||||
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
|
||||
|
||||
# the living definition of UOp st_arg
|
||||
# the living definition of intermediate UOps
|
||||
|
||||
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None:
|
||||
if uop in sts: return
|
||||
op, _, src, arg = uop.op, uop.dtype, uop.src, uop.arg
|
||||
|
@ -799,4 +800,5 @@ def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
|
|||
for out in ast.src: _assert_valid_uop(out, out.st_arg, sts)
|
||||
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
|
||||
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
|
||||
type_verify(list(sts))
|
||||
return sts
|
||||
|
|
|
@ -312,10 +312,13 @@ def type_verify(uops):
|
|||
if uop is UOps.DEFINE_LOCAL: assert isinstance(dtype, PtrDType), f"invalid dtype for local buffer {dtype}"
|
||||
if uop is UOps.DEFINE_GLOBAL: assert isinstance(dtype, (PtrDType, ImageDType)), f"invalid dtype for global buffer {dtype}"
|
||||
if isinstance(dtype, ImageDType): assert uop is UOps.DEFINE_GLOBAL, f"{uop} can't be image"
|
||||
if uop is UOps.SHAPETRACKER: assert len(src) == 0, f"SHAPETRACKER must only define a ShapeTracker arg {uop}"
|
||||
if uop is UOps.REDUCE_AXIS: assert isinstance(arg, tuple) and len(arg) == 2 and arg[0] in ReduceOps, f"invalid arg for REDUCE_AXIS {arg}"
|
||||
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
|
||||
if uop is UOps.CONST:
|
||||
assert dtype is not None and dtype == dtype.scalar(), f"consts must be scalar, got {dtype}"
|
||||
assert type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
|
||||
# TODO: intermediate CONST of Variable is DEFINE_VAR
|
||||
assert (isinstance(arg, Variable) and u.src) or (type(arg) is type(dtypes.as_const(arg, dtype))), f"type of {arg=} does not match {dtype}"
|
||||
if uop is UOps.DEFINE_ACC: assert dtype is not None and src[0].dtype == dtype, f"dtype mismatch {src[0].dtype=} != {dtype=}"
|
||||
if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg
|
||||
if uop is UOps.CAST: assert dtype.count == 1 and len(src) == 1
|
||||
|
|
Loading…
Reference in New Issue