type verify intermediate UOps [run_process_replay] (#6140)

* type verify intermediate UOps [run_process_replay]

* merge asserts

* variable const
This commit is contained in:
qazal 2024-08-20 01:59:01 +08:00 committed by GitHub
parent 478145cb8e
commit 2242ff84be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 9 additions and 4 deletions

View File

@ -63,7 +63,7 @@ class TestLinearizerFailures(unittest.TestCase):
helper_test_lin(Kernel(ast), opts, failed_platforms=[])
def test_failure_6(self):
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 10), strides=(1, 20), offset=0, mask=None, contiguous=False))))),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=10.0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(10, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(10, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),))))
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 10), strides=(1, 20), offset=0, mask=None, contiguous=False))))),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=10, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(10, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(10, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),))))
opts = [Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=0)]
# COMPILE FAILED, KeyError: UOps.CONST
helper_test_lin(Kernel(ast), opts, failed_platforms=[])

View File

@ -4,7 +4,7 @@ from dataclasses import dataclass, replace
from collections import defaultdict
from typing import Literal, Optional, List, Tuple, Union, cast, Dict, Final, DefaultDict
from tinygrad.ops import BinaryOps, ReduceOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, print_uops
from tinygrad.ops import BinaryOps, ReduceOps, UNSAFE_PAD_OPS, KernelInfo, BUFFER_UOPS, UOp, UOps, print_uops, type_verify
from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore, Program
from tinygrad.dtype import DType, ImageDType, PtrDType
@ -768,7 +768,8 @@ class Kernel:
return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
# the living definition of UOp st_arg
# the living definition of intermediate UOps
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None:
if uop in sts: return
op, _, src, arg = uop.op, uop.dtype, uop.src, uop.arg
@ -799,4 +800,5 @@ def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
for out in ast.src: _assert_valid_uop(out, out.st_arg, sts)
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
type_verify(list(sts))
return sts

View File

@ -312,10 +312,13 @@ def type_verify(uops):
if uop is UOps.DEFINE_LOCAL: assert isinstance(dtype, PtrDType), f"invalid dtype for local buffer {dtype}"
if uop is UOps.DEFINE_GLOBAL: assert isinstance(dtype, (PtrDType, ImageDType)), f"invalid dtype for global buffer {dtype}"
if isinstance(dtype, ImageDType): assert uop is UOps.DEFINE_GLOBAL, f"{uop} can't be image"
if uop is UOps.SHAPETRACKER: assert len(src) == 0, f"SHAPETRACKER must only define a ShapeTracker arg {uop}"
if uop is UOps.REDUCE_AXIS: assert isinstance(arg, tuple) and len(arg) == 2 and arg[0] in ReduceOps, f"invalid arg for REDUCE_AXIS {arg}"
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
if uop is UOps.CONST:
assert dtype is not None and dtype == dtype.scalar(), f"consts must be scalar, got {dtype}"
assert type(arg) is type(dtypes.as_const(arg, dtype)), f"type of {arg=} does not match {dtype}"
# TODO: intermediate CONST of Variable is DEFINE_VAR
assert (isinstance(arg, Variable) and u.src) or (type(arg) is type(dtypes.as_const(arg, dtype))), f"type of {arg=} does not match {dtype}"
if uop is UOps.DEFINE_ACC: assert dtype is not None and src[0].dtype == dtype, f"dtype mismatch {src[0].dtype=} != {dtype=}"
if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype is not None # type is the output type, not an arg
if uop is UOps.CAST: assert dtype.count == 1 and len(src) == 1