2024-03-02 23:37:51 +08:00
|
|
|
import ctypes, collections, time, itertools
|
2024-05-03 05:38:22 +08:00
|
|
|
from typing import List, Any, Dict, cast, Optional, Tuple
|
2024-07-27 07:01:12 +08:00
|
|
|
from tinygrad.helpers import init_c_var, round_up
|
2024-05-11 02:22:31 +08:00
|
|
|
from tinygrad.device import Buffer, BufferOptions
|
2024-05-11 13:43:09 +08:00
|
|
|
from tinygrad.device import Compiled, Device
|
2024-02-29 01:40:53 +08:00
|
|
|
from tinygrad.shape.symbolic import Variable
|
2024-03-14 12:19:22 +08:00
|
|
|
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
|
2024-05-11 13:43:09 +08:00
|
|
|
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
2024-07-27 07:01:12 +08:00
|
|
|
from tinygrad.engine.jit import MultiGraphRunner, GraphException
|
2024-02-29 01:40:53 +08:00
|
|
|
import tinygrad.runtime.autogen.hsa as hsa
|
2024-07-13 02:06:42 +08:00
|
|
|
from tinygrad.runtime.support.hsa import check, AQLQueue, AQL_PACKET_SIZE, EMPTY_SIGNAL
|
2024-02-29 01:40:53 +08:00
|
|
|
|
|
|
|
def dedup_signals(signals): return [hsa.hsa_signal_t(hndl) for hndl in set([x.handle for x in signals if isinstance(x, hsa.hsa_signal_t)])]
|
|
|
|
|
|
|
|
class VirtAQLQueue(AQLQueue):
|
|
|
|
def __init__(self, device, sz):
|
|
|
|
self.device = device
|
|
|
|
self.virt_queue = (hsa.hsa_kernel_dispatch_packet_t * sz)()
|
|
|
|
self.queue_base = self.write_addr = ctypes.addressof(self.virt_queue)
|
|
|
|
self.packets_count = 0
|
|
|
|
self.available_packet_slots = sz
|
|
|
|
def _wait_queue(self, need_packets=1): assert False, f"VirtQueue is too small to handle {self.packets_count+need_packets} packets!"
|
|
|
|
def _submit_packet(self):
|
|
|
|
self.write_addr += AQL_PACKET_SIZE
|
|
|
|
self.packets_count += 1
|
|
|
|
self.available_packet_slots -= 1
|
|
|
|
|
2024-05-02 01:27:13 +08:00
|
|
|
class HSAGraph(MultiGraphRunner):
|
2024-04-11 23:24:57 +08:00
|
|
|
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
2024-05-02 01:27:13 +08:00
|
|
|
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
2024-02-29 01:40:53 +08:00
|
|
|
|
|
|
|
# Check all jit items are compatible.
|
|
|
|
compiled_devices = set()
|
|
|
|
for ji in self.jit_cache:
|
2024-04-11 23:49:52 +08:00
|
|
|
if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.device)
|
2024-02-29 01:40:53 +08:00
|
|
|
elif isinstance(ji.prg, BufferXfer):
|
2024-04-24 11:27:27 +08:00
|
|
|
for x in ji.bufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device])
|
2024-02-29 01:40:53 +08:00
|
|
|
else: raise GraphException
|
|
|
|
if any(not isinstance(d, HSADevice) for d in compiled_devices): raise GraphException
|
|
|
|
|
|
|
|
self.devices: List[HSADevice] = list(compiled_devices) #type:ignore
|
|
|
|
|
|
|
|
# Allocate kernel args.
|
2024-03-16 04:18:40 +08:00
|
|
|
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
|
2024-02-29 01:40:53 +08:00
|
|
|
for ji in self.jit_cache:
|
2024-04-11 23:49:52 +08:00
|
|
|
if isinstance(ji.prg, CompiledRunner): kernargs_size[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
|
2024-03-27 00:11:41 +08:00
|
|
|
kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions()) for dev,sz in kernargs_size.items()}
|
2024-02-29 01:40:53 +08:00
|
|
|
|
|
|
|
# Fill initial arguments.
|
|
|
|
self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
|
|
|
|
for j,ji in enumerate(self.jit_cache):
|
2024-04-11 23:49:52 +08:00
|
|
|
if not isinstance(ji.prg, CompiledRunner): continue
|
2024-02-29 01:40:53 +08:00
|
|
|
self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.device])
|
2024-03-16 04:18:40 +08:00
|
|
|
kernargs_ptrs[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
|
2024-04-24 11:27:27 +08:00
|
|
|
for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf)
|
2024-05-10 08:29:07 +08:00
|
|
|
for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]])
|
2024-02-29 01:40:53 +08:00
|
|
|
|
|
|
|
# Build queues.
|
|
|
|
self.virt_aql_queues: Dict[Compiled, VirtAQLQueue] = {dev:VirtAQLQueue(dev, 2*len(self.jit_cache)+16) for dev in self.devices}
|
|
|
|
self.packets = {}
|
|
|
|
self.transfers = []
|
2024-03-18 23:01:04 +08:00
|
|
|
self.ji_to_transfer: Dict[int, int] = {} # faster to store transfers as list and update using this mapping table.
|
2024-02-29 01:40:53 +08:00
|
|
|
self.signals_to_reset: List[hsa.hsa_signal_t] = []
|
2024-03-20 05:17:41 +08:00
|
|
|
self.signals_to_devices: Dict[ctypes.c_uint64, List[HSADevice]] = {}
|
2024-03-14 12:19:22 +08:00
|
|
|
self.profile_info: Dict[Compiled, List[Tuple[Any, ...]]] = collections.defaultdict(list)
|
2024-02-29 01:40:53 +08:00
|
|
|
|
|
|
|
# Special packet to wait for the world.
|
2024-03-20 05:17:41 +08:00
|
|
|
self.kickoff_signals: Dict[HSADevice, hsa.hsa_signal_t] = {dev:self.alloc_signal(reset_on_start=True) for dev in self.devices}
|
|
|
|
for dev in self.devices: self.virt_aql_queues[dev].submit_barrier([], self.kickoff_signals[dev])
|
2024-02-29 01:40:53 +08:00
|
|
|
|
|
|
|
for j,ji in enumerate(self.jit_cache):
|
2024-04-11 23:49:52 +08:00
|
|
|
if isinstance(ji.prg, CompiledRunner):
|
2024-05-10 08:29:07 +08:00
|
|
|
wait_signals = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], new_dependency=j, sync_with_aql_packets=False)
|
2024-02-29 01:40:53 +08:00
|
|
|
for i in range(0, len(wait_signals), 5):
|
2024-03-20 05:17:41 +08:00
|
|
|
self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals[i:i+5])
|
2024-02-29 21:43:55 +08:00
|
|
|
self.packets[j] = hsa.hsa_kernel_dispatch_packet_t.from_address(self.virt_aql_queues[ji.prg.device].write_addr)
|
2024-03-20 05:17:41 +08:00
|
|
|
|
|
|
|
sync_signal = self.alloc_signal(reset_on_start=True) if PROFILE else None
|
2024-05-10 08:29:07 +08:00
|
|
|
self.virt_aql_queues[ji.prg.device].submit_kernel(ji.prg.clprg, *ji.prg.p.launch_dims(var_vals), #type:ignore
|
2024-03-20 05:17:41 +08:00
|
|
|
ctypes.addressof(self.ji_kargs_structs[j]), completion_signal=sync_signal)
|
|
|
|
if PROFILE: self.profile_info[ji.prg.device].append((sync_signal, ji.prg.clprg.name, False))
|
2024-02-29 01:40:53 +08:00
|
|
|
elif isinstance(ji.prg, BufferXfer):
|
2024-04-24 11:27:27 +08:00
|
|
|
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
2024-03-29 04:33:47 +08:00
|
|
|
dest_dev, src_dev = cast(HSADevice, Device[dest.device]), cast(HSADevice, Device[src.device])
|
2024-03-20 05:17:41 +08:00
|
|
|
sync_signal = self.alloc_signal(reset_on_start=True, wait_on=[dest_dev, src_dev])
|
2024-02-29 01:40:53 +08:00
|
|
|
|
|
|
|
wait_signals = self.access_resources(read=[src], write=[dest], new_dependency=sync_signal, sync_with_aql_packets=True)
|
2024-03-18 23:01:04 +08:00
|
|
|
self.transfers.append([dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals),
|
|
|
|
(hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True])
|
2024-03-19 03:52:27 +08:00
|
|
|
self.ji_to_transfer[j] = len(self.transfers) - 1
|
2024-03-14 12:19:22 +08:00
|
|
|
if PROFILE: self.profile_info[src_dev].append((sync_signal, f"transfer: HSA:{src_dev.device_id} -> HSA:{dest_dev.device_id}", True))
|
2024-02-29 01:40:53 +08:00
|
|
|
|
|
|
|
# Wait for all active signals to finish the graph
|
|
|
|
wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list)
|
2024-03-02 23:37:51 +08:00
|
|
|
for v in dedup_signals(list(self.w_dependency_map.values()) + list(itertools.chain.from_iterable(self.r_dependency_map.values()))):
|
2024-03-20 05:17:41 +08:00
|
|
|
for dev in self.signals_to_devices[v.handle]:
|
2024-02-29 01:40:53 +08:00
|
|
|
wait_signals_to_finish[dev].append(v)
|
|
|
|
|
|
|
|
self.finish_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
|
|
|
|
for dev in self.devices:
|
|
|
|
wait_signals = wait_signals_to_finish[dev]
|
|
|
|
for i in range(0, max(1, len(wait_signals)), 5):
|
2024-03-20 05:17:41 +08:00
|
|
|
self.virt_aql_queues[dev].submit_barrier(wait_signals[i:i+5], completion_signal=self.finish_signal if i+5>=len(wait_signals) else None)
|
2024-02-29 01:40:53 +08:00
|
|
|
|
|
|
|
# Zero signals to allow graph to start and execute.
|
|
|
|
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 0)
|
|
|
|
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, 0)
|
|
|
|
|
2024-04-19 19:41:30 +08:00
|
|
|
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
2024-02-29 01:40:53 +08:00
|
|
|
# Wait and restore signals
|
|
|
|
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
|
|
|
for sig in self.signals_to_reset: hsa.hsa_signal_silent_store_relaxed(sig, 1)
|
|
|
|
hsa.hsa_signal_silent_store_relaxed(self.finish_signal, len(self.devices))
|
|
|
|
|
|
|
|
# Update rawbuffers
|
|
|
|
for (j,i),input_idx in self.input_replace.items():
|
2024-03-18 23:01:04 +08:00
|
|
|
if j in self.ji_kargs_structs:
|
|
|
|
self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf)
|
|
|
|
else:
|
|
|
|
if i == 0: self.transfers[self.ji_to_transfer[j]][0] = input_rawbuffers[input_idx]._buf # dest
|
|
|
|
elif i == 1: self.transfers[self.ji_to_transfer[j]][2] = input_rawbuffers[input_idx]._buf # src
|
2024-02-29 01:40:53 +08:00
|
|
|
|
|
|
|
# Update var_vals
|
2024-05-02 01:27:13 +08:00
|
|
|
for j in self.jc_idx_with_updatable_var_vals:
|
2024-05-10 08:29:07 +08:00
|
|
|
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
|
2024-02-29 01:40:53 +08:00
|
|
|
self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
|
|
|
|
|
|
|
|
# Update launch dims
|
2024-05-02 01:27:13 +08:00
|
|
|
for j in self.jc_idx_with_updatable_launch_dims:
|
2024-05-10 08:29:07 +08:00
|
|
|
gl, lc = cast(CompiledRunner, self.jit_cache[j].prg).p.launch_dims(var_vals)
|
2024-02-29 01:40:53 +08:00
|
|
|
self.packets[j].workgroup_size_x = lc[0]
|
|
|
|
self.packets[j].workgroup_size_y = lc[1]
|
|
|
|
self.packets[j].workgroup_size_z = lc[2]
|
|
|
|
self.packets[j].grid_size_x = gl[0] * lc[0]
|
|
|
|
self.packets[j].grid_size_y = gl[1] * lc[1]
|
|
|
|
self.packets[j].grid_size_z = gl[2] * lc[2]
|
|
|
|
|
|
|
|
for dev in self.devices:
|
2024-03-03 20:55:07 +08:00
|
|
|
dev.flush_hdp()
|
2024-02-29 01:40:53 +08:00
|
|
|
dev.hw_queue.blit_packets(self.virt_aql_queues[dev].queue_base, self.virt_aql_queues[dev].packets_count)
|
|
|
|
|
|
|
|
for transfer_data in self.transfers:
|
|
|
|
check(hsa.hsa_amd_memory_async_copy_on_engine(*transfer_data))
|
|
|
|
|
|
|
|
et = None
|
|
|
|
if wait:
|
|
|
|
st = time.perf_counter()
|
|
|
|
hsa.hsa_signal_wait_scacquire(self.finish_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE)
|
|
|
|
et = time.perf_counter() - st
|
|
|
|
|
2024-03-14 12:19:22 +08:00
|
|
|
for profdev,profdata in self.profile_info.items(): Profiler.tracked_signals[profdev] += profdata
|
2024-02-29 01:40:53 +08:00
|
|
|
return et
|
|
|
|
|
2024-03-20 05:17:41 +08:00
|
|
|
def alloc_signal(self, reset_on_start=False, wait_on=None):
|
|
|
|
sync_signal = init_c_var(hsa.hsa_signal_t(), lambda x: check(hsa.hsa_amd_signal_create(1, 0, None, 0, ctypes.byref(x))))
|
|
|
|
if reset_on_start: self.signals_to_reset.append(sync_signal)
|
|
|
|
if wait_on is not None: self.signals_to_devices[sync_signal.handle] = wait_on
|
|
|
|
return sync_signal
|
|
|
|
|
2024-02-29 01:40:53 +08:00
|
|
|
def dependency_as_signal(self, dep, sync_with_aql_packets) -> Optional[hsa.hsa_signal_t]:
|
|
|
|
if isinstance(dep, hsa.hsa_signal_t): return dep
|
2024-02-29 21:43:55 +08:00
|
|
|
elif sync_with_aql_packets and isinstance(packet := self.packets.get(dep), hsa.hsa_kernel_dispatch_packet_t):
|
2024-03-20 05:17:41 +08:00
|
|
|
if packet.completion_signal.handle == EMPTY_SIGNAL.handle: packet.completion_signal = self.alloc_signal(reset_on_start=True)
|
2024-02-29 21:43:55 +08:00
|
|
|
return packet.completion_signal
|
2024-02-29 01:40:53 +08:00
|
|
|
return None
|
|
|
|
|
2024-05-03 05:38:22 +08:00
|
|
|
def access_resources(self, read, write, new_dependency, sync_with_aql_packets=False):
|
|
|
|
rdeps = self._access_resources(read, write, new_dependency)
|
|
|
|
wait_signals = [self.dependency_as_signal(dep, sync_with_aql_packets=sync_with_aql_packets) for dep in rdeps]
|
2024-03-29 04:39:02 +08:00
|
|
|
if sync_with_aql_packets: wait_signals += [self.kickoff_signals[cast(HSADevice, Device[rawbuf.device])] for rawbuf in read+write]
|
2024-02-29 01:40:53 +08:00
|
|
|
return dedup_signals(wait_signals)
|