mirror of https://github.com/commaai/tinygrad.git
RUF018 assignment-in-assert [run_process_replay] (#6172)
assertion should not have side effect or `-O` breaks. initially just wanted to fix the one in rearrange, but it also made some long lines less long
This commit is contained in:
parent
9c60a27ece
commit
b36a7273c6
|
@ -30,6 +30,7 @@ lint.select = [
|
||||||
"A", # builtin-variable-shadowing, builtin-argument-shadowing, builtin-attribute-shadowing
|
"A", # builtin-variable-shadowing, builtin-argument-shadowing, builtin-attribute-shadowing
|
||||||
"SIM105", # suppressible-exception
|
"SIM105", # suppressible-exception
|
||||||
"FURB110",# if-exp-instead-of-or-operator
|
"FURB110",# if-exp-instead-of-or-operator
|
||||||
|
"RUF018", # assignment-in-assert
|
||||||
]
|
]
|
||||||
|
|
||||||
line-length = 150
|
line-length = 150
|
||||||
|
|
|
@ -20,7 +20,8 @@ class TestAMD(unittest.TestCase):
|
||||||
global_size=TestAMD.d0_runner.global_size, local_size=TestAMD.d0_runner.local_size)
|
global_size=TestAMD.d0_runner.global_size, local_size=TestAMD.d0_runner.local_size)
|
||||||
TestAMD.d0_runner.clprg(TestAMD.a.lazydata.buffer._buf, TestAMD.b.lazydata.buffer._buf,
|
TestAMD.d0_runner.clprg(TestAMD.a.lazydata.buffer._buf, TestAMD.b.lazydata.buffer._buf,
|
||||||
global_size=TestAMD.d0_runner.global_size, local_size=TestAMD.d0_runner.local_size)
|
global_size=TestAMD.d0_runner.global_size, local_size=TestAMD.d0_runner.local_size)
|
||||||
assert (val:=TestAMD.a.lazydata.buffer.as_buffer().cast("f")[0]) == 4000.0, f"got val {val}"
|
val = TestAMD.a.lazydata.buffer.as_buffer().cast("f")[0]
|
||||||
|
assert val == 4000.0, f"got val {val}"
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -66,7 +66,8 @@ class TestHCQ(unittest.TestCase):
|
||||||
q.submit(TestHCQ.d0)
|
q.submit(TestHCQ.d0)
|
||||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||||
TestHCQ.d0.timeline_value += 1
|
TestHCQ.d0.timeline_value += 1
|
||||||
assert (val:=TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]) == 2000.0, f"got val {val}"
|
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
|
||||||
|
assert val == 2000.0, f"got val {val}"
|
||||||
|
|
||||||
def test_run_1000_times(self):
|
def test_run_1000_times(self):
|
||||||
temp_signal = TestHCQ.d0._alloc_signal(value=0)
|
temp_signal = TestHCQ.d0._alloc_signal(value=0)
|
||||||
|
@ -81,7 +82,8 @@ class TestHCQ(unittest.TestCase):
|
||||||
TestHCQ.compute_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
TestHCQ.compute_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||||
TestHCQ.d0.timeline_value += 1
|
TestHCQ.d0.timeline_value += 1
|
||||||
assert (val:=TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]) == 2000.0, f"got val {val}"
|
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
|
||||||
|
assert val == 2000.0, f"got val {val}"
|
||||||
|
|
||||||
def test_run_to_3(self):
|
def test_run_to_3(self):
|
||||||
temp_signal = TestHCQ.d0._alloc_signal(value=0)
|
temp_signal = TestHCQ.d0._alloc_signal(value=0)
|
||||||
|
@ -94,7 +96,8 @@ class TestHCQ(unittest.TestCase):
|
||||||
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||||
TestHCQ.d0.timeline_value += 1
|
TestHCQ.d0.timeline_value += 1
|
||||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 3.0, f"got val {val}"
|
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||||
|
assert val == 3.0, f"got val {val}"
|
||||||
|
|
||||||
def test_update_exec(self):
|
def test_update_exec(self):
|
||||||
q = TestHCQ.compute_queue()
|
q = TestHCQ.compute_queue()
|
||||||
|
@ -104,8 +107,10 @@ class TestHCQ(unittest.TestCase):
|
||||||
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||||
TestHCQ.d0.timeline_value += 1
|
TestHCQ.d0.timeline_value += 1
|
||||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
|
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}, should not be updated"
|
assert val == 1.0, f"got val {val}"
|
||||||
|
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
|
||||||
|
assert val == 0.0, f"got val {val}, should not be updated"
|
||||||
|
|
||||||
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
|
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
|
||||||
def test_bind_run(self):
|
def test_bind_run(self):
|
||||||
|
@ -122,7 +127,8 @@ class TestHCQ(unittest.TestCase):
|
||||||
TestHCQ.compute_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
TestHCQ.compute_queue().signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||||
TestHCQ.d0.timeline_value += 1
|
TestHCQ.d0.timeline_value += 1
|
||||||
assert (val:=TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]) == 2000.0, f"got val {val}"
|
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
|
||||||
|
assert val == 2000.0, f"got val {val}"
|
||||||
|
|
||||||
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
|
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
|
||||||
def test_update_exec_binded(self):
|
def test_update_exec_binded(self):
|
||||||
|
@ -136,8 +142,10 @@ class TestHCQ(unittest.TestCase):
|
||||||
q.submit(TestHCQ.d0)
|
q.submit(TestHCQ.d0)
|
||||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||||
TestHCQ.d0.timeline_value += 1
|
TestHCQ.d0.timeline_value += 1
|
||||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
|
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}, should not be updated"
|
assert val == 1.0, f"got val {val}"
|
||||||
|
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
|
||||||
|
assert val == 0.0, f"got val {val}, should not be updated"
|
||||||
|
|
||||||
@unittest.skipIf(CI, "Can't handle async update on CPU")
|
@unittest.skipIf(CI, "Can't handle async update on CPU")
|
||||||
def test_wait_signal(self):
|
def test_wait_signal(self):
|
||||||
|
@ -167,7 +175,8 @@ class TestHCQ(unittest.TestCase):
|
||||||
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||||
TestHCQ.d0.timeline_value += 1
|
TestHCQ.d0.timeline_value += 1
|
||||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
|
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||||
|
assert val == 1.0, f"got val {val}"
|
||||||
|
|
||||||
def test_submit_empty_queues(self):
|
def test_submit_empty_queues(self):
|
||||||
TestHCQ.compute_queue().submit(TestHCQ.d0)
|
TestHCQ.compute_queue().submit(TestHCQ.d0)
|
||||||
|
@ -198,7 +207,8 @@ class TestHCQ(unittest.TestCase):
|
||||||
q.submit(TestHCQ.d0)
|
q.submit(TestHCQ.d0)
|
||||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||||
TestHCQ.d0.timeline_value += 1
|
TestHCQ.d0.timeline_value += 1
|
||||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
|
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||||
|
assert val == 1.0, f"got val {val}"
|
||||||
|
|
||||||
def test_copy_1000_times(self):
|
def test_copy_1000_times(self):
|
||||||
q = TestHCQ.copy_queue()
|
q = TestHCQ.copy_queue()
|
||||||
|
@ -212,7 +222,8 @@ class TestHCQ(unittest.TestCase):
|
||||||
# confirm the signal didn't exceed the put value
|
# confirm the signal didn't exceed the put value
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value + 1, timeout=50)
|
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value + 1, timeout=50)
|
||||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}"
|
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
|
||||||
|
assert val == 0.0, f"got val {val}"
|
||||||
|
|
||||||
def test_copy(self):
|
def test_copy(self):
|
||||||
q = TestHCQ.copy_queue()
|
q = TestHCQ.copy_queue()
|
||||||
|
@ -221,7 +232,8 @@ class TestHCQ(unittest.TestCase):
|
||||||
q.submit(TestHCQ.d0)
|
q.submit(TestHCQ.d0)
|
||||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||||
TestHCQ.d0.timeline_value += 1
|
TestHCQ.d0.timeline_value += 1
|
||||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 1.0, f"got val {val}"
|
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
|
||||||
|
assert val == 1.0, f"got val {val}"
|
||||||
|
|
||||||
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
|
@unittest.skipUnless(Device.DEFAULT == "NV", "Only NV supports bind")
|
||||||
def test_bind_copy(self):
|
def test_bind_copy(self):
|
||||||
|
@ -237,7 +249,8 @@ class TestHCQ(unittest.TestCase):
|
||||||
# confirm the signal didn't exceed the put value
|
# confirm the signal didn't exceed the put value
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value + 1, timeout=50)
|
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value + 1, timeout=50)
|
||||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}"
|
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
|
||||||
|
assert val == 0.0, f"got val {val}"
|
||||||
|
|
||||||
def test_copy_bandwidth(self):
|
def test_copy_bandwidth(self):
|
||||||
# THEORY: the bandwidth is low here because it's only using one SDMA queue. I suspect it's more stable like this at least.
|
# THEORY: the bandwidth is low here because it's only using one SDMA queue. I suspect it's more stable like this at least.
|
||||||
|
@ -276,7 +289,8 @@ class TestHCQ(unittest.TestCase):
|
||||||
q.submit(TestHCQ.d0)
|
q.submit(TestHCQ.d0)
|
||||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||||
TestHCQ.d0.timeline_value += 1
|
TestHCQ.d0.timeline_value += 1
|
||||||
assert (val:=TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
|
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
|
||||||
|
assert val == 1.0, f"got val {val}"
|
||||||
|
|
||||||
def test_cross_device_signal(self):
|
def test_cross_device_signal(self):
|
||||||
d1 = Device[f"{Device.DEFAULT}:1"]
|
d1 = Device[f"{Device.DEFAULT}:1"]
|
||||||
|
@ -306,7 +320,8 @@ class TestHCQ(unittest.TestCase):
|
||||||
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
q.signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||||
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
TestHCQ.d0._wait_signal(TestHCQ.d0.timeline_signal, TestHCQ.d0.timeline_value)
|
||||||
TestHCQ.d0.timeline_value += 1
|
TestHCQ.d0.timeline_value += 1
|
||||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
|
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||||
|
assert val == 1.0, f"got val {val}"
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -46,7 +46,8 @@ class TestNV(unittest.TestCase):
|
||||||
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SIN, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.ulong, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SIN, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.ulong, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501
|
||||||
temp_runner = get_runner(TestNV.d0.dname, (ast,))
|
temp_runner = get_runner(TestNV.d0.dname, (ast,))
|
||||||
temp_runner([TestNV.b.lazydata.buffer, TestNV.along.lazydata.buffer], var_vals={})
|
temp_runner([TestNV.b.lazydata.buffer, TestNV.along.lazydata.buffer], var_vals={})
|
||||||
assert abs((val:=TestNV.b.lazydata.buffer.as_buffer().cast("f")[0]) - 0.80647) < 0.001, f"got val {val}"
|
val = TestNV.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||||
|
assert abs(val - 0.80647) < 0.001, f"got val {val}"
|
||||||
|
|
||||||
def test_kernargs_no_oob_access(self):
|
def test_kernargs_no_oob_access(self):
|
||||||
kernargs_start = TestNV.d0._gpu_alloc((2 << 20), map_to_cpu=True).va_addr
|
kernargs_start = TestNV.d0._gpu_alloc((2 << 20), map_to_cpu=True).va_addr
|
||||||
|
@ -59,7 +60,8 @@ class TestNV(unittest.TestCase):
|
||||||
q.signal(TestNV.d0.timeline_signal, TestNV.d0.timeline_value).submit(TestNV.d0)
|
q.signal(TestNV.d0.timeline_signal, TestNV.d0.timeline_value).submit(TestNV.d0)
|
||||||
TestNV.d0._wait_signal(TestNV.d0.timeline_signal, TestNV.d0.timeline_value)
|
TestNV.d0._wait_signal(TestNV.d0.timeline_signal, TestNV.d0.timeline_value)
|
||||||
TestNV.d0.timeline_value += 1
|
TestNV.d0.timeline_value += 1
|
||||||
assert (val:=TestNV.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
|
val = TestNV.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||||
|
assert val == 1.0, f"got val {val}"
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
@ -57,10 +57,12 @@ class TestAssign(unittest.TestCase):
|
||||||
x.realize()
|
x.realize()
|
||||||
x = Tensor([0])
|
x = Tensor([0])
|
||||||
f(x)
|
f(x)
|
||||||
assert (out:=x.item()) == 1, f"expected 1, got {out}"
|
out = x.item()
|
||||||
|
assert out == 1, f"expected 1, got {out}"
|
||||||
x = Tensor([0])
|
x = Tensor([0])
|
||||||
f(x)
|
f(x)
|
||||||
assert (out:=x.item()) == 1, f"expected 1, got {out}"
|
out = x.item()
|
||||||
|
assert out == 1, f"expected 1, got {out}"
|
||||||
|
|
||||||
def test_assign_add_jit(self):
|
def test_assign_add_jit(self):
|
||||||
@TinyJit
|
@TinyJit
|
||||||
|
|
|
@ -107,7 +107,8 @@ class TestHCQ(unittest.TestCase):
|
||||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||||
TestHCQ.d0.timeline_value += 1
|
TestHCQ.d0.timeline_value += 1
|
||||||
|
|
||||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
|
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||||
|
assert val == 1.0, f"got val {val}"
|
||||||
|
|
||||||
def test_exec_2_kernels_100_times(self):
|
def test_exec_2_kernels_100_times(self):
|
||||||
q = TestHCQ.d0.hw_compute_queue_t()
|
q = TestHCQ.d0.hw_compute_queue_t()
|
||||||
|
@ -120,7 +121,8 @@ class TestHCQ(unittest.TestCase):
|
||||||
q.update_wait(0, value=TestHCQ.d0.timeline_value - 1).update_signal(3, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
q.update_wait(0, value=TestHCQ.d0.timeline_value - 1).update_signal(3, value=TestHCQ.d0.timeline_value).submit(TestHCQ.d0)
|
||||||
TestHCQ.d0.timeline_value += 1
|
TestHCQ.d0.timeline_value += 1
|
||||||
|
|
||||||
assert (val:=TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]) == 200.0, f"got val {val}"
|
val = TestHCQ.a.lazydata.buffer.as_buffer().cast("f")[0]
|
||||||
|
assert val == 200.0, f"got val {val}"
|
||||||
|
|
||||||
def test_exec_update(self):
|
def test_exec_update(self):
|
||||||
q = TestHCQ.d0.hw_compute_queue_t()
|
q = TestHCQ.d0.hw_compute_queue_t()
|
||||||
|
@ -132,8 +134,10 @@ class TestHCQ(unittest.TestCase):
|
||||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||||
TestHCQ.d0.timeline_value += 1
|
TestHCQ.d0.timeline_value += 1
|
||||||
|
|
||||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]) == 1.0, f"got val {val}"
|
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[0]
|
||||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}, should not be updated"
|
assert val == 1.0, f"got val {val}"
|
||||||
|
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
|
||||||
|
assert val == 0.0, f"got val {val}, should not be updated"
|
||||||
|
|
||||||
def test_exec_update_fuzz(self):
|
def test_exec_update_fuzz(self):
|
||||||
a = Tensor.randint((3, 3, 3), dtype=dtypes.int, device=Device.DEFAULT).realize()
|
a = Tensor.randint((3, 3, 3), dtype=dtypes.int, device=Device.DEFAULT).realize()
|
||||||
|
@ -178,7 +182,8 @@ class TestHCQ(unittest.TestCase):
|
||||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||||
TestHCQ.d0.timeline_value += 1
|
TestHCQ.d0.timeline_value += 1
|
||||||
|
|
||||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 1.0, f"got val {val}"
|
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
|
||||||
|
assert val == 1.0, f"got val {val}"
|
||||||
|
|
||||||
def test_copy_long(self):
|
def test_copy_long(self):
|
||||||
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
|
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
|
||||||
|
@ -211,7 +216,8 @@ class TestHCQ(unittest.TestCase):
|
||||||
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
TestHCQ.d0.timeline_signal.wait(TestHCQ.d0.timeline_value)
|
||||||
TestHCQ.d0.timeline_value += 1
|
TestHCQ.d0.timeline_value += 1
|
||||||
|
|
||||||
assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 1.0, f"got val {val}"
|
val = TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]
|
||||||
|
assert val == 1.0, f"got val {val}"
|
||||||
|
|
||||||
def test_update_copy_long(self):
|
def test_update_copy_long(self):
|
||||||
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
|
if TestHCQ.d0.hw_copy_queue_t is None: self.skipTest("device does not support copy queue")
|
||||||
|
|
|
@ -1532,7 +1532,8 @@ def _helper_linearizer_opt_ast(realized_ast:UOp, real_bufs:List[Buffer], opts=[]
|
||||||
for opt in opts:
|
for opt in opts:
|
||||||
k.apply_opt(opt)
|
k.apply_opt(opt)
|
||||||
if expected_color_size is not None:
|
if expected_color_size is not None:
|
||||||
assert (cs:=list(zip(k.colors(), k.full_shape))) == expected_color_size, f"expected={expected_color_size} got={cs}"
|
cs = list(zip(k.colors(), k.full_shape))
|
||||||
|
assert cs == expected_color_size, f"expected={expected_color_size} got={cs}"
|
||||||
prg = get_prg(k)
|
prg = get_prg(k)
|
||||||
for buf,_ in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled
|
for buf,_ in outbufs: buf.copyin(np.zeros((buf.size, ), dtype=_to_np_dtype(buf.dtype)).data) # Zero to check that all values are filled
|
||||||
prg.exec(real_bufs)
|
prg.exec(real_bufs)
|
||||||
|
|
|
@ -598,8 +598,8 @@ def linearize_uop(sink_in:Union[UOp, List[UOp]], opts:Optional[Renderer]=None, s
|
||||||
# NOTE: multiple identical stores to DEFINE_LOCAL is okay
|
# NOTE: multiple identical stores to DEFINE_LOCAL is okay
|
||||||
# NOTE: for PTX you have to propogate through some the calculations to determine if it is a store to DEFINE_LOCAL
|
# NOTE: for PTX you have to propogate through some the calculations to determine if it is a store to DEFINE_LOCAL
|
||||||
def _islocalbuf(u: UOp): return u.op is UOps.DEFINE_LOCAL or any(_islocalbuf(x) for x in u.src if u.op in [UOps.ALU, UOps.CAST])
|
def _islocalbuf(u: UOp): return u.op is UOps.DEFINE_LOCAL or any(_islocalbuf(x) for x in u.src if u.op in [UOps.ALU, UOps.CAST])
|
||||||
assert len(all_stores := [x.src[0:2]+x.src[3:] for x in _uops if x.op is UOps.STORE and not _islocalbuf(x.src[0])]) \
|
all_stores = [x.src[0:2]+x.src[3:] for x in _uops if x.op is UOps.STORE and not _islocalbuf(x.src[0])]
|
||||||
== len(dedup(all_stores)), "repeated stores in uops"
|
assert len(all_stores) == len(dedup(all_stores)), "repeated stores in uops"
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
print_uops(_uops)
|
print_uops(_uops)
|
||||||
if not CI:
|
if not CI:
|
||||||
|
|
|
@ -19,8 +19,8 @@ class _Device:
|
||||||
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
|
def __getitem__(self, ix:str) -> Compiled: return self.__get_canonicalized_item(self.canonicalize(ix))
|
||||||
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
|
||||||
def __get_canonicalized_item(self, ix:str) -> Compiled:
|
def __get_canonicalized_item(self, ix:str) -> Compiled:
|
||||||
assert ((cpn:=multiprocessing.current_process().name) == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], \
|
cpn = multiprocessing.current_process().name
|
||||||
f"can only open device {ix} from parent, not {cpn}"
|
assert (cpn == "MainProcess") or ix.split(":")[0] in ["DISK", "NPY"], f"can only open device {ix} from parent, not {cpn}"
|
||||||
x = ix.split(":")[0].upper()
|
x = ix.split(":")[0].upper()
|
||||||
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
|
ret = [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "device") and x in self._devices][0](ix) # noqa: E501
|
||||||
if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
|
if DEBUG >= 1: print(f"opened device {ix} from pid:{os.getpid()}")
|
||||||
|
|
|
@ -39,7 +39,8 @@ def round_up(num, amt:int): return (num+amt-1)//amt * amt
|
||||||
def data64(data: int) -> Tuple[int, int]: return (data >> 32, data & 0xFFFFFFFF)
|
def data64(data: int) -> Tuple[int, int]: return (data >> 32, data & 0xFFFFFFFF)
|
||||||
def data64_le(data: int) -> Tuple[int, int]: return (data & 0xFFFFFFFF, data >> 32)
|
def data64_le(data: int) -> Tuple[int, int]: return (data & 0xFFFFFFFF, data >> 32)
|
||||||
def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]:
|
def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]:
|
||||||
assert len(kvs:=set([(k,v) for d in ds for k,v in d.items()])) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key" # noqa: E501
|
kvs = set([(k,v) for d in ds for k,v in d.items()])
|
||||||
|
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
|
||||||
return {k:v for d in ds for k,v in d.items()}
|
return {k:v for d in ds for k,v in d.items()}
|
||||||
def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]]:
|
def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]]:
|
||||||
a:List[T] = []
|
a:List[T] = []
|
||||||
|
|
|
@ -141,7 +141,8 @@ class LazyBuffer:
|
||||||
srcs.append(root._view(s.base.contiguous_child[1]))
|
srcs.append(root._view(s.base.contiguous_child[1]))
|
||||||
else:
|
else:
|
||||||
srcs.append(s)
|
srcs.append(s)
|
||||||
assert all_same(dts:=[x.dtype.scalar() for x in (srcs[1:] if op is TernaryOps.WHERE else srcs)]), f"all dtypes must match {dts} on {op}"
|
if not all_same(dts:=[x.dtype.scalar() for x in (srcs[1:] if op is TernaryOps.WHERE else srcs)]):
|
||||||
|
raise AssertionError(f"all dtypes must match {dts} on {op}")
|
||||||
assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
|
assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
|
||||||
if op is TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
|
if op is TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
|
||||||
if op is UnaryOps.NEG: assert srcs[0].dtype != dtypes.bool, "UnaryOps.NEG does not accept dtype bool"
|
if op is UnaryOps.NEG: assert srcs[0].dtype != dtypes.bool, "UnaryOps.NEG does not accept dtype bool"
|
||||||
|
|
|
@ -330,7 +330,8 @@ def type_verify(uops):
|
||||||
if uop is UOps.ALU:
|
if uop is UOps.ALU:
|
||||||
if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||||
elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}:
|
elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}:
|
||||||
assert dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), f"{arg} output dtype mismatch {dtype=} != {bd=}"
|
bd = dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool
|
||||||
|
assert dtype == bd, f"{arg} output dtype mismatch {dtype=} != {bd=}"
|
||||||
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||||
elif arg is BinaryOps.IDIV:
|
elif arg is BinaryOps.IDIV:
|
||||||
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), f"input dtype is not int {src[0].dtype=}, {src[1].dtype=}"
|
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), f"input dtype is not int {src[0].dtype=}, {src[1].dtype=}"
|
||||||
|
@ -340,8 +341,8 @@ def type_verify(uops):
|
||||||
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||||
elif arg in BinaryOps: assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
elif arg in BinaryOps: assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||||
elif arg == TernaryOps.WHERE:
|
elif arg == TernaryOps.WHERE:
|
||||||
assert src[0].dtype == (bd:=(dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool)), \
|
bd = dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool
|
||||||
f"{arg} selector dtype mismatch {src[0].dtype=} != {bd}"
|
assert src[0].dtype == bd, f"{arg} selector dtype mismatch {src[0].dtype=} != {bd}"
|
||||||
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
|
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
|
||||||
|
|
||||||
def uop_alu_resolve(u:UOp) -> sint:
|
def uop_alu_resolve(u:UOp) -> sint:
|
||||||
|
|
|
@ -1658,7 +1658,8 @@ class Tensor:
|
||||||
"""
|
"""
|
||||||
def parse_formula(formula: str):
|
def parse_formula(formula: str):
|
||||||
lparens, rparens = map(lambda x: [i for i, ch in enumerate(formula.split()) if ch == x], ("(", ")"))
|
lparens, rparens = map(lambda x: [i for i, ch in enumerate(formula.split()) if ch == x], ("(", ")"))
|
||||||
assert len(lparens) == len(rparens) and sorted(flatten(pairs := list(zip(lparens, rparens)))) == flatten(pairs), "bracket mismatch"
|
pairs = list(zip(lparens, rparens))
|
||||||
|
assert len(lparens) == len(rparens) and sorted(flatten(pairs)) == flatten(pairs), "bracket mismatch"
|
||||||
return [name for name in formula.split() if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
|
return [name for name in formula.split() if name not in ("(", ")")], [(s - 2*i, e - 1 - 2*i) for i, (s, e) in enumerate(pairs)]
|
||||||
|
|
||||||
assert formula.count("->") == 1, 'need exactly one "->" in formula'
|
assert formula.count("->") == 1, 'need exactly one "->" in formula'
|
||||||
|
@ -1889,7 +1890,7 @@ class Tensor:
|
||||||
"""
|
"""
|
||||||
n1, n2 = len(self.shape), len(w.shape)
|
n1, n2 = len(self.shape), len(w.shape)
|
||||||
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D"
|
||||||
assert (L:=self.shape[-1]) == (R:=w.shape[-min(n2, 2)]), f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})"
|
if (L:=self.shape[-1]) != (R:=w.shape[-min(n2, 2)]): raise AssertionError(f"shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})")
|
||||||
x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
|
x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1])
|
||||||
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
|
w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2))
|
||||||
return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype)
|
return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype)
|
||||||
|
|
Loading…
Reference in New Issue