mirror of https://github.com/commaai/tinygrad.git
fix addresses of dispatch packets (#3534)
This commit is contained in:
parent
9268a8b154
commit
b05776ef3e
|
@ -80,10 +80,10 @@ class HSAGraph(MultiDeviceJITGraph):
|
|||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
if isinstance(ji.prg, CompiledASTRunner):
|
||||
self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)
|
||||
wait_signals = self.access_resources(read=ji.rawbufs[1:], write=ji.rawbufs[0:1], new_dependency=self.packets[j], sync_with_aql_packets=False)
|
||||
wait_signals = self.access_resources(read=ji.rawbufs[1:], write=ji.rawbufs[0:1], new_dependency=j, sync_with_aql_packets=False)
|
||||
for i in range(0, len(wait_signals), 5):
|
||||
self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals=wait_signals[i:i+5])
|
||||
self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)
|
||||
self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.launch_dims(var_vals), ctypes.addressof(self.ji_kargs_structs[j])) #type:ignore
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]]
|
||||
|
@ -99,6 +99,7 @@ class HSAGraph(MultiDeviceJITGraph):
|
|||
# Make sure the src buffer can be by other devices.
|
||||
c_agents = (hsa.hsa_agent_t * len(HSADevice.devices))(*[dev.agent for dev in HSADevice.devices])
|
||||
check(hsa.hsa_amd_agents_allow_access(len(HSADevice.devices), c_agents, None, src._buf))
|
||||
check(hsa.hsa_amd_agents_allow_access(len(HSADevice.devices), c_agents, None, dest._buf))
|
||||
|
||||
# Wait for all active signals to finish the graph
|
||||
wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list)
|
||||
|
@ -159,11 +160,11 @@ class HSAGraph(MultiDeviceJITGraph):
|
|||
|
||||
def dependency_as_signal(self, dep, sync_with_aql_packets) -> Optional[hsa.hsa_signal_t]:
|
||||
if isinstance(dep, hsa.hsa_signal_t): return dep
|
||||
elif sync_with_aql_packets and isinstance(dep, hsa.hsa_kernel_dispatch_packet_t):
|
||||
if dep.completion_signal.handle == EMPTY_SIGNAL.handle:
|
||||
dep.completion_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
|
||||
self.signals_to_reset.append(dep.completion_signal)
|
||||
return dep.completion_signal
|
||||
elif sync_with_aql_packets and isinstance(packet := self.packets.get(dep), hsa.hsa_kernel_dispatch_packet_t):
|
||||
if packet.completion_signal.handle == EMPTY_SIGNAL.handle:
|
||||
packet.completion_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
|
||||
self.signals_to_reset.append(packet.completion_signal)
|
||||
return packet.completion_signal
|
||||
return None
|
||||
|
||||
def access_resources(self, read, write, new_dependency=None, sync_with_aql_packets=False):
|
||||
|
|
Loading…
Reference in New Issue