only axis is masked [run_process_replay] (#6123)

This commit is contained in:
George Hotz 2024-08-16 21:01:17 -07:00 committed by GitHub
parent 94aa5f11b5
commit d9cb45af09
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 16 additions and 7 deletions

View File

@ -108,13 +108,13 @@ class TestRealDoesntSimplify(unittest.TestCase):
self.st = ShapeTracker((
View.create((8, 3, 1, 2, 11, 1), (33, 11, 0, 0, 1, 0), 0, None),
View.create((8, 6, 11), (66, 11, 1), 0, None)))
assert self.st.real_strides() == (33, None, 1)
self.assertEqual(self.st.real_strides(), (33, None, 1))
def test_2(self):
self.st = ShapeTracker((
View.create((2, 2, 4, 3, 3), (72, 9, 18, -3, -1), 8, None),
View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None)))
assert self.st.real_strides() == (None, 18, -3, -1)
self.assertEqual(self.st.real_strides(), (None, 18, -3, -1))
class TestRealStrides(unittest.TestCase):
def test_1(self):
@ -131,7 +131,7 @@ class TestRealSimplifies(unittest.TestCase):
self.st = self.st.simplify()
assert len(self.st.views) == 1
print(self.st.views[-1].strides, st)
assert self.st.views[-1].strides == st
self.assertEqual(self.st.views[-1].strides, st)
def test_1(self):
self.st = ShapeTracker((
@ -733,6 +733,13 @@ class TestShapeTracker(unittest.TestCase):
self.test_expand()
self.test_permute()
def test_axis_is_masked(self):
st = ShapeTracker.from_shape((100, 100, 100, 100)).pad(((0,1),(0,0),(2,0), (0,0)))
assert st.axis_is_masked(0)
assert not st.axis_is_masked(1)
assert st.axis_is_masked(2)
assert not st.axis_is_masked(3)
class TestShapeTrackerSize(unittest.TestCase):
def test_simple_size(self):
st = ShapeTracker.from_shape((100, 100))

View File

@ -193,7 +193,7 @@ class UOp:
def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]:
# NOTE: returned UOp is assumed to be CONST
if self.op is UOps.DEFINE_VAR and self.src: return self.src[0], self.src[1] if isinstance(self.src[1].arg, int) else None
if self.op is UOps.RANGE: return self.src[0], self.const(self.src[1].arg-1) if isinstance(self.src[1].arg, int) else None
if self.op is UOps.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
if self.op is UOps.SPECIAL: return self.const(0), self.const(self.arg[1]-1) if isinstance(self.arg[1], int) else None
if self.op is UOps.CONST: return self, self

View File

@ -66,7 +66,9 @@ class ShapeTracker:
def to_uops(self) -> Tuple[UOp, UOp]: return UOp(UOps.ST_IDX, dtypes.pyint, (), self), UOp(UOps.ST_VALID, dtypes.bool, (), self)
def to_indexed_uops(self, idxs:List[UOp]) -> Tuple[UOp, UOp]:
def to_indexed_uops(self, _idxs:Optional[List[UOp]]=None) -> Tuple[UOp, UOp]:
idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(s)), i) for i,s in enumerate(self.shape)] \
if _idxs is None else _idxs
idx, valid = _uop_view(self.views[-1], idxs, UOp.const(dtypes.bool, True))
for view in reversed(self.views[0:-1]):
view = view.minify()
@ -132,8 +134,8 @@ class ShapeTracker:
return idx, valid
def axis_is_masked(self, axis:int) -> bool:
_, valid = self.expr_idxs()
return f'idx{axis}' in [v.expr for v in valid.vars()]
_, valid = self.to_indexed_uops()
return axis in [x.arg for x in valid.sparents if x.op is UOps.RANGE]
def simplify(self) -> ShapeTracker:
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None: