mirror of https://github.com/commaai/tinygrad.git
parent
99cbc24390
commit
2d54e4d747
|
@ -46,9 +46,8 @@ class AQLQueue:
|
|||
def __del__(self):
|
||||
if hasattr(self, 'hw_queue'): check(hsa.hsa_queue_destroy(self.hw_queue))
|
||||
|
||||
def submit_kernel(self, prg, global_size, local_size, kernargs, need_signal=False):
|
||||
def submit_kernel(self, prg, global_size, local_size, kernargs, completion_signal=None):
|
||||
if self.available_packet_slots == 0: self._wait_queue()
|
||||
signal = self._alloc_signal(reusable=True) if need_signal else EMPTY_SIGNAL
|
||||
|
||||
packet = hsa.hsa_kernel_dispatch_packet_t.from_address(self.write_addr)
|
||||
packet.workgroup_size_x = local_size[0]
|
||||
|
@ -63,17 +62,14 @@ class AQLQueue:
|
|||
packet.kernel_object = prg.handle
|
||||
packet.kernarg_address = kernargs
|
||||
packet.reserved2 = 0
|
||||
packet.completion_signal = signal
|
||||
packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
|
||||
packet.setup = DISPATCH_KERNEL_SETUP
|
||||
packet.header = DISPATCH_KERNEL_HEADER
|
||||
self._submit_packet()
|
||||
|
||||
return signal
|
||||
|
||||
def submit_barrier(self, wait_signals=None, need_signal=False, completion_signal=None):
|
||||
def submit_barrier(self, wait_signals=None, completion_signal=None):
|
||||
assert wait_signals is None or len(wait_signals) <= 5
|
||||
if self.available_packet_slots == 0: self._wait_queue()
|
||||
signal = (completion_signal or self._alloc_signal(reusable=True)) if need_signal else EMPTY_SIGNAL
|
||||
|
||||
packet = hsa.hsa_barrier_and_packet_t.from_address(self.write_addr)
|
||||
packet.reserved0 = 0
|
||||
|
@ -81,12 +77,10 @@ class AQLQueue:
|
|||
for i in range(5):
|
||||
packet.dep_signal[i] = wait_signals[i] if wait_signals and len(wait_signals) > i else EMPTY_SIGNAL
|
||||
packet.reserved2 = 0
|
||||
packet.completion_signal = signal
|
||||
packet.completion_signal = completion_signal if completion_signal else EMPTY_SIGNAL
|
||||
packet.header = BARRIER_HEADER
|
||||
self._submit_packet()
|
||||
|
||||
return signal
|
||||
|
||||
def blit_packets(self, packet_addr, packet_cnt):
|
||||
if self.available_packet_slots < packet_cnt: self._wait_queue(packet_cnt)
|
||||
|
||||
|
@ -98,8 +92,8 @@ class AQLQueue:
|
|||
self._submit_packet(packet_cnt)
|
||||
|
||||
def wait(self):
|
||||
signal = self.submit_barrier(need_signal=True)
|
||||
hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
self.submit_barrier([], finish_signal := self.device.alloc_signal(reusable=True))
|
||||
hsa.hsa_signal_wait_scacquire(finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
self.available_packet_slots = self.queue_size // AQL_PACKET_SIZE
|
||||
|
||||
def _wait_queue(self, need_packets=1):
|
||||
|
@ -117,8 +111,6 @@ class AQLQueue:
|
|||
if self.write_addr > self.write_addr_end:
|
||||
self.write_addr = self.queue_base + (self.write_addr - self.queue_base) % self.queue_size
|
||||
|
||||
def _alloc_signal(self, reusable=False): return self.device.alloc_signal(reusable=reusable)
|
||||
|
||||
def scan_agents():
|
||||
agents = collections.defaultdict(list)
|
||||
|
||||
|
|
|
@ -23,7 +23,6 @@ class VirtAQLQueue(AQLQueue):
|
|||
self.write_addr += AQL_PACKET_SIZE
|
||||
self.packets_count += 1
|
||||
self.available_packet_slots -= 1
|
||||
def _alloc_signal(self, reusable=False): return init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_signal_create(1, 0, None, ctypes.byref(x))))
|
||||
|
||||
class HSAGraph(MultiDeviceJITGraph):
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
|
@ -67,31 +66,28 @@ class HSAGraph(MultiDeviceJITGraph):
|
|||
self.signals_to_reset: List[hsa.hsa_signal_t] = []
|
||||
self.w_dependency_map: Dict[Any, Union[hsa.hsa_signal_t, int]] = {}
|
||||
self.r_dependency_map: Dict[Any, List[Union[hsa.hsa_signal_t, int]]] = collections.defaultdict(list)
|
||||
signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {}
|
||||
self.signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {}
|
||||
self.profile_info: Dict[Compiled, List[Tuple[Any, ...]]] = collections.defaultdict(list)
|
||||
|
||||
# Special packet to wait for the world.
|
||||
self.kickoff_signals: Dict[HSADevice, hsa.hsa_signal_t] = {}
|
||||
for dev in self.devices: self.kickoff_signals[dev] = self.virt_aql_queues[dev].submit_barrier(need_signal=True)
|
||||
self.signals_to_reset += list(self.kickoff_signals.values())
|
||||
self.kickoff_signals: Dict[HSADevice, hsa.hsa_signal_t] = {dev:self.alloc_signal(reset_on_start=True) for dev in self.devices}
|
||||
for dev in self.devices: self.virt_aql_queues[dev].submit_barrier([], self.kickoff_signals[dev])
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
if isinstance(ji.prg, CompiledASTRunner):
|
||||
wait_signals = self.access_resources(ji.rawbufs[(outs:=ji.prg.outcount):], ji.rawbufs[:outs], new_dependency=j, sync_with_aql_packets=False)
|
||||
for i in range(0, len(wait_signals), 5):
|
||||
self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals=wait_signals[i:i+5])
|
||||
self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals[i:i+5])
|
||||
self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)
|
||||
sync_signal = self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.launch_dims(var_vals), #type:ignore
|
||||
ctypes.addressof(self.ji_kargs_structs[j]), need_signal=PROFILE)
|
||||
if PROFILE:
|
||||
self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False))
|
||||
self.signals_to_reset.append(sync_signal)
|
||||
|
||||
sync_signal = self.alloc_signal(reset_on_start=True) if PROFILE else None
|
||||
self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.launch_dims(var_vals), #type:ignore
|
||||
ctypes.addressof(self.ji_kargs_structs[j]), completion_signal=sync_signal)
|
||||
if PROFILE: self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False))
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
dest, src = [cast(Buffer, x) for x in ji.rawbufs[0:2]]
|
||||
dest_dev, src_dev = cast(HSADevice, dest.d), cast(HSADevice, src.d)
|
||||
sync_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
|
||||
self.signals_to_reset.append(sync_signal)
|
||||
signals_to_devices[sync_signal.handle] = [dest_dev, src_dev]
|
||||
sync_signal = self.alloc_signal(reset_on_start=True, wait_on=[dest_dev, src_dev])
|
||||
|
||||
wait_signals = self.access_resources(read=[src], write=[dest], new_dependency=sync_signal, sync_with_aql_packets=True)
|
||||
self.transfers.append([dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals),
|
||||
|
@ -102,14 +98,14 @@ class HSAGraph(MultiDeviceJITGraph):
|
|||
# Wait for all active signals to finish the graph
|
||||
wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list)
|
||||
for v in dedup_signals(list(self.w_dependency_map.values()) + list(itertools.chain.from_iterable(self.r_dependency_map.values()))):
|
||||
for dev in signals_to_devices[v.handle]:
|
||||
for dev in self.signals_to_devices[v.handle]:
|
||||
wait_signals_to_finish[dev].append(v)
|
||||
|
||||
self.finish_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
|
||||
for dev in self.devices:
|
||||
wait_signals = wait_signals_to_finish[dev]
|
||||
for i in range(0, max(1, len(wait_signals)), 5):
|
||||
self.virt_aql_queues[dev].submit_barrier(wait_signals[i:i+5], need_signal=(i+5>=len(wait_signals)), completion_signal=self.finish_signal)
|
||||
self.virt_aql_queues[dev].submit_barrier(wait_signals[i:i+5], completion_signal=self.finish_signal if i+5>=len(wait_signals) else None)
|
||||
|
||||
# Zero signals to allow graph to start and execute.
|
||||
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0)
|
||||
|
@ -162,12 +158,16 @@ class HSAGraph(MultiDeviceJITGraph):
|
|||
jit=jit, num_kernels=len(self.jit_cache), device="HSA")
|
||||
return et
|
||||
|
||||
def alloc_signal(self, reset_on_start=False, wait_on=None):
|
||||
sync_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
|
||||
if reset_on_start: self.signals_to_reset.append(sync_signal)
|
||||
if wait_on is not None: self.signals_to_devices[sync_signal.handle] = wait_on
|
||||
return sync_signal
|
||||
|
||||
def dependency_as_signal(self, dep, sync_with_aql_packets) -> Optional[hsa.hsa_signal_t]:
|
||||
if isinstance(dep, hsa.hsa_signal_t): return dep
|
||||
elif sync_with_aql_packets and isinstance(packet := self.packets.get(dep), hsa.hsa_kernel_dispatch_packet_t):
|
||||
if packet.completion_signal.handle == EMPTY_SIGNAL.handle:
|
||||
packet.completion_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
|
||||
self.signals_to_reset.append(packet.completion_signal)
|
||||
if packet.completion_signal.handle == EMPTY_SIGNAL.handle: packet.completion_signal = self.alloc_signal(reset_on_start=True)
|
||||
return packet.completion_signal
|
||||
return None
|
||||
|
||||
|
|
|
@ -89,7 +89,8 @@ class HSAProgram:
|
|||
for i in range(len(vals)): args_st.__setattr__(f'v{i}', vals[i])
|
||||
self.device.flush_hdp()
|
||||
|
||||
signal = self.device.hw_queue.submit_kernel(self, global_size, local_size, kernargs, need_signal=(wait or PROFILE))
|
||||
signal = self.device.alloc_signal(reusable=True) if wait or PROFILE else None
|
||||
self.device.hw_queue.submit_kernel(self, global_size, local_size, kernargs, completion_signal=signal)
|
||||
if PROFILE: Profiler.track(signal, self.device, self.name)
|
||||
if wait:
|
||||
hsa.hsa_signal_wait_scacquire(signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
||||
|
@ -122,18 +123,17 @@ class HSAAllocator(LRUAllocator):
|
|||
|
||||
def copyin(self, dest:T, src: memoryview):
|
||||
# Async copyin sync model uses barriers on the main hw queue, since barriers are guaranteed to execute in order with all other packets.
|
||||
copy_signal = self.device.alloc_signal(reusable=True)
|
||||
sync_signal = self.device.hw_queue.submit_barrier(need_signal=True)
|
||||
self.device.hw_queue.submit_barrier([], sync_signal := self.device.alloc_signal(reusable=True))
|
||||
mem = self._alloc_with_options(src.nbytes, BufferOptions(host=True))
|
||||
ctypes.memmove(mem, from_mv(src), src.nbytes)
|
||||
check(hsa.hsa_amd_memory_async_copy_on_engine(dest, self.device.agent, mem, HSADevice.cpu_agent, src.nbytes,
|
||||
1, ctypes.byref(sync_signal), copy_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True))
|
||||
self.device.hw_queue.submit_barrier(wait_signals=[copy_signal])
|
||||
check(hsa.hsa_amd_memory_async_copy_on_engine(dest, self.device.agent, mem, HSADevice.cpu_agent, src.nbytes, 1, ctypes.byref(sync_signal),
|
||||
copy_signal := self.device.alloc_signal(reusable=True), hsa.HSA_AMD_SDMA_ENGINE_0, True))
|
||||
self.device.hw_queue.submit_barrier([copy_signal])
|
||||
self.device.delayed_free.append(mem)
|
||||
if PROFILE: Profiler.track(copy_signal, self.device, f"copyin: CPU -> HSA:{self.device.device_id}", is_copy=True)
|
||||
|
||||
def copy_from_fd(self, dest, fd, offset, size):
|
||||
sync_signal = self.device.hw_queue.submit_barrier(need_signal=True)
|
||||
self.device.hw_queue.submit_barrier([], sync_signal := self.device.alloc_signal(reusable=True))
|
||||
|
||||
if not hasattr(self, 'hb'):
|
||||
self.hb = [self._alloc_with_options(CHUNK_SIZE, BufferOptions(host=True)) for _ in range(2)]
|
||||
|
@ -167,7 +167,7 @@ class HSAAllocator(LRUAllocator):
|
|||
|
||||
wait_signals = [self.hb_signals[self.hb_polarity - 1]]
|
||||
if copies_called > 1: wait_signals.append(self.hb_signals[self.hb_polarity])
|
||||
self.device.hw_queue.submit_barrier(wait_signals=wait_signals)
|
||||
self.device.hw_queue.submit_barrier(wait_signals)
|
||||
|
||||
def copyout(self, dest:memoryview, src:T):
|
||||
HSADevice.synchronize_system()
|
||||
|
@ -180,13 +180,13 @@ class HSAAllocator(LRUAllocator):
|
|||
if PROFILE: Profiler.track(copy_signal, self.device, f"copyout: HSA:{self.device.device_id} -> CPU", is_copy=True)
|
||||
|
||||
def transfer(self, dest:T, src:T, sz:int, src_dev=None, dest_dev=None):
|
||||
copy_signal = dest_dev.alloc_signal(reusable=False)
|
||||
sync_signal_1 = src_dev.hw_queue.submit_barrier(need_signal=True)
|
||||
sync_signal_2 = dest_dev.hw_queue.submit_barrier(need_signal=True)
|
||||
src_dev.hw_queue.submit_barrier([], sync_signal_1 := src_dev.alloc_signal(reusable=True))
|
||||
dest_dev.hw_queue.submit_barrier([], sync_signal_2 := dest_dev.alloc_signal(reusable=True))
|
||||
c_wait_signal = (hsa.hsa_signal_t*2)(sync_signal_1, sync_signal_2)
|
||||
check(hsa.hsa_amd_memory_async_copy_on_engine(dest, dest_dev.agent, src, src_dev.agent, sz, 2, c_wait_signal, copy_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True)) # noqa: E501
|
||||
src_dev.hw_queue.submit_barrier(wait_signals=[copy_signal])
|
||||
dest_dev.hw_queue.submit_barrier(wait_signals=[copy_signal])
|
||||
check(hsa.hsa_amd_memory_async_copy_on_engine(dest, dest_dev.agent, src, src_dev.agent, sz, 2, c_wait_signal,
|
||||
copy_signal := dest_dev.alloc_signal(reusable=False), hsa.HSA_AMD_SDMA_ENGINE_0, True))
|
||||
src_dev.hw_queue.submit_barrier([copy_signal])
|
||||
dest_dev.hw_queue.submit_barrier([copy_signal])
|
||||
if PROFILE: Profiler.track(copy_signal, src_dev, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", is_copy=True)
|
||||
|
||||
class HSADevice(Compiled):
|
||||
|
|
Loading…
Reference in New Issue