RUF018 assignment-in-assert [run_process_replay] (#6172)

assertion should not have side effect or `-O` breaks.

initially just wanted to fix the one in rearrange, but it also made some long lines less long
This commit is contained in:
chenyu 2024-08-19 00:34:52 -04:00 committed by GitHub
parent 9c60a27ece
commit b36a7273c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 70 additions and 38 deletions

View File

@ -30,6 +30,7 @@ lint.select = [
"A", # builtin-variable-shadowing, builtin-argument-shadowing, builtin-attribute-shadowing
"SIM105", # suppressible-exception
"FURB110",# if-exp-instead-of-or-operator
"RUF018", # assignment-in-assert
]
line-length = 150

View File

@ -20,7 +20,8 @@ class TestAMD(unittest.TestCase):
global_size=TestAMD.d0_runner.global_size, local_size=TestAMD.d0_runner.local_size)
TestAMD.d0_runner.clprg(TestAMD.a.lazydata.buffer._buf, TestAMD.b.lazydata.buffer._buf,
global_size=TestAMD.d0_runner.global_size, local_size=TestAMD.d0_runner.local_size)
assert (val:=TestAMD.a.lazydata.buffer.as_buffer().cast("f")[0]) == 4000.0, f"got val {val}"
val = TestAMD.a.lazydata.buffer.as_buffer().cast("f")[0]
assert val == 4000.0, f"got val {val}"
if __name__ == "__main__":
unittest.main()

View File

@ -66,7 +66,8 @@ class TestHCQ(unittest.TestCase):
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert (val:=TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]) == 2000.0, f"got val {val}"
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
assert val == 2000.0, f"got val {val}"
def test_run_1000_times(self):
temp_signal = TestHCQ.d0._alloc_signal(value=0)
@ -81,7 +82,8 @@ class TestHCQ(unittest.TestCase):
TestHCQ.compute_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert (val:=TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]) == 2000.0, f"got val {val}"
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
assert val == 2000.0, f"got val {val}"
def test_run_to_3(self):
temp_signal = TestHCQ.d0._alloc_signal(value=0)
@ -94,7 +96,8 @@ class TestHCQ(unittest.TestCase):
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 3.0, f"got val {val}"
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
assert val == 3.0, f"got val {val}"
def test_update_exec(self):
q = TestHCQ.compute_queue()
@ -104,8 +107,10 @@ class TestHCQ(unittest.TestCase):
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}, should not be updated"
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
assert val == 1.0, f"got val {val}"
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
assert val == 0.0, f"got val {val}, should not be updated"
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
def test_bind_run(self):
@ -122,7 +127,8 @@ class TestHCQ(unittest.TestCase):
TestHCQ.compute_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert (val:=TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]) == 2000.0, f"got val {val}"
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
assert val == 2000.0, f"got val {val}"
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
def test_update_exec_binded(self):
@ -136,8 +142,10 @@ class TestHCQ(unittest.TestCase):
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}, should not be updated"
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
assert val == 1.0, f"got val {val}"
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
assert val == 0.0, f"got val {val}, should not be updated"
@unittest.skipIf(CI, "Can't handle async update on CPU")
def test_wait_signal(self):
@ -167,7 +175,8 @@ class TestHCQ(unittest.TestCase):
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
assert val == 1.0, f"got val {val}"
def test_submit_empty_queues(self):
TestHCQ.compute_queue().submit(TestHCQ.d0)
@ -198,7 +207,8 @@ class TestHCQ(unittest.TestCase):
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
assert val == 1.0, f"got val {val}"
def test_copy_1000_times(self):
q = TestHCQ.copy_queue()
@ -212,7 +222,8 @@ class TestHCQ(unittest.TestCase):
# confirm the signal didn't exceed the put value
with self.assertRaises(RuntimeError):
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value + 1, timeout=50)
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}"
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
assert val == 0.0, f"got val {val}"
def test_copy(self):
q = TestHCQ.copy_queue()
@ -221,7 +232,8 @@ class TestHCQ(unittest.TestCase):
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 1.0, f"got val {val}"
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
assert val == 1.0, f"got val {val}"
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
def test_bind_copy(self):
@ -237,7 +249,8 @@ class TestHCQ(unittest.TestCase):
# confirm the signal didn't exceed the put value
with self.assertRaises(RuntimeError):
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value + 1, timeout=50)
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}"
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
assert val == 0.0, f"got val {val}"
def test_copy_bandwidth(self):
# THEORY: the bandwidth is low here because it's only using one SDMA queue. I suspect it's more stable like this at least.
@ -276,7 +289,8 @@ class TestHCQ(unittest.TestCase):
q.submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert (val:=TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
assert val == 1.0, f"got val {val}"
def test_cross_device_signal(self):
d1 = Device[f"{Device.DEFAULT}:1"]
@ -306,7 +320,8 @@ class TestHCQ(unittest.TestCase):
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
assert val == 1.0, f"got val {val}"
if __name__ == "__main__":
unittest.main()

View File

@ -46,7 +46,8 @@ class TestNV(unittest.TestCase):
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SIN, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.ulong, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501
temp_runner = get_runner(TestNV.d0.dname, (ast,))
temp_runner([TestNV.b.lazydata.buffer, TestNV.along.lazydata.buffer], var_vals={})
assert abs((val:=TestNV.b.lazydata.buffer.as_buffer().cast("f")[0]) - 0.80647) < 0.001, f"got val {val}"
val = TestNV.b.lazydata.buffer.as_buffer().cast("f")[0]
assert abs(val - 0.80647) < 0.001, f"got val {val}"
def test_kernargs_no_oob_access(self):
kernargs_start = TestNV.d0._gpu_alloc((2 << 20), map_to_cpu=True).va_addr
@ -59,7 +60,8 @@ class TestNV(unittest.TestCase):
q.signal(TestNV.d0.timeline_signal, TestNV.d0.timeline_value).submit(TestNV.d0)
TestNV.d0._wait_signal(TestNV.d0.timeline_signal, TestNV.d0.timeline_value)
TestNV.d0.timeline_value += 1
assert (val:=TestNV.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
val = TestNV.b.lazydata.buffer.as_buffer().cast("f")[0]
assert val == 1.0, f"got val {val}"
if __name__ == "__main__":
unittest.main()

View File

@ -57,10 +57,12 @@ class TestAssign(unittest.TestCase):
x.realize()
x = Tensor([0])
f(x)
assert (out:=x.item()) == 1, f"expected 1, got {out}"
out = x.item()
assert out == 1, f"expected 1, got {out}"
x = Tensor([0])
f(x)
assert (out:=x.item()) == 1, f"expected 1, got {out}"
out = x.item()
assert out == 1, f"expected 1, got {out}"
def test_assign_add_jit(self):
@TinyJit

View File

@ -107,7 +107,8 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
assert val == 1.0, f"got val {val}"
def test_exec_2_kernels_100_times(self):
q = TestHCQ.d0.hw_compute_queue_t()
@ -120,7 +121,8 @@ class TestHCQ(unittest.TestCase):
q.update_wait(0, value=TestHCQ.d0.timeline_value - 1).update_signal(3, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
TestHCQ.d0.timeline_value += 1
assert (val:=TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]) == 200.0, f"got val {val}"
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
assert val == 200.0, f"got val {val}"
def test_exec_update(self):
q = TestHCQ.d0.hw_compute_queue_t()
@ -132,8 +134,10 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}, should not be updated"
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
assert val == 1.0, f"got val {val}"
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
assert val == 0.0, f"got val {val}, should not be updated"
def test_exec_update_fuzz(self):
a = Tensor.randint((3, 3, 3), dtype=dtypes.int, device=Device.DEFAULT).realize()
@ -178,7 +182,8 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 1.0, f"got val {val}"
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
assert val == 1.0, f"got val {val}"
def test_copy_long(self):
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
@ -211,7 +216,8 @@ class TestHCQ(unittest.TestCase):
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
TestHCQ.d0.timeline_value += 1
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 1.0, f"got val {val}"
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
assert val == 1.0, f"got val {val}"
def test_update_copy_long(self):
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")

View File

@ -1532,7 +1532,8 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:List[Buffer], opts=[]
for opt in opts:
k.apply_opt(opt)
if expected_color_size is not None:
assert (cs:=list(zip(k.colors(), k.full_shape))) == expected_color_size, f"expected={expected_color_size} got={cs}"
cs = list(zip(k.colors(), k.full_shape))
assert cs == expected_color_size, f"expected={expected_color_size} got={cs}"
prg = get_prg(k)
for buf,_ in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled
prg.exec(real_bufs)

View File

@ -598,8 +598,8 @@ def linearize_uop(sink_in:Union[UOp, List[UOp]], opts:Optional[Renderer]=None, s
# NOTE: multiple identical stores to DEFINE_LOCAL is okay
# NOTE: for PTX you have to propogate through some the calculations to determine if it is a store to DEFINE_LOCAL
def _islocalbuf(u: UOp): return u.op is UOps.DEFINE_LOCAL or any(_islocalbuf(x) for x in u.src if u.op in [UOps.ALU, UOps.CAST])
assert len(all_stores := [x.src[0:2]+x.src[3:] for x in _uops if x.op is UOps.STORE and not _islocalbuf(x.src[0])]) \
== len(dedup(all_stores)), "repeated stores in uops"
all_stores = [x.src[0:2]+x.src[3:] for x in _uops if x.op is UOps.STORE and not _islocalbuf(x.src[0])]
assert len(all_stores) == len(dedup(all_stores)), "repeated stores in uops"
except AssertionError as e:
print_uops(_uops)
if not CI:

View File

@ -19,8 +19,8 @@ class _Device:
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def __get_canonicalized_item(self, ix:str) -> Compiled:
assert ((cpn:=multiprocessing.current_process().name) == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], \
f"can only open device {ix} from parent, not {cpn}"
cpn = multiprocessing.current_process().name
assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], f"can only open device {ix} from parent, not {cpn}"
x = ix.split(":")[0].upper()
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")

View File

@ -39,7 +39,8 @@ def round_up(num, amt:int): return (num+amt-1)//amt * amt
def data64(data: int) -> Tuple[int, int]: return (data >> 32, data & 0xFFFFFFFF)
def data64_le(data: int) -> Tuple[int, int]: return (data & 0xFFFFFFFF, data >> 32)
def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]:
assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" # noqa: E501
kvs = set([(k,v) for d in ds for k,v in d.items()])
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
return {k:v for d in ds for k,v in d.items()}
def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]]:
a:List[T] = []

View File

@ -141,7 +141,8 @@ class LazyBuffer:
srcs.append(root._view(s.base.contiguous_child[1]))
else:
srcs.append(s)
assert all_same(dts:=[x.dtype.scalar() for x in (srcs[1:] if op is TernaryOps.WHERE else srcs)]), f"all dtypes must match {dts} on {op}"
if not all_same(dts:=[x.dtype.scalar() for x in (srcs[1:] if op is TernaryOps.WHERE else srcs)]):
raise AssertionError(f"all dtypes must match {dts} on {op}")
assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
if op is TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
if op is UnaryOps.NEG: assert srcs[0].dtype != dtypes.bool, "UnaryOps.NEG does not accept dtype bool"

View File

@ -330,7 +330,8 @@ def type_verify(uops):
if uop is UOps.ALU:
if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}:
assert dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), f"{arg} output dtype mismatch {dtype=} != {bd=}"
bd = dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool
assert dtype == bd, f"{arg} output dtype mismatch {dtype=} != {bd=}"
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
elif arg is BinaryOps.IDIV:
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), f"input dtype is not int {src[0].dtype=}, {src[1].dtype=}"
@ -340,8 +341,8 @@ def type_verify(uops):
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
elif arg in BinaryOps: assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
elif arg == TernaryOps.WHERE:
assert src[0].dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), \
f"{arg} selector dtype mismatch {src[0].dtype=} != {bd}"
bd = dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool
assert src[0].dtype == bd, f"{arg} selector dtype mismatch {src[0].dtype=} != {bd}"
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
def uop_alu_resolve(u:UOp) -> sint:

View File

@ -1658,7 +1658,8 @@ class Tensor:
"""
def parse_formula(formula: str):
lparens, rparens = map(lambda x: [i for i, ch in enumerate(formula.split()) if ch == x], ("(", ")"))
assert len(lparens) == len(rparens) and sorted(flatten(pairs := list(zip(lparens, rparens)))) == flatten(pairs), "bracket mismatch"
pairs = list(zip(lparens, rparens))
assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
return [name for name in formula.split() if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
assert formula.count("->") == 1, 'need exactly one "->" in formula'
@ -1889,7 +1890,7 @@ class Tensor:
"""
n1, n2 = len(self.shape), len(w.shape)
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
assert (L:=self.shape[-1]) == (R:=w.shape[-min(n2, 2)]), f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})"
if (L:=self.shape[-1]) != (R:=w.shape[-min(n2, 2)]): raise AssertionError(f"shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})")
x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype)