mirror of https://github.com/commaai/tinygrad.git
only axis is masked [run_process_replay] (#6123)
This commit is contained in:
parent
94aa5f11b5
commit
d9cb45af09
|
@ -108,13 +108,13 @@ class TestRealDoesntSimplify(unittest.TestCase):
|
|||
self.st = ShapeTracker((
|
||||
View.create((8, 3, 1, 2, 11, 1), (33, 11, 0, 0, 1, 0), 0, None),
|
||||
View.create((8, 6, 11), (66, 11, 1), 0, None)))
|
||||
assert self.st.real_strides() == (33, None, 1)
|
||||
self.assertEqual(self.st.real_strides(), (33, None, 1))
|
||||
|
||||
def test_2(self):
|
||||
self.st = ShapeTracker((
|
||||
View.create((2, 2, 4, 3, 3), (72, 9, 18, -3, -1), 8, None),
|
||||
View.create((4, 4, 3, 3), (36, 9, 3, 1), 0, None)))
|
||||
assert self.st.real_strides() == (None, 18, -3, -1)
|
||||
self.assertEqual(self.st.real_strides(), (None, 18, -3, -1))
|
||||
|
||||
class TestRealStrides(unittest.TestCase):
|
||||
def test_1(self):
|
||||
|
@ -131,7 +131,7 @@ class TestRealSimplifies(unittest.TestCase):
|
|||
self.st = self.st.simplify()
|
||||
assert len(self.st.views) == 1
|
||||
print(self.st.views[-1].strides, st)
|
||||
assert self.st.views[-1].strides == st
|
||||
self.assertEqual(self.st.views[-1].strides, st)
|
||||
|
||||
def test_1(self):
|
||||
self.st = ShapeTracker((
|
||||
|
@ -733,6 +733,13 @@ class TestShapeTracker(unittest.TestCase):
|
|||
self.test_expand()
|
||||
self.test_permute()
|
||||
|
||||
def test_axis_is_masked(self):
|
||||
st = ShapeTracker.from_shape((100, 100, 100, 100)).pad(((0,1),(0,0),(2,0), (0,0)))
|
||||
assert st.axis_is_masked(0)
|
||||
assert not st.axis_is_masked(1)
|
||||
assert st.axis_is_masked(2)
|
||||
assert not st.axis_is_masked(3)
|
||||
|
||||
class TestShapeTrackerSize(unittest.TestCase):
|
||||
def test_simple_size(self):
|
||||
st = ShapeTracker.from_shape((100, 100))
|
||||
|
|
|
@ -193,7 +193,7 @@ class UOp:
|
|||
def _min_max(self) -> Tuple[Optional[UOp], Optional[UOp]]:
|
||||
# NOTE: returned UOp is assumed to be CONST
|
||||
if self.op is UOps.DEFINE_VAR and self.src: return self.src[0], self.src[1] if isinstance(self.src[1].arg, int) else None
|
||||
if self.op is UOps.RANGE: return self.src[0], self.const(self.src[1].arg-1) if isinstance(self.src[1].arg, int) else None
|
||||
if self.op is UOps.RANGE: return self.src[0].vmin, (self.src[1]-1).vmax
|
||||
# TODO: UOps.SPECIAL is UOps.DEFINE_VAR
|
||||
if self.op is UOps.SPECIAL: return self.const(0), self.const(self.arg[1]-1) if isinstance(self.arg[1], int) else None
|
||||
if self.op is UOps.CONST: return self, self
|
||||
|
|
|
@ -66,7 +66,9 @@ class ShapeTracker:
|
|||
|
||||
def to_uops(self) -> Tuple[UOp, UOp]: return UOp(UOps.ST_IDX, dtypes.pyint, (), self), UOp(UOps.ST_VALID, dtypes.bool, (), self)
|
||||
|
||||
def to_indexed_uops(self, idxs:List[UOp]) -> Tuple[UOp, UOp]:
|
||||
def to_indexed_uops(self, _idxs:Optional[List[UOp]]=None) -> Tuple[UOp, UOp]:
|
||||
idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(s)), i) for i,s in enumerate(self.shape)] \
|
||||
if _idxs is None else _idxs
|
||||
idx, valid = _uop_view(self.views[-1], idxs, UOp.const(dtypes.bool, True))
|
||||
for view in reversed(self.views[0:-1]):
|
||||
view = view.minify()
|
||||
|
@ -132,8 +134,8 @@ class ShapeTracker:
|
|||
return idx, valid
|
||||
|
||||
def axis_is_masked(self, axis:int) -> bool:
|
||||
_, valid = self.expr_idxs()
|
||||
return f'idx{axis}' in [v.expr for v in valid.vars()]
|
||||
_, valid = self.to_indexed_uops()
|
||||
return axis in [x.arg for x in valid.sparents if x.op is UOps.RANGE]
|
||||
|
||||
def simplify(self) -> ShapeTracker:
|
||||
if len(self.views) >= 2 and (new_view := self.views[-2] + self.views[-1]) is not None:
|
||||
|
|
Loading…
Reference in New Issue