diff --git a/test/test_symbolic_shapetracker.py b/test/test_symbolic_shapetracker.py index d1ffdfb4..37676efc 100644 --- a/test/test_symbolic_shapetracker.py +++ b/test/test_symbolic_shapetracker.py @@ -22,6 +22,18 @@ class TestSymbolic(unittest.TestCase): assert e1.render() == "((y*3)+x)" assert e2.render() == "1" + def test_real_strides_0(self): + st = ShapeTracker(views=(View(shape=(2, (NumNode(1)+Variable('start_pos', 1, 8)), 1, 1), strides=(8, 1, 0, 0), offset=0, mask=((0, 2), (0, Variable('start_pos', 1, 8)), (0, 1), (0, 1)), contiguous=False), View(shape=(2, (NumNode(1)+Variable('start_pos', 1, 8))), strides=((NumNode(1)+Variable('start_pos', 1, 8)), 1), offset=0, mask=None, contiguous=True))) # noqa: E501 + self.assertEqual(st.real_strides(), (8, None)) + + def test_real_strides_1(self): + st = ShapeTracker(views=(View(shape=(3, (NumNode(2)+Variable('i', 1, 10))), strides=(Variable('i', 1, 10), 1), offset=NumNode(0), mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501 + self.assertEqual(st.real_strides(), (Variable('i', 1, 10), None)) + + def test_real_strides_2(self): + st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+Variable('j', 1, 10))), strides=(Variable('i', 1, 10), 1), offset=NumNode(0), mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501 + self.assertEqual(st.real_strides(), (Variable('i', 1, 10), None)) + def test_cat_dim0_strides(self): i = Variable("i", 1, 5).bind(3) j = Variable("j", 1, 5).bind(3)