mirror of https://github.com/commaai/tinygrad.git
minor improvements (#2845)
This commit is contained in:
parent
d086325b1b
commit
b2192b5400
|
@ -46,7 +46,7 @@ def do_stride(st):
|
|||
|
||||
def do_flip(st):
|
||||
c = random.randint(0, len(st.shape)-1)
|
||||
stride = tuple(random.choice([-1,1]) if i==c else 1 for i in range(len(st.shape)))
|
||||
stride = tuple(-1 if i==c else 1 for i in range(len(st.shape)))
|
||||
if DEBUG >= 1: print("st.stride(", stride, ")")
|
||||
st.stride(stride)
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ class TestShapeTrackerBasics(unittest.TestCase):
|
|||
x = x.reshape( (2, 2, 5) )
|
||||
x1 = x.reshape( (4, 5) )
|
||||
x1 = x1.reshape( (2, 2, 5) )
|
||||
assert x == x1
|
||||
assert x == x1.simplify()
|
||||
|
||||
class TestShapeTrackerAdd(unittest.TestCase):
|
||||
def test_simple_add_reshape(self):
|
||||
|
|
|
@ -36,6 +36,7 @@ def expr_idxs(view:View, idxs:Tuple[Node, ...]) -> Node:
|
|||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def merge_views(vm2:View, vm1:View) -> Optional[View]:
|
||||
if vm1.contiguous and vm1.shape == vm2.shape: return vm2
|
||||
if vm2.contiguous: return vm1
|
||||
if vm2.mask or vm1.offset != 0: return None # this isn't supported yet
|
||||
if None in (strides := ShapeTracker((vm2, vm1)).real_strides()): return None
|
||||
|
|
Loading…
Reference in New Issue