fix addresses of dispatch packets (#3534)

This commit is contained in:
nimlgen 2024-02-29 16:43:55 +03:00 committed by GitHub
parent 9268a8b154
commit b05776ef3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 8 additions and 7 deletions

View File

@ -80,10 +80,10 @@ class HSAGraph(MultiDeviceJITGraph):
for j,ji in enumerate(self.jit_cache):
if isinstance(ji.prg, CompiledASTRunner):
self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)
wait_signals = self.access_resources(read=ji.rawbufs[1:], write=ji.rawbufs[0:1], new_dependency=self.packets[j], sync_with_aql_packets=False)
wait_signals = self.access_resources(read=ji.rawbufs[1:], write=ji.rawbufs[0:1], new_dependency=j, sync_with_aql_packets=False)
for i in range(0, len(wait_signals), 5):
self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals=wait_signals[i:i+5])
self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)
self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.launch_dims(var_vals), ctypes.addressof(self.ji_kargs_structs[j])) #type:ignore
elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]]
@ -99,6 +99,7 @@ class HSAGraph(MultiDeviceJITGraph):
# Make sure the src buffer can be by other devices.
c_agents = (hsa.hsa_agent_t * len(HSADevice.devices))(*[dev.agent for dev in HSADevice.devices])
check(hsa.hsa_amd_agents_allow_access(len(HSADevice.devices), c_agents, None, src._buf))
check(hsa.hsa_amd_agents_allow_access(len(HSADevice.devices), c_agents, None, dest._buf))
# Wait for all active signals to finish the graph
wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list)
@ -159,11 +160,11 @@ class HSAGraph(MultiDeviceJITGraph):
def dependency_as_signal(self, dep, sync_with_aql_packets) -> Optional[hsa.hsa_signal_t]:
if isinstance(dep, hsa.hsa_signal_t): return dep
elif sync_with_aql_packets and isinstance(dep, hsa.hsa_kernel_dispatch_packet_t):
if dep.completion_signal.handle == EMPTY_SIGNAL.handle:
dep.completion_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
self.signals_to_reset.append(dep.completion_signal)
return dep.completion_signal
elif sync_with_aql_packets and isinstance(packet := self.packets.get(dep), hsa.hsa_kernel_dispatch_packet_t):
if packet.completion_signal.handle == EMPTY_SIGNAL.handle:
packet.completion_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
self.signals_to_reset.append(packet.completion_signal)
return packet.completion_signal
return None
def access_resources(self, read, write, new_dependency=None, sync_with_aql_packets=False):