clean up hsa driver (#3818)

* clean up driver

* remove returns
This commit is contained in:
nimlgen 2024-03-20 00:17:41 +03:00 committed by GitHub
parent 99cbc24390
commit 2d54e4d747
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 39 additions and 47 deletions

View File

@ -46,9 +46,8 @@ class AQLQueue:
def __del__(self): def __del__(self):
if hasattr(self, 'hw_queue'): check(hsa.hsa_queue_destroy(self.hw_queue)) if hasattr(self, 'hw_queue'): check(hsa.hsa_queue_destroy(self.hw_queue))
def submit_kernel(self, prg, global_size, local_size, kernargs, need_signal=False): def submit_kernel(self, prg, global_size, local_size, kernargs, completion_signal=None):
if self.available_packet_slots == 0: self._wait_queue() if self.available_packet_slots == 0: self._wait_queue()
signal = self._alloc_signal(reusable=True) if need_signal else EMPTY_SIGNAL
packet = hsa.hsa_kernel_dispatch_packet_t.from_address(self.write_addr) packet = hsa.hsa_kernel_dispatch_packet_t.from_address(self.write_addr)
packet.workgroup_size_x = local_size[0] packet.workgroup_size_x = local_size[0]
@ -63,17 +62,14 @@ class AQLQueue:
packet.kernel_object = prg.handle packet.kernel_object = prg.handle
packet.kernarg_address = kernargs packet.kernarg_address = kernargs
packet.reserved2 = 0 packet.reserved2 = 0
packet.completion_signal = signal packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
packet.setup = DISPATCH_KERNEL_SETUP packet.setup = DISPATCH_KERNEL_SETUP
packet.header = DISPATCH_KERNEL_HEADER packet.header = DISPATCH_KERNEL_HEADER
self._submit_packet() self._submit_packet()
return signal def submit_barrier(self, wait_signals=None, completion_signal=None):
def submit_barrier(self, wait_signals=None, need_signal=False, completion_signal=None):
assert wait_signals is None or len(wait_signals) <= 5 assert wait_signals is None or len(wait_signals) <= 5
if self.available_packet_slots == 0: self._wait_queue() if self.available_packet_slots == 0: self._wait_queue()
signal = (completion_signal or self._alloc_signal(reusable=True)) if need_signal else EMPTY_SIGNAL
packet = hsa.hsa_barrier_and_packet_t.from_address(self.write_addr) packet = hsa.hsa_barrier_and_packet_t.from_address(self.write_addr)
packet.reserved0 = 0 packet.reserved0 = 0
@ -81,12 +77,10 @@ class AQLQueue:
for i in range(5): for i in range(5):
packet.dep_signal[i] = wait_signals[i] if wait_signals and len(wait_signals) > i else EMPTY_SIGNAL packet.dep_signal[i] = wait_signals[i] if wait_signals and len(wait_signals) > i else EMPTY_SIGNAL
packet.reserved2 = 0 packet.reserved2 = 0
packet.completion_signal = signal packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
packet.header = BARRIER_HEADER packet.header = BARRIER_HEADER
self._submit_packet() self._submit_packet()
return signal
def blit_packets(self, packet_addr, packet_cnt): def blit_packets(self, packet_addr, packet_cnt):
if self.available_packet_slots < packet_cnt: self._wait_queue(packet_cnt) if self.available_packet_slots < packet_cnt: self._wait_queue(packet_cnt)
@ -98,8 +92,8 @@ class AQLQueue:
self._submit_packet(packet_cnt) self._submit_packet(packet_cnt)
def wait(self): def wait(self):
signal = self.submit_barrier(need_signal=True) self.submit_barrier([], finish_signal := self.device.alloc_signal(reusable=True))
hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE) hsa.hsa_signal_wait_scacquire(finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE
def _wait_queue(self, need_packets=1): def _wait_queue(self, need_packets=1):
@ -117,8 +111,6 @@ class AQLQueue:
if self.write_addr > self.write_addr_end: if self.write_addr > self.write_addr_end:
self.write_addr = self.queue_base + (self.write_addr - self.queue_base) % self.queue_size self.write_addr = self.queue_base + (self.write_addr - self.queue_base) % self.queue_size
def _alloc_signal(self, reusable=False): return self.device.alloc_signal(reusable=reusable)
def scan_agents(): def scan_agents():
agents = collections.defaultdict(list) agents = collections.defaultdict(list)

View File

@ -23,7 +23,6 @@ class VirtAQLQueue(AQLQueue):
self.write_addr += AQL_PACKET_SIZE self.write_addr += AQL_PACKET_SIZE
self.packets_count += 1 self.packets_count += 1
self.available_packet_slots -= 1 self.available_packet_slots -= 1
def _alloc_signal(self, reusable=False): return init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_signal_create(1, 0, None, ctypes.byref(x))))
class HSAGraph(MultiDeviceJITGraph): class HSAGraph(MultiDeviceJITGraph):
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]): def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
@ -67,31 +66,28 @@ class HSAGraph(MultiDeviceJITGraph):
self.signals_to_reset: List[hsa.hsa_signal_t] = [] self.signals_to_reset: List[hsa.hsa_signal_t] = []
self.w_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, int]] = {} self.w_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, int]] = {}
self.r_dependency_map: Dict[Any, List[Union[hsa.hsa_signal_t, int]]] = collections.defaultdict(list) self.r_dependency_map: Dict[Any, List[Union[hsa.hsa_signal_t, int]]] = collections.defaultdict(list)
signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {} self.signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {}
self.profile_info: Dict[Compiled, List[Tuple[Any, ...]]] = collections.defaultdict(list) self.profile_info: Dict[Compiled, List[Tuple[Any, ...]]] = collections.defaultdict(list)
# Special packet to wait for the world. # Special packet to wait for the world.
self.kickoff_signals: Dict[HSADevice, hsa.hsa_signal_t] = {} self.kickoff_signals: Dict[HSADevice, hsa.hsa_signal_t] = {dev:self.alloc_signal(reset_on_start=True) for dev in self.devices}
for dev in self.devices: self.kickoff_signals[dev] = self.virt_aql_queues[dev].submit_barrier(need_signal=True) for dev in self.devices: self.virt_aql_queues[dev].submit_barrier([], self.kickoff_signals[dev])
self.signals_to_reset += list(self.kickoff_signals.values())
for j,ji in enumerate(self.jit_cache): for j,ji in enumerate(self.jit_cache):
if isinstance(ji.prg, CompiledASTRunner): if isinstance(ji.prg, CompiledASTRunner):
wait_signals = self.access_resources(ji.rawbufs[(outs:=ji.prg.outcount):], ji.rawbufs[:outs], new_dependency=j, sync_with_aql_packets=False) wait_signals = self.access_resources(ji.rawbufs[(outs:=ji.prg.outcount):], ji.rawbufs[:outs], new_dependency=j, sync_with_aql_packets=False)
for i in range(0, len(wait_signals), 5): for i in range(0, len(wait_signals), 5):
self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals=wait_signals[i:i+5]) self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals[i:i+5])
self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr) self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)
sync_signal = self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.launch_dims(var_vals), #type:ignore
ctypes.addressof(self.ji_kargs_structs[j]), need_signal=PROFILE) sync_signal = self.alloc_signal(reset_on_start=True) if PROFILE else None
if PROFILE: self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.launch_dims(var_vals), #type:ignore
self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False)) ctypes.addressof(self.ji_kargs_structs[j]), completion_signal=sync_signal)
self.signals_to_reset.append(sync_signal) if PROFILE: self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False))
elif isinstance(ji.prg, BufferXfer): elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]] dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]]
dest_dev, src_dev = cast(HSADevice, dest.d), cast(HSADevice, src.d) dest_dev, src_dev = cast(HSADevice, dest.d), cast(HSADevice, src.d)
sync_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x)))) sync_signal = self.alloc_signal(reset_on_start=True, wait_on=[dest_dev, src_dev])
self.signals_to_reset.append(sync_signal)
signals_to_devices[sync_signal.handle] = [dest_dev, src_dev]
wait_signals = self.access_resources(read=[src], write=[dest], new_dependency=sync_signal, sync_with_aql_packets=True) wait_signals = self.access_resources(read=[src], write=[dest], new_dependency=sync_signal, sync_with_aql_packets=True)
self.transfers.append([dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals), self.transfers.append([dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals),
@ -102,14 +98,14 @@ class HSAGraph(MultiDeviceJITGraph):
# Wait for all active signals to finish the graph # Wait for all active signals to finish the graph
wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list) wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list)
for v in dedup_signals(list(self.w_dependency_map.values()) + list(itertools.chain.from_iterable(self.r_dependency_map.values()))): for v in dedup_signals(list(self.w_dependency_map.values()) + list(itertools.chain.from_iterable(self.r_dependency_map.values()))):
for dev in signals_to_devices[v.handle]: for dev in self.signals_to_devices[v.handle]:
wait_signals_to_finish[dev].append(v) wait_signals_to_finish[dev].append(v)
self.finish_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x)))) self.finish_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
for dev in self.devices: for dev in self.devices:
wait_signals = wait_signals_to_finish[dev] wait_signals = wait_signals_to_finish[dev]
for i in range(0, max(1, len(wait_signals)), 5): for i in range(0, max(1, len(wait_signals)), 5):
self.virt_aql_queues[dev].submit_barrier(wait_signals[i:i+5], need_signal=(i+5>=len(wait_signals)), completion_signal=self.finish_signal) self.virt_aql_queues[dev].submit_barrier(wait_signals[i:i+5], completion_signal=self.finish_signal if i+5>=len(wait_signals) else None)
# Zero signals to allow graph to start and execute. # Zero signals to allow graph to start and execute.
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0) for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0)
@ -162,12 +158,16 @@ class HSAGraph(MultiDeviceJITGraph):
jit=jit, num_kernels=len(self.jit_cache), device="HSA") jit=jit, num_kernels=len(self.jit_cache), device="HSA")
return et return et
def alloc_signal(self, reset_on_start=False, wait_on=None):
sync_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
if reset_on_start: self.signals_to_reset.append(sync_signal)
if wait_on is not None: self.signals_to_devices[sync_signal.handle] = wait_on
return sync_signal
def dependency_as_signal(self, dep, sync_with_aql_packets) -> Optional[hsa.hsa_signal_t]: def dependency_as_signal(self, dep, sync_with_aql_packets) -> Optional[hsa.hsa_signal_t]:
if isinstance(dep, hsa.hsa_signal_t): return dep if isinstance(dep, hsa.hsa_signal_t): return dep
elif sync_with_aql_packets and isinstance(packet := self.packets.get(dep), hsa.hsa_kernel_dispatch_packet_t): elif sync_with_aql_packets and isinstance(packet := self.packets.get(dep), hsa.hsa_kernel_dispatch_packet_t):
if packet.completion_signal.handle == EMPTY_SIGNAL.handle: if packet.completion_signal.handle == EMPTY_SIGNAL.handle: packet.completion_signal = self.alloc_signal(reset_on_start=True)
packet.completion_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
self.signals_to_reset.append(packet.completion_signal)
return packet.completion_signal return packet.completion_signal
return None return None

View File

@ -89,7 +89,8 @@ class HSAProgram:
for i in range(len(vals)): args_st.__setattr__(f'v{i}', vals[i]) for i in range(len(vals)): args_st.__setattr__(f'v{i}', vals[i])
self.device.flush_hdp() self.device.flush_hdp()
signal = self.device.hw_queue.submit_kernel(self, global_size, local_size, kernargs, need_signal=(wait or PROFILE)) signal = self.device.alloc_signal(reusable=True) if wait or PROFILE else None
self.device.hw_queue.submit_kernel(self, global_size, local_size, kernargs, completion_signal=signal)
if PROFILE: Profiler.track(signal, self.device, self.name) if PROFILE: Profiler.track(signal, self.device, self.name)
if wait: if wait:
hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE) hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
@ -122,18 +123,17 @@ class HSAAllocator(LRUAllocator):
def copyin(self, dest:T, src: memoryview): def copyin(self, dest:T, src: memoryview):
# Async copyin sync model uses barriers on the main hw queue, since barriers are guaranteed to execute in order with all other packets. # Async copyin sync model uses barriers on the main hw queue, since barriers are guaranteed to execute in order with all other packets.
copy_signal = self.device.alloc_signal(reusable=True) self.device.hw_queue.submit_barrier([], sync_signal := self.device.alloc_signal(reusable=True))
sync_signal = self.device.hw_queue.submit_barrier(need_signal=True)
mem = self._alloc_with_options(src.nbytes, BufferOptions(host=True)) mem = self._alloc_with_options(src.nbytes, BufferOptions(host=True))
ctypes.memmove(mem, from_mv(src), src.nbytes) ctypes.memmove(mem, from_mv(src), src.nbytes)
check(hsa.hsa_amd_memory_async_copy_on_engine(dest, self.device.agent, mem, HSADevice.cpu_agent, src.nbytes, check(hsa.hsa_amd_memory_async_copy_on_engine(dest, self.device.agent, mem, HSADevice.cpu_agent, src.nbytes, 1, ctypes.byref(sync_signal),
1, ctypes.byref(sync_signal), copy_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True)) copy_signal := self.device.alloc_signal(reusable=True), hsa.HSA_AMD_SDMA_ENGINE_0, True))
self.device.hw_queue.submit_barrier(wait_signals=[copy_signal]) self.device.hw_queue.submit_barrier([copy_signal])
self.device.delayed_free.append(mem) self.device.delayed_free.append(mem)
if PROFILE: Profiler.track(copy_signal, self.device, f"copyin: CPU -> HSA:{self.device.device_id}", is_copy=True) if PROFILE: Profiler.track(copy_signal, self.device, f"copyin: CPU -> HSA:{self.device.device_id}", is_copy=True)
def copy_from_fd(self, dest, fd, offset, size): def copy_from_fd(self, dest, fd, offset, size):
sync_signal = self.device.hw_queue.submit_barrier(need_signal=True) self.device.hw_queue.submit_barrier([], sync_signal := self.device.alloc_signal(reusable=True))
if not hasattr(self, 'hb'): if not hasattr(self, 'hb'):
self.hb = [self._alloc_with_options(CHUNK_SIZE, BufferOptions(host=True)) for _ in range(2)] self.hb = [self._alloc_with_options(CHUNK_SIZE, BufferOptions(host=True)) for _ in range(2)]
@ -167,7 +167,7 @@ class HSAAllocator(LRUAllocator):
wait_signals = [self.hb_signals[self.hb_polarity - 1]] wait_signals = [self.hb_signals[self.hb_polarity - 1]]
if copies_called > 1: wait_signals.append(self.hb_signals[self.hb_polarity]) if copies_called > 1: wait_signals.append(self.hb_signals[self.hb_polarity])
self.device.hw_queue.submit_barrier(wait_signals=wait_signals) self.device.hw_queue.submit_barrier(wait_signals)
def copyout(self, dest:memoryview, src:T): def copyout(self, dest:memoryview, src:T):
HSADevice.synchronize_system() HSADevice.synchronize_system()
@ -180,13 +180,13 @@ class HSAAllocator(LRUAllocator):
if PROFILE: Profiler.track(copy_signal, self.device, f"copyout: HSA:{self.device.device_id} -> CPU", is_copy=True) if PROFILE: Profiler.track(copy_signal, self.device, f"copyout: HSA:{self.device.device_id} -> CPU", is_copy=True)
def transfer(self, dest:T, src:T, sz:int, src_dev=None, dest_dev=None): def transfer(self, dest:T, src:T, sz:int, src_dev=None, dest_dev=None):
copy_signal = dest_dev.alloc_signal(reusable=False) src_dev.hw_queue.submit_barrier([], sync_signal_1 := src_dev.alloc_signal(reusable=True))
sync_signal_1 = src_dev.hw_queue.submit_barrier(need_signal=True) dest_dev.hw_queue.submit_barrier([], sync_signal_2 := dest_dev.alloc_signal(reusable=True))
sync_signal_2 = dest_dev.hw_queue.submit_barrier(need_signal=True)
c_wait_signal = (hsa.hsa_signal_t*2)(sync_signal_1, sync_signal_2) c_wait_signal = (hsa.hsa_signal_t*2)(sync_signal_1, sync_signal_2)
check(hsa.hsa_amd_memory_async_copy_on_engine(dest, dest_dev.agent, src, src_dev.agent, sz, 2, c_wait_signal, copy_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True)) # noqa: E501 check(hsa.hsa_amd_memory_async_copy_on_engine(dest, dest_dev.agent, src, src_dev.agent, sz, 2, c_wait_signal,
src_dev.hw_queue.submit_barrier(wait_signals=[copy_signal]) copy_signal := dest_dev.alloc_signal(reusable=False), hsa.HSA_AMD_SDMA_ENGINE_0, True))
dest_dev.hw_queue.submit_barrier(wait_signals=[copy_signal]) src_dev.hw_queue.submit_barrier([copy_signal])
dest_dev.hw_queue.submit_barrier([copy_signal])
if PROFILE: Profiler.track(copy_signal, src_dev, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", is_copy=True) if PROFILE: Profiler.track(copy_signal, src_dev, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", is_copy=True)
class HSADevice(Compiled): class HSADevice(Compiled):