From d9cb45af09afd8c233ec7b74014fabf9f0b74880 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 16 Aug 2024 21:01:17 -0700 Subject: [PATCH] only axis is masked [run_process_replay] (#6123) --- test/unit/test_shapetracker.py | 13 ++++++++++--- tinygrad/ops.py | 2 +- tinygrad/shape/shapetracker.py | 8 +++++--- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index ed0b6ade..9775a3f1 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -108,13 +108,13 @@ class TestRealDoesntSimplify(unittest.TestCase): self.st = ShapeTracker(( View.create((8, 3, 1, 2, 11, 1), (33, 11, 0, 0, 1, 0), 0, None), View.create((8, 6, 11), (66, 11, 1), 0, None))) - assert self.st.real_strides() == (33, None, 1) + self.assertEqual(self.st.real_strides(), (33, None, 1)) def test_2(self): self.st = ShapeTracker(( View.create((2, 2, 4, 3, 3), (72, 9, 18, -3, -1), 8, None), View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None))) - assert self.st.real_strides() == (None, 18, -3, -1) + self.assertEqual(self.st.real_strides(), (None, 18, -3, -1)) class TestRealStrides(unittest.TestCase): def test_1(self): @@ -131,7 +131,7 @@ class TestRealSimplifies(unittest.TestCase): self.st = self.st.simplify() assert len(self.st.views) == 1 print(self.st.views[-1].strides, st) - assert self.st.views[-1].strides == st + self.assertEqual(self.st.views[-1].strides, st) def test_1(self): self.st = ShapeTracker(( @@ -733,6 +733,13 @@ class TestShapeTracker(unittest.TestCase): self.test_expand() self.test_permute() + def test_axis_is_masked(self): + st = ShapeTracker.from_shape((100, 100, 100, 100)).pad(((0,1),(0,0),(2,0), (0,0))) + assert st.axis_is_masked(0) + assert not st.axis_is_masked(1) + assert st.axis_is_masked(2) + assert not st.axis_is_masked(3) + class TestShapeTrackerSize(unittest.TestCase): def test_simple_size(self): st = ShapeTracker.from_shape((100, 100)) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 613b8673..6301ed6b 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -193,7 +193,7 @@ class UOp: def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]: # NOTE: returned UOp is assumed to be CONST if self.op is UOps.DEFINE_VAR and self.src: return self.src[0], self.src[1] if isinstance(self.src[1].arg, int) else None - if self.op is UOps.RANGE: return self.src[0], self.const(self.src[1].arg-1) if isinstance(self.src[1].arg, int) else None + if self.op is UOps.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax # TODO: UOps.SPECIAL is UOps.DEFINE_VAR if self.op is UOps.SPECIAL: return self.const(0), self.const(self.arg[1]-1) if isinstance(self.arg[1], int) else None if self.op is UOps.CONST: return self, self diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 2dc65a04..7e147173 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -66,7 +66,9 @@ class ShapeTracker: def to_uops(self) -> Tuple[UOp, UOp]: return UOp(UOps.ST_IDX, dtypes.pyint, (), self), UOp(UOps.ST_VALID, dtypes.bool, (), self) - def to_indexed_uops(self, idxs:List[UOp]) -> Tuple[UOp, UOp]: + def to_indexed_uops(self, _idxs:Optional[List[UOp]]=None) -> Tuple[UOp, UOp]: + idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(s)), i) for i,s in enumerate(self.shape)] \ + if _idxs is None else _idxs idx, valid = _uop_view(self.views[-1], idxs, UOp.const(dtypes.bool, True)) for view in reversed(self.views[0:-1]): view = view.minify() @@ -132,8 +134,8 @@ class ShapeTracker: return idx, valid def axis_is_masked(self, axis:int) -> bool: - _, valid = self.expr_idxs() - return f'idx{axis}' in [v.expr for v in valid.vars()] + _, valid = self.to_indexed_uops() + return axis in [x.arg for x in valid.sparents if x.op is UOps.RANGE] def simplify(self) -> ShapeTracker: if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None: