fix hsa sync issue (#3847)

* fix hsa sync issue

* linter
This commit is contained in:
nimlgen 2024-03-21 04:00:30 +03:00 committed by GitHub
parent f271cd682b
commit 85691c8e20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 23 additions and 3 deletions

View File

@ -1,9 +1,10 @@
import ctypes, unittest
from tinygrad.helpers import init_c_struct_t
from tinygrad.device import Device, Buffer
from tinygrad.device import Device, Buffer, BufferXfer
from tinygrad.dtype import dtypes
from tinygrad.runtime.driver.hsa import AQLQueue
from tinygrad.runtime.graph.hsa import VirtAQLQueue
from tinygrad.runtime.graph.hsa import VirtAQLQueue, HSAGraph
from tinygrad.features.jit import JitItem
def get_hsa_inc_prog(dev, inc=1):
prg = f"""
@ -91,6 +92,25 @@ class TestHSADriver(unittest.TestCase):
assert test_buf.as_buffer().cast('I')[0] == expected, f"{test_buf.as_buffer().cast('I')[0]} != {expected}, all packets executed?"
del queue, clprogs
def test_hsa_copies_sync(self):
d0, d1 = f"{Device.DEFAULT}:0", f"{Device.DEFAULT}:1"
test_buf0 = Buffer(d0, 1, dtypes.int)
test_buf1 = Buffer(d0, 1, dtypes.int)
test_buf2 = Buffer(d1, 1, dtypes.int)
test_buf0.copyin(memoryview(bytearray(1*4)))
test_buf1.copyin(memoryview(bytearray(1*4)))
test_buf2.copyin(memoryview(bytearray(1*4)))
jit_cache = [JitItem(BufferXfer(), [test_buf0, test_buf2]), JitItem(BufferXfer(), [test_buf2, test_buf1])]
graph = HSAGraph(jit_cache, [], {})
for i in range(10000):
test_buf0.copyin(memoryview(bytearray(1*4)))
test_buf2.copyin(memoryview(bytearray(int.to_bytes(4, length=1*4, byteorder='little'))))
graph([], {})
assert test_buf0.as_buffer().cast('I')[0] == 4
assert test_buf2.as_buffer().cast('I')[0] == 0
if __name__ == '__main__':
unittest.main()

View File

@ -187,7 +187,7 @@ class HSAGraph(MultiDeviceJITGraph):
# When synchronizing to aql packets, we only need to sync to the latest one, as they are executed in order.
signal_deps, aql_deps = [x for x in rdeps if isinstance(x, hsa.hsa_signal_t)], [x for x in rdeps if isinstance(x, int)]
deps = signal_deps + [max(aql_deps)] if aql_deps else []
deps = signal_deps + ([max(aql_deps)] if len(aql_deps) > 0 else [])
for dep in deps: wait_signals.append(self.dependency_as_signal(dep, sync_with_aql_packets=sync_with_aql_packets))
if new_dependency is not None: