mirror of https://github.com/commaai/tinygrad.git
tests of real_stride of symbolic shape (#6409)
these would have failed in #6365
This commit is contained in:
parent
0fbd141038
commit
ad05302232
|
@ -22,6 +22,18 @@ class TestSymbolic(unittest.TestCase):
|
|||
assert e1.render() == "((y*3)+x)"
|
||||
assert e2.render() == "1"
|
||||
|
||||
def test_real_strides_0(self):
|
||||
st = ShapeTracker(views=(View(shape=(2, (NumNode(1)+Variable('start_pos', 1, 8)), 1, 1), strides=(8, 1, 0, 0), offset=0, mask=((0, 2), (0, Variable('start_pos', 1, 8)), (0, 1), (0, 1)), contiguous=False), View(shape=(2, (NumNode(1)+Variable('start_pos', 1, 8))), strides=((NumNode(1)+Variable('start_pos', 1, 8)), 1), offset=0, mask=None, contiguous=True))) # noqa: E501
|
||||
self.assertEqual(st.real_strides(), (8, None))
|
||||
|
||||
def test_real_strides_1(self):
|
||||
st = ShapeTracker(views=(View(shape=(3, (NumNode(2)+Variable('i', 1, 10))), strides=(Variable('i', 1, 10), 1), offset=NumNode(0), mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
|
||||
self.assertEqual(st.real_strides(), (Variable('i', 1, 10), None))
|
||||
|
||||
def test_real_strides_2(self):
|
||||
st = ShapeTracker(views=(View(shape=(3, (Variable('i', 1, 10)+Variable('j', 1, 10))), strides=(Variable('i', 1, 10), 1), offset=NumNode(0), mask=((0, 3), (0, Variable('i', 1, 10))), contiguous=False),)) # noqa: E501
|
||||
self.assertEqual(st.real_strides(), (Variable('i', 1, 10), None))
|
||||
|
||||
def test_cat_dim0_strides(self):
|
||||
i = Variable("i", 1, 5).bind(3)
|
||||
j = Variable("j", 1, 5).bind(3)
|
||||
|
|
Loading…
Reference in New Issue