add test to validate lazyops dims (#5845)

This commit is contained in:
George Hotz 2024-07-31 12:59:38 -07:00 committed by GitHub
parent 4fe5b95568
commit 8672a9db3f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 3 additions and 1 deletions

View File

@ -390,7 +390,7 @@ class TestLinearizerFailures(unittest.TestCase):
for st in k.uops.sink.src: self.assertEqual(len(st.src), 4)
self.assertLessEqual(len(ifs[0].src[0].sparents), 16)
@unittest.expectedFailure
@unittest.skip("this is an invalid lazyop")
def test_failure_45(self):
ast = LazyOp(MetaOps.KERNEL, arg=None, src=(
LazyOp(BufferOps.STORE, arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(2, 3, 1, 1, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))), src=(

View File

@ -165,4 +165,6 @@ def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
assert out.op is BufferOps.STORE, f"kernels must have stores as the output, got {out.op}"
assert out.arg.st.size == ast.src[-1].arg.st.size, f"outputs must have the same size, got {out.arg.st.size}"
assert_valid(out, out.arg.st)
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
return sts