mirror of https://github.com/commaai/tinygrad.git
parent
f271cd682b
commit
85691c8e20
|
@ -1,9 +1,10 @@
|
|||
import ctypes, unittest
|
||||
from tinygrad.helpers import init_c_struct_t
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.device import Device, Buffer, BufferXfer
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.runtime.driver.hsa import AQLQueue
|
||||
from tinygrad.runtime.graph.hsa import VirtAQLQueue
|
||||
from tinygrad.runtime.graph.hsa import VirtAQLQueue, HSAGraph
|
||||
from tinygrad.features.jit import JitItem
|
||||
|
||||
def get_hsa_inc_prog(dev, inc=1):
|
||||
prg = f"""
|
||||
|
@ -91,6 +92,25 @@ class TestHSADriver(unittest.TestCase):
|
|||
assert test_buf.as_buffer().cast('I')[0] == expected, f"{test_buf.as_buffer().cast('I')[0]} != {expected}, all packets executed?"
|
||||
del queue, clprogs
|
||||
|
||||
def test_hsa_copies_sync(self):
|
||||
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
|
||||
|
||||
test_buf0 = Buffer(d0, 1, dtypes.int)
|
||||
test_buf1 = Buffer(d0, 1, dtypes.int)
|
||||
test_buf2 = Buffer(d1, 1, dtypes.int)
|
||||
test_buf0.copyin(memoryview(bytearray(1*4)))
|
||||
test_buf1.copyin(memoryview(bytearray(1*4)))
|
||||
test_buf2.copyin(memoryview(bytearray(1*4)))
|
||||
|
||||
jit_cache = [JitItem(BufferXfer(), [test_buf0, test_buf2]), JitItem(BufferXfer(), [test_buf2, test_buf1])]
|
||||
graph = HSAGraph(jit_cache, [], {})
|
||||
|
||||
for i in range(10000):
|
||||
test_buf0.copyin(memoryview(bytearray(1*4)))
|
||||
test_buf2.copyin(memoryview(bytearray(int.to_bytes(4, length=1*4, byteorder='little'))))
|
||||
graph([], {})
|
||||
assert test_buf0.as_buffer().cast('I')[0] == 4
|
||||
assert test_buf2.as_buffer().cast('I')[0] == 0
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -187,7 +187,7 @@ class HSAGraph(MultiDeviceJITGraph):
|
|||
|
||||
# When synchronizing to aql packets, we only need to sync to the latest one, as they are executed in order.
|
||||
signal_deps, aql_deps = [x for x in rdeps if isinstance(x, hsa.hsa_signal_t)], [x for x in rdeps if isinstance(x, int)]
|
||||
deps = signal_deps + [max(aql_deps)] if aql_deps else []
|
||||
deps = signal_deps + ([max(aql_deps)] if len(aql_deps) > 0 else [])
|
||||
for dep in deps: wait_signals.append(self.dependency_as_signal(dep, sync_with_aql_packets=sync_with_aql_packets))
|
||||
|
||||
if new_dependency is not None:
|
||||
|
|
Loading…
Reference in New Issue