From b36a7273c61c4315d5d01925a0d2a67534bcb965 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 19 Aug 2024 00:34:52 -0400 Subject: [PATCH] RUF018 assignment-in-assert [run_process_replay] (#6172) assertion should not have side effect or `-O` breaks. initially just wanted to fix the one in rearrange, but it also made some long lines less long --- ruff.toml | 1 + test/external/external_test_amd.py | 3 +- test/external/external_test_hcq.py | 45 ++++++++++++++++++++---------- test/external/external_test_nv.py | 6 ++-- test/test_assign.py | 6 ++-- test/test_hcq.py | 18 ++++++++---- test/test_linearizer.py | 3 +- tinygrad/codegen/uopgraph.py | 4 +-- tinygrad/device.py | 4 +-- tinygrad/helpers.py | 3 +- tinygrad/lazy.py | 3 +- tinygrad/ops.py | 7 +++-- tinygrad/tensor.py | 5 ++-- 13 files changed, 70 insertions(+), 38 deletions(-) diff --git a/ruff.toml b/ruff.toml index ad26e413..83c85189 100644 --- a/ruff.toml +++ b/ruff.toml @@ -30,6 +30,7 @@ lint.select = [ "A", # builtin-variable-shadowing, builtin-argument-shadowing, builtin-attribute-shadowing "SIM105", # suppressible-exception "FURB110",# if-exp-instead-of-or-operator + "RUF018", # assignment-in-assert ] line-length = 150 diff --git a/test/external/external_test_amd.py b/test/external/external_test_amd.py index bd1206b1..aabff989 100644 --- a/test/external/external_test_amd.py +++ b/test/external/external_test_amd.py @@ -20,7 +20,8 @@ class TestAMD(unittest.TestCase): global_size=TestAMD.d0_runner.global_size, local_size=TestAMD.d0_runner.local_size) TestAMD.d0_runner.clprg(TestAMD.a.lazydata.buffer._buf, TestAMD.b.lazydata.buffer._buf, global_size=TestAMD.d0_runner.global_size, local_size=TestAMD.d0_runner.local_size) - assert (val:=TestAMD.a.lazydata.buffer.as_buffer().cast("f")[0]) == 4000.0, f"got val {val}" + val = TestAMD.a.lazydata.buffer.as_buffer().cast("f")[0] + assert val == 4000.0, f"got val {val}" if __name__ == "__main__": unittest.main() diff --git a/test/external/external_test_hcq.py b/test/external/external_test_hcq.py index a7be0047..6d3c7045 100644 --- a/test/external/external_test_hcq.py +++ b/test/external/external_test_hcq.py @@ -66,7 +66,8 @@ class TestHCQ(unittest.TestCase): q.submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - assert (val:=TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]) == 2000.0, f"got val {val}" + val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0] + assert val == 2000.0, f"got val {val}" def test_run_1000_times(self): temp_signal = TestHCQ.d0._alloc_signal(value=0) @@ -81,7 +82,8 @@ class TestHCQ(unittest.TestCase): TestHCQ.compute_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - assert (val:=TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]) == 2000.0, f"got val {val}" + val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0] + assert val == 2000.0, f"got val {val}" def test_run_to_3(self): temp_signal = TestHCQ.d0._alloc_signal(value=0) @@ -94,7 +96,8 @@ class TestHCQ(unittest.TestCase): q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 3.0, f"got val {val}" + val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0] + assert val == 3.0, f"got val {val}" def test_update_exec(self): q = TestHCQ.compute_queue() @@ -104,8 +107,10 @@ class TestHCQ(unittest.TestCase): q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}" - assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}, should not be updated" + val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0] + assert val == 1.0, f"got val {val}" + val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1] + assert val == 0.0, f"got val {val}, should not be updated" @unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind") def test_bind_run(self): @@ -122,7 +127,8 @@ class TestHCQ(unittest.TestCase): TestHCQ.compute_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - assert (val:=TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]) == 2000.0, f"got val {val}" + val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0] + assert val == 2000.0, f"got val {val}" @unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind") def test_update_exec_binded(self): @@ -136,8 +142,10 @@ class TestHCQ(unittest.TestCase): q.submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}" - assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}, should not be updated" + val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0] + assert val == 1.0, f"got val {val}" + val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1] + assert val == 0.0, f"got val {val}, should not be updated" @unittest.skipIf(CI, "Can't handle async update on CPU") def test_wait_signal(self): @@ -167,7 +175,8 @@ class TestHCQ(unittest.TestCase): q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}" + val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0] + assert val == 1.0, f"got val {val}" def test_submit_empty_queues(self): TestHCQ.compute_queue().submit(TestHCQ.d0) @@ -198,7 +207,8 @@ class TestHCQ(unittest.TestCase): q.submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}" + val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0] + assert val == 1.0, f"got val {val}" def test_copy_1000_times(self): q = TestHCQ.copy_queue() @@ -212,7 +222,8 @@ class TestHCQ(unittest.TestCase): # confirm the signal didn't exceed the put value with self.assertRaises(RuntimeError): TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value + 1, timeout=50) - assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}" + val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1] + assert val == 0.0, f"got val {val}" def test_copy(self): q = TestHCQ.copy_queue() @@ -221,7 +232,8 @@ class TestHCQ(unittest.TestCase): q.submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 1.0, f"got val {val}" + val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1] + assert val == 1.0, f"got val {val}" @unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind") def test_bind_copy(self): @@ -237,7 +249,8 @@ class TestHCQ(unittest.TestCase): # confirm the signal didn't exceed the put value with self.assertRaises(RuntimeError): TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value + 1, timeout=50) - assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}" + val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1] + assert val == 0.0, f"got val {val}" def test_copy_bandwidth(self): # THEORY: the bandwidth is low here because it's only using one SDMA queue. I suspect it's more stable like this at least. @@ -276,7 +289,8 @@ class TestHCQ(unittest.TestCase): q.submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - assert (val:=TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}" + val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0] + assert val == 1.0, f"got val {val}" def test_cross_device_signal(self): d1 = Device[f"{Device.DEFAULT}:1"] @@ -306,7 +320,8 @@ class TestHCQ(unittest.TestCase): q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0) TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}" + val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0] + assert val == 1.0, f"got val {val}" if __name__ == "__main__": unittest.main() diff --git a/test/external/external_test_nv.py b/test/external/external_test_nv.py index 1aebcd3a..d4d92428 100644 --- a/test/external/external_test_nv.py +++ b/test/external/external_test_nv.py @@ -46,7 +46,8 @@ class TestNV(unittest.TestCase): ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SIN, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.ulong, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501 temp_runner = get_runner(TestNV.d0.dname, (ast,)) temp_runner([TestNV.b.lazydata.buffer, TestNV.along.lazydata.buffer], var_vals={}) - assert abs((val:=TestNV.b.lazydata.buffer.as_buffer().cast("f")[0]) - 0.80647) < 0.001, f"got val {val}" + val = TestNV.b.lazydata.buffer.as_buffer().cast("f")[0] + assert abs(val - 0.80647) < 0.001, f"got val {val}" def test_kernargs_no_oob_access(self): kernargs_start = TestNV.d0._gpu_alloc((2 << 20), map_to_cpu=True).va_addr @@ -59,7 +60,8 @@ class TestNV(unittest.TestCase): q.signal(TestNV.d0.timeline_signal, TestNV.d0.timeline_value).submit(TestNV.d0) TestNV.d0._wait_signal(TestNV.d0.timeline_signal, TestNV.d0.timeline_value) TestNV.d0.timeline_value += 1 - assert (val:=TestNV.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}" + val = TestNV.b.lazydata.buffer.as_buffer().cast("f")[0] + assert val == 1.0, f"got val {val}" if __name__ == "__main__": unittest.main() diff --git a/test/test_assign.py b/test/test_assign.py index 60272339..0f4df9d8 100644 --- a/test/test_assign.py +++ b/test/test_assign.py @@ -57,10 +57,12 @@ class TestAssign(unittest.TestCase): x.realize() x = Tensor([0]) f(x) - assert (out:=x.item()) == 1, f"expected 1, got {out}" + out = x.item() + assert out == 1, f"expected 1, got {out}" x = Tensor([0]) f(x) - assert (out:=x.item()) == 1, f"expected 1, got {out}" + out = x.item() + assert out == 1, f"expected 1, got {out}" def test_assign_add_jit(self): @TinyJit diff --git a/test/test_hcq.py b/test/test_hcq.py index 4b60b0fb..066c600e 100644 --- a/test/test_hcq.py +++ b/test/test_hcq.py @@ -107,7 +107,8 @@ class TestHCQ(unittest.TestCase): TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}" + val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0] + assert val == 1.0, f"got val {val}" def test_exec_2_kernels_100_times(self): q = TestHCQ.d0.hw_compute_queue_t() @@ -120,7 +121,8 @@ class TestHCQ(unittest.TestCase): q.update_wait(0, value=TestHCQ.d0.timeline_value - 1).update_signal(3, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0) TestHCQ.d0.timeline_value += 1 - assert (val:=TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]) == 200.0, f"got val {val}" + val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0] + assert val == 200.0, f"got val {val}" def test_exec_update(self): q = TestHCQ.d0.hw_compute_queue_t() @@ -132,8 +134,10 @@ class TestHCQ(unittest.TestCase): TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}" - assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}, should not be updated" + val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0] + assert val == 1.0, f"got val {val}" + val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1] + assert val == 0.0, f"got val {val}, should not be updated" def test_exec_update_fuzz(self): a = Tensor.randint((3, 3, 3), dtype=dtypes.int, device=Device.DEFAULT).realize() @@ -178,7 +182,8 @@ class TestHCQ(unittest.TestCase): TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 1.0, f"got val {val}" + val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1] + assert val == 1.0, f"got val {val}" def test_copy_long(self): if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue") @@ -211,7 +216,8 @@ class TestHCQ(unittest.TestCase): TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value) TestHCQ.d0.timeline_value += 1 - assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 1.0, f"got val {val}" + val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1] + assert val == 1.0, f"got val {val}" def test_update_copy_long(self): if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue") diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 05c6c9f4..ead71ef8 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -1532,7 +1532,8 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:List[Buffer], opts=[] for opt in opts: k.apply_opt(opt) if expected_color_size is not None: - assert (cs:=list(zip(k.colors(), k.full_shape))) == expected_color_size, f"expected={expected_color_size} got={cs}" + cs = list(zip(k.colors(), k.full_shape)) + assert cs == expected_color_size, f"expected={expected_color_size} got={cs}" prg = get_prg(k) for buf,_ in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled prg.exec(real_bufs) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index bef4d3c3..b7305cd5 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -598,8 +598,8 @@ def linearize_uop(sink_in:Union[UOp, List[UOp]], opts:Optional[Renderer]=None, s # NOTE: multiple identical stores to DEFINE_LOCAL is okay # NOTE: for PTX you have to propogate through some the calculations to determine if it is a store to DEFINE_LOCAL def _islocalbuf(u: UOp): return u.op is UOps.DEFINE_LOCAL or any(_islocalbuf(x) for x in u.src if u.op in [UOps.ALU, UOps.CAST]) - assert len(all_stores := [x.src[0:2]+x.src[3:] for x in _uops if x.op is UOps.STORE and not _islocalbuf(x.src[0])]) \ - == len(dedup(all_stores)), "repeated stores in uops" + all_stores = [x.src[0:2]+x.src[3:] for x in _uops if x.op is UOps.STORE and not _islocalbuf(x.src[0])] + assert len(all_stores) == len(dedup(all_stores)), "repeated stores in uops" except AssertionError as e: print_uops(_uops) if not CI: diff --git a/tinygrad/device.py b/tinygrad/device.py index 7a56c05f..dfaa7cc7 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -19,8 +19,8 @@ class _Device: def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix)) @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none def __get_canonicalized_item(self, ix:str) -> Compiled: - assert ((cpn:=multiprocessing.current_process().name) == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], \ - f"can only open device {ix} from parent, not {cpn}" + cpn = multiprocessing.current_process().name + assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], f"can only open device {ix} from parent, not {cpn}" x = ix.split(":")[0].upper() ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501 if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}") diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index c28c1df1..5e7a117e 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -39,7 +39,8 @@ def round_up(num, amt:int): return (num+amt-1)//amt * amt def data64(data: int) -> Tuple[int, int]: return (data >> 32, data & 0xFFFFFFFF) def data64_le(data: int) -> Tuple[int, int]: return (data & 0xFFFFFFFF, data >> 32) def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]: - assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" # noqa: E501 + kvs = set([(k,v) for d in ds for k,v in d.items()]) + assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" return {k:v for d in ds for k,v in d.items()} def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]]: a:List[T] = [] diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 928b9751..0297ac07 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -141,7 +141,8 @@ class LazyBuffer: srcs.append(root._view(s.base.contiguous_child[1])) else: srcs.append(s) - assert all_same(dts:=[x.dtype.scalar() for x in (srcs[1:] if op is TernaryOps.WHERE else srcs)]), f"all dtypes must match {dts} on {op}" + if not all_same(dts:=[x.dtype.scalar() for x in (srcs[1:] if op is TernaryOps.WHERE else srcs)]): + raise AssertionError(f"all dtypes must match {dts} on {op}") assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}" if op is TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool" if op is UnaryOps.NEG: assert srcs[0].dtype != dtypes.bool, "UnaryOps.NEG does not accept dtype bool" diff --git a/tinygrad/ops.py b/tinygrad/ops.py index bb810fc2..646c78bf 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -330,7 +330,8 @@ def type_verify(uops): if uop is UOps.ALU: if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}" elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}: - assert dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), f"{arg} output dtype mismatch {dtype=} != {bd=}" + bd = dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool + assert dtype == bd, f"{arg} output dtype mismatch {dtype=} != {bd=}" assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}" elif arg is BinaryOps.IDIV: assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), f"input dtype is not int {src[0].dtype=}, {src[1].dtype=}" @@ -340,8 +341,8 @@ def type_verify(uops): assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}" elif arg in BinaryOps: assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}" elif arg == TernaryOps.WHERE: - assert src[0].dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), \ - f"{arg} selector dtype mismatch {src[0].dtype=} != {bd}" + bd = dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool + assert src[0].dtype == bd, f"{arg} selector dtype mismatch {src[0].dtype=} != {bd}" assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}" def uop_alu_resolve(u:UOp) -> sint: diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 349957dd..90d633e6 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1658,7 +1658,8 @@ class Tensor: """ def parse_formula(formula: str): lparens, rparens = map(lambda x: [i for i, ch in enumerate(formula.split()) if ch == x], ("(", ")")) - assert len(lparens) == len(rparens) and sorted(flatten(pairs := list(zip(lparens, rparens)))) == flatten(pairs), "bracket mismatch" + pairs = list(zip(lparens, rparens)) + assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch" return [name for name in formula.split() if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)] assert formula.count("->") == 1, 'need exactly one "->" in formula' @@ -1889,7 +1890,7 @@ class Tensor: """ n1, n2 = len(self.shape), len(w.shape) assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D" - assert (L:=self.shape[-1]) == (R:=w.shape[-min(n2, 2)]), f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})" + if (L:=self.shape[-1]) != (R:=w.shape[-min(n2, 2)]): raise AssertionError(f"shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})") x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1]) w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2)) return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype)