minor improvements (#2845)

This commit is contained in:
George Hotz 2023-12-18 22:09:08 -08:00 committed by GitHub
parent d086325b1b
commit b2192b5400
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 3 additions and 2 deletions

View File

@ -46,7 +46,7 @@ def do_stride(st):
def do_flip(st):
c = random.randint(0, len(st.shape)-1)
stride = tuple(random.choice([-1,1]) if i==c else 1 for i in range(len(st.shape)))
stride = tuple(-1 if i==c else 1 for i in range(len(st.shape)))
if DEBUG >= 1: print("st.stride(", stride, ")")
st.stride(stride)

View File

@ -21,7 +21,7 @@ class TestShapeTrackerBasics(unittest.TestCase):
x = x.reshape( (2, 2, 5) )
x1 = x.reshape( (4, 5) )
x1 = x1.reshape( (2, 2, 5) )
assert x == x1
assert x == x1.simplify()
class TestShapeTrackerAdd(unittest.TestCase):
def test_simple_add_reshape(self):

View File

@ -36,6 +36,7 @@ def expr_idxs(view:View, idxs:Tuple[Node, ...]) -> Node:
@functools.lru_cache(maxsize=None)
def merge_views(vm2:View, vm1:View) -> Optional[View]:
if vm1.contiguous and vm1.shape == vm2.shape: return vm2
if vm2.contiguous: return vm1
if vm2.mask or vm1.offset != 0: return None # this isn't supported yet
if None in (strides := ShapeTracker((vm2, vm1)).real_strides()): return None