fixup non-void SINKs in tests [run_process_replay] (#6624)

This commit is contained in:
qazal 2024-09-21 13:29:18 +08:00 committed by GitHub
parent 391d14438e
commit d2351af019
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 3 deletions

View File

@ -13,7 +13,7 @@ from tinygrad.shape.view import View
class InvalidASTException(Exception): pass
def helper_test_verify_ast(*stores:UOp) -> Kernel:
sink = UOp(UOps.SINK, None, stores)
sink = UOp(UOps.SINK, dtypes.void, stores)
if DEBUG >= 3:
for op in stores: print(op)
try: verify_ast(sink)
@ -50,7 +50,7 @@ class TestVerifyAST(unittest.TestCase):
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)]
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop()))
b = a + UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.MAX, (1,)))
st = UOp(UOps.STORE, None, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b))
st = UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b))
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
def test_shrink_ok(self):

View File

@ -636,7 +636,7 @@ spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.b
# NOTE: for testing, we let sinks be anything
#(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True),
(UPat(UOps.SINK), lambda: True),
(UPat(UOps.SINK, dtypes.void), lambda: True),
# PTX LOAD/STORE
(UPat((UOps.LOAD, UOps.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),