mirror of https://github.com/commaai/tinygrad.git
191 lines
6.7 KiB
Python
191 lines
6.7 KiB
Python
import unittest
|
|
from typing import List
|
|
from tinygrad.helpers import prod
|
|
from tinygrad.shape.view import View
|
|
from tinygrad.shape.shapetracker import ShapeTracker
|
|
from tinygrad.shape.symbolic import Variable, sym_infer
|
|
|
|
class MultiShapeTracker:
|
|
def __init__(self, sts:List[ShapeTracker]): self.sts = sts
|
|
@property
|
|
def shape(self): return self.sts[0].shape
|
|
def reshape(self, arg): self.sts = [x.reshape(arg) for x in self.sts]
|
|
def permute(self, arg): self.sts = [x.permute(arg) for x in self.sts]
|
|
def expand(self, arg): self.sts = [x.expand(arg) for x in self.sts]
|
|
def shrink(self, arg): self.sts = [x.shrink(arg) for x in self.sts]
|
|
def stride(self, arg): self.sts = [x.stride(arg) for x in self.sts]
|
|
def pad(self, arg): self.sts = [x.pad(arg) for x in self.sts]
|
|
|
|
def st_equal(st1:ShapeTracker, st2:ShapeTracker) -> bool:
|
|
if st1.shape != st2.shape: return False
|
|
if st1 == st2: return True
|
|
idx = Variable("idx", 0, prod(st1.shape)-1)
|
|
st1_idx, st1_valid = st1.reshape((st1.size,)).expr_idxs([idx])
|
|
st2_idx, st2_valid = st2.reshape((st2.size,)).expr_idxs([idx])
|
|
for i in range(idx.min, idx.max + 1):
|
|
st1_off = sym_infer(st1_idx, {idx: i})
|
|
st2_off = sym_infer(st2_idx, {idx: i})
|
|
st1_v = sym_infer(st1_valid, {idx: i})
|
|
st2_v = sym_infer(st2_valid, {idx: i})
|
|
if st1_v != st2_v or (st1_off != st2_off and st1_v):
|
|
print(f"ST MISMATCH @ {i}, {st1_v=} != {st2_v=}, {st1_off=} != {st2_off=}")
|
|
print(st1)
|
|
print(st2)
|
|
return False
|
|
return True
|
|
|
|
class TestShapeTrackerBasics(unittest.TestCase):
|
|
def test_pad_shrink_removes_mask(self):
|
|
a = ShapeTracker.from_shape((10, 10))
|
|
a = a.pad(((0,2), (0,2)))
|
|
a = a.shrink(((0,10), (0,10)))
|
|
assert len(a.views) == 1 and a.views[-1].mask is None
|
|
|
|
def test_pad_shrink_leaves_mask(self):
|
|
a = ShapeTracker.from_shape((10, 10))
|
|
a = a.pad(((0,2), (0,2)))
|
|
a = a.shrink(((0,10), (0,11)))
|
|
assert len(a.views) == 1 and a.views[-1].mask is not None
|
|
|
|
def test_reshape_makes_same(self):
|
|
a = ShapeTracker.from_shape((2, 5))
|
|
x = a.pad( ((2, 0), (0, 0)) )
|
|
x = x.reshape( (2, 2, 5) )
|
|
x1 = x.reshape( (4, 5) )
|
|
x1 = x1.reshape( (2, 2, 5) )
|
|
assert x == x1.simplify()
|
|
|
|
def test_simplify_is_correct(self):
|
|
multiv = ShapeTracker(views=(View(shape=(15, 3), strides=(9, 1), offset=6, mask=None, contiguous=False),
|
|
View(shape=(4, 3), strides=(12, 4), offset=0, mask=None, contiguous=False)))
|
|
assert st_equal(multiv, multiv.simplify())
|
|
|
|
class TestShapeTrackerAdd(unittest.TestCase):
|
|
def test_simple_add_reshape(self):
|
|
a = ShapeTracker.from_shape((10, 10))
|
|
a = a.reshape((100,))
|
|
b = ShapeTracker.from_shape((100,))
|
|
assert a+b == b
|
|
|
|
def test_simple_add_permute(self):
|
|
a = ShapeTracker.from_shape((10, 10))
|
|
a = a.permute((1,0))
|
|
b = ShapeTracker.from_shape((10, 10))
|
|
b = b.permute((1,0))
|
|
assert a+b == ShapeTracker.from_shape((10, 10))
|
|
|
|
def test_plus_real1(self):
|
|
st = MultiShapeTracker([ShapeTracker.from_shape((15, 9))])
|
|
st.shrink( ((0, 15), (6, 9)) )
|
|
backup = st.sts[0]
|
|
st.sts.append(ShapeTracker.from_shape(backup.shape))
|
|
st.reshape( (45,) )
|
|
st.stride( (4,) )
|
|
st.reshape( (4, 3) )
|
|
assert st_equal(backup + st.sts[1], st.sts[0])
|
|
|
|
def test_off_by_one(self):
|
|
st1 = ShapeTracker(views=(View(shape=(5,), strides=(1,), offset=0, mask=None, contiguous=True),
|
|
View(shape=(5,), strides=(1,), offset=0, mask=None, contiguous=True)))
|
|
st2 = ShapeTracker(views=(View(shape=(4,), strides=(1,), offset=0, mask=None, contiguous=True),
|
|
View(shape=(5,), strides=(1,), offset=0, mask=None, contiguous=True)))
|
|
assert not (st_equal(st1, st2))
|
|
|
|
class TestShapeTrackerAddVariable(unittest.TestCase):
|
|
def test_self_add(self):
|
|
j = Variable("j", 0, 20).bind(10)
|
|
a = ShapeTracker.from_shape((10,10))
|
|
x = a.reshape((10, j))
|
|
out = x + x
|
|
assert out == x
|
|
|
|
def test_self_add_reshape(self):
|
|
j = Variable("j", 0, 20).bind(10)
|
|
a = ShapeTracker.from_shape((10,10))
|
|
x = a.reshape((10, j))
|
|
out = x.reshape((5, 2, j)) + x
|
|
assert out == x
|
|
|
|
def test_merge_symbolic_views(self):
|
|
var_i = Variable('i', 1, 10)
|
|
var_j = Variable('i', 1, 10)
|
|
vm1 = View(shape=(var_i, var_j, 3), strides=(3, 0, 1), offset=0, mask=None, contiguous=False)
|
|
vm2 = View(shape=(var_i, var_j, 3), strides=(var_j*3, 3, 1), offset=0, mask=None, contiguous=True)
|
|
ShapeTracker((vm1,)) + ShapeTracker((vm2,))
|
|
|
|
@unittest.skip("two vars not supported")
|
|
def test_merge_symbolic_views_2(self):
|
|
var_i = Variable('i', 1, 10)
|
|
var_j = Variable('j', 1, 10)
|
|
vm1 = View(shape=(var_i, var_j), strides=(0, 0), offset=0, mask=None, contiguous=False)
|
|
vm2 = View(shape=(var_i, var_j), strides=(var_j, 1), offset=0, mask=None, contiguous=True)
|
|
ret = (ShapeTracker((vm1,)) + ShapeTracker((vm2,))).reshape((var_i, var_j, 1))
|
|
ret_2 = ShapeTracker((vm1,)) + ShapeTracker((vm2,)).reshape((var_i, var_j, 1))
|
|
assert ret == ret_2
|
|
|
|
class TestShapeTrackerInvert(unittest.TestCase):
|
|
def test_invert_reshape(self):
|
|
a = ShapeTracker.from_shape((10, 10))
|
|
x = a.reshape((5, 20))
|
|
ap = ShapeTracker.from_shape(x.shape) + x.invert(a.shape)
|
|
assert ap == a, f"{ap} != {a}"
|
|
|
|
def test_invert_permute(self):
|
|
a = ShapeTracker.from_shape((5, 20))
|
|
x = a.permute((1,0))
|
|
ap = x + x.invert(a.shape)
|
|
assert ap == a, f"{ap} != {a}"
|
|
|
|
def test_invert_permute_3(self):
|
|
a = ShapeTracker.from_shape((8, 4, 5))
|
|
x = a.permute((1,2,0))
|
|
ap = x + x.invert(a.shape)
|
|
assert ap == a, f"{ap} != {a}"
|
|
|
|
def test_invert_real1(self):
|
|
a = ShapeTracker.from_shape((3, 6, 10))
|
|
x = a.reshape( (3, 3, 2, 10) )
|
|
x = x.permute( (2, 1, 3, 0) )
|
|
ap = x + x.invert(a.shape)
|
|
assert ap == a, f"{ap} != {a}"
|
|
|
|
def test_cant_invert_expand(self):
|
|
a = ShapeTracker.from_shape((10, 1))
|
|
x = a.expand((10,10))
|
|
assert x.invert(a.shape) is None
|
|
|
|
def test_cant_invert_shrink(self):
|
|
a = ShapeTracker.from_shape((10, 10))
|
|
x = a.shrink(((0,10),(2,8)))
|
|
assert x.invert(a.shape) is None
|
|
|
|
def test_can_invert_flip(self):
|
|
a = ShapeTracker.from_shape((20, 10))
|
|
x = a.stride((-1,1))
|
|
ap = x + x.invert(a.shape)
|
|
assert st_equal(ap, a)
|
|
|
|
def test_can_invert_flip_permute(self):
|
|
a = ShapeTracker.from_shape((20, 10))
|
|
x = a.permute((1,0))
|
|
x = x.stride((-1,1))
|
|
ap = x + x.invert(a.shape)
|
|
assert st_equal(ap, a)
|
|
|
|
def test_cant_invert_stride(self):
|
|
a = ShapeTracker.from_shape((10, 10))
|
|
x = a.stride((2,2))
|
|
assert x.invert(a.shape) is None
|
|
|
|
def test_invert_failure(self):
|
|
a = ShapeTracker.from_shape((2, 5))
|
|
x = a.pad( ((2, 0), (0, 0)) )
|
|
x = x.reshape( (2, 2, 5) )
|
|
x = x.reshape( (4, 5) )
|
|
ap = x + x.invert(a.shape)
|
|
assert st_equal(ap, a)
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|
|
|