From 9ed2b8b818dd04f5a78085aeb4544b62c4175e87 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 6 Sep 2024 05:32:12 -0400 Subject: [PATCH] fix DEFINE_VAR setup in test_uop_graph [run_process_replay] (#6392) making sure arg always have 3 items --- test/test_uop_graph.py | 26 ++++++++++++++------------ tinygrad/ops.py | 3 +-- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 263465a4..4b25b8d6 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -249,7 +249,7 @@ class TestUOpGraph(unittest.TestCase): for i in [2, 4, 8]: vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i))) var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i)) - acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),)) + acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) uops = to_uops_list([wmma]) assert_equiv_uops(uops[0], acc) @@ -258,7 +258,7 @@ class TestUOpGraph(unittest.TestCase): for i in [2, 4, 8]: var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i)) vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i))) - acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),)) + acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) uops = to_uops_list([wmma]) assert_equiv_uops(uops[0], acc) @@ -268,19 +268,21 @@ class TestUOpGraph(unittest.TestCase): for i in [4, 8]: vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) + - tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(Variable(f'tmp{j}', 0.0, 1.0),)) for j in range(i//2))) - var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0),)) - acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),)) + tuple(UOp(UOps.DEFINE_VAR, dtypes.half, + arg=(Variable(f'tmp{j}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2))) + var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) + acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) uops = to_uops_list([wmma]) assert_equiv_uops(uops[-1], wmma) for i in [4, 8]: - var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0),)) + var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) + - tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(Variable(f'tmp{j}', 0.0, 1.0),)) for j in range(i//2))) - acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),)) + tuple(UOp(UOps.DEFINE_VAR, dtypes.half, + arg=(Variable(f'tmp{j}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2))) + acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) uops = to_uops_list([wmma]) assert_equiv_uops(uops[-1], wmma) @@ -288,17 +290,17 @@ class TestUOpGraph(unittest.TestCase): for i in [2, 4, 8]: vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i))) - var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0),)) - acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),)) + var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) + acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) uops = to_uops_list([wmma]) assert_equiv_uops(uops[-1], wmma) for i in [2, 4, 8]: - var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0),)) + var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i))) - acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),)) + acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) uops = to_uops_list([wmma]) assert_equiv_uops(uops[-1], wmma) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 56e9f713..2b3921a7 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -414,8 +414,7 @@ class UOp(MathTrait): @functools.cached_property def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]: # NOTE: returned UOp is assumed to be CONST - # TODO: fix DEFINE_VAR arg in tests and remove checking len(self.arg) - if self.op is UOps.DEFINE_VAR and self.arg and len(self.arg) > 1: return self.arg[1], self.arg[2] if isinstance(self.arg[2].arg, int) else None + if self.op is UOps.DEFINE_VAR and self.arg: return self.arg[1], self.arg[2] if isinstance(self.arg[2].arg, int) else None if self.op is UOps.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax # TODO: UOps.SPECIAL is UOps.DEFINE_VAR if self.op is UOps.SPECIAL: return self.const_like(0), self.const_like(self.arg[1]-1) if isinstance(self.arg[1], int) else None