assert if expr_idxs return might be outside of int32 (#4157)

This commit is contained in:
chenyu 2024-04-12 14:18:35 -04:00 committed by GitHub
parent 24a27a01a9
commit f6c8032e5d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 0 deletions

View File

@ -750,5 +750,13 @@ class TestShapeTrackerSize(unittest.TestCase):
st = st.shrink(((0, 100), (0, 50)))
self.assertEqual(st.real_size(), 9950) # careful here
class TestIdxs(unittest.TestCase):
def test_check_idx_range(self):
# generated from: (Tensor.rand(4096,599*64) @ Tensor.rand(599*64,1024)).realize()
# TODO: use int64
st = ShapeTracker(views=(View(shape=(4096, 1024, 599, 1), strides=(613376, 599, 1, 0), offset=0, mask=None, contiguous=True),))
with self.assertRaises(AssertionError):
st.expr_idxs()
if __name__ == '__main__':
unittest.main()

View File

@ -89,6 +89,8 @@ class ShapeTracker:
idxs.append((idx//acc)%d)
acc *= d
idx, valid = _expr_view(view, idxs[::-1], valid)
assert not isinstance(idx.min, int) or idx.min >= -2**31, f"idx.min too small. {idx=}, {idx.min=}"
assert not isinstance(idx.max, int) or idx.max < 2**31, f"idx.max too big. {idx=}, {idx.max=}"
return idx, valid
def axis_is_masked(self, axis:int) -> bool: