mirror of https://github.com/commaai/tinygrad.git
assert if expr_idxs return might be outside of int32 (#4157)
This commit is contained in:
parent
24a27a01a9
commit
f6c8032e5d
|
@ -750,5 +750,13 @@ class TestShapeTrackerSize(unittest.TestCase):
|
|||
st = st.shrink(((0, 100), (0, 50)))
|
||||
self.assertEqual(st.real_size(), 9950) # careful here
|
||||
|
||||
class TestIdxs(unittest.TestCase):
|
||||
def test_check_idx_range(self):
|
||||
# generated from: (Tensor.rand(4096,599*64) @ Tensor.rand(599*64,1024)).realize()
|
||||
# TODO: use int64
|
||||
st = ShapeTracker(views=(View(shape=(4096, 1024, 599, 1), strides=(613376, 599, 1, 0), offset=0, mask=None, contiguous=True),))
|
||||
with self.assertRaises(AssertionError):
|
||||
st.expr_idxs()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -89,6 +89,8 @@ class ShapeTracker:
|
|||
idxs.append((idx//acc)%d)
|
||||
acc *= d
|
||||
idx, valid = _expr_view(view, idxs[::-1], valid)
|
||||
assert not isinstance(idx.min, int) or idx.min >= -2**31, f"idx.min too small. {idx=}, {idx.min=}"
|
||||
assert not isinstance(idx.max, int) or idx.max < 2**31, f"idx.max too big. {idx=}, {idx.max=}"
|
||||
return idx, valid
|
||||
|
||||
def axis_is_masked(self, axis:int) -> bool:
|
||||
|
|
Loading…
Reference in New Issue