tests of real_stride of symbolic shape (#6409)

these would have failed in #6365
This commit is contained in:
chenyu 2024-09-08 21:37:19 -04:00 committed by GitHub
parent 0fbd141038
commit ad05302232
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 12 additions and 0 deletions

View File

@ -22,6 +22,18 @@ class TestSymbolic(unittest.TestCase):
assert e1.render() == "((y*3)+x)"
assert e2.render() == "1"
def test_real_strides_0(self):
st = ShapeTracker(views=(View(shape=(2, (NumNode(1)+Variable('start_pos', 1, 8)), 1, 1), strides=(8, 1, 0, 0), offset=0, mask=((0, 2), (0, Variable('start_pos', 1, 8)), (0, 1), (0, 1)), contiguous=False), View(shape=(2, (NumNode(1)+Variable('start_pos', 1, 8))), strides=((NumNode(1)+Variable('start_pos', 1, 8)), 1), offset=0, mask=None, contiguous=True))) # noqa: E501
self.assertEqual(st.real_strides(), (8, None))
def test_real_strides_1(self):
st = ShapeTracker(views=(View(shape=(3, (NumNode(2)+Variable('i', 1, 10))), strides=(Variable('i', 1, 10), 1), offset=NumNode(0), mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
self.assertEqual(st.real_strides(), (Variable('i', 1, 10), None))
def test_real_strides_2(self):
st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+Variable('j', 1, 10))), strides=(Variable('i', 1, 10), 1), offset=NumNode(0), mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
self.assertEqual(st.real_strides(), (Variable('i', 1, 10), None))
def test_cat_dim0_strides(self):
i = Variable("i", 1, 5).bind(3)
j = Variable("j", 1, 5).bind(3)