diff --git a/tinygrad/runtime/driver/hsa.py b/tinygrad/runtime/driver/hsa.py index 5aaeaf48..0322bb60 100644 --- a/tinygrad/runtime/driver/hsa.py +++ b/tinygrad/runtime/driver/hsa.py @@ -1,4 +1,4 @@ -import ctypes +import ctypes, collections import tinygrad.runtime.autogen.hsa as hsa from tinygrad.helpers import init_c_var @@ -124,20 +124,17 @@ class AQLQueue: def _alloc_signal(self, reusable=False): return self.device.alloc_signal(reusable=reusable) -def find_agent(typ, device_id): +def scan_agents(): + agents = collections.defaultdict(list) + @ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_agent_t, ctypes.c_void_p) - def __filter_agents(agent, data): + def __scan_agents(agent, data): status = hsa.hsa_agent_get_info(agent, hsa.HSA_AGENT_INFO_DEVICE, ctypes.byref(device_type := hsa.hsa_device_type_t())) - if status == 0 and device_type.value == typ: - ret = ctypes.cast(data, ctypes.POINTER(hsa.hsa_agent_t)) - if ret[0].handle == device_id: - ret[0] = agent - return hsa.HSA_STATUS_INFO_BREAK - ret[0].handle = ret[0].handle + 1 + if status == 0: agents[device_type.value].append(agent) return hsa.HSA_STATUS_SUCCESS - hsa.hsa_iterate_agents(__filter_agents, ctypes.byref(agent := hsa.hsa_agent_t())) - return agent + hsa.hsa_iterate_agents(__scan_agents, None) + return agents def find_memory_pool(agent, segtyp=-1, flags=-1, location=-1): @ctypes.CFUNCTYPE(hsa.hsa_status_t, hsa.hsa_amd_memory_pool_t, ctypes.c_void_p) diff --git a/tinygrad/runtime/graph/hsa.py b/tinygrad/runtime/graph/hsa.py index 9ec8ef4c..e6189b3e 100644 --- a/tinygrad/runtime/graph/hsa.py +++ b/tinygrad/runtime/graph/hsa.py @@ -96,11 +96,6 @@ class HSAGraph(MultiDeviceJITGraph): self.transfers.append((dest._buf, dest_dev.agent, src._buf, src_dev.agent, dest.nbytes, len(wait_signals), (hsa.hsa_signal_t*len(wait_signals))(*wait_signals), sync_signal, hsa.HSA_AMD_SDMA_ENGINE_0, True)) - # Make sure the src buffer can be by other devices. - c_agents = (hsa.hsa_agent_t * len(HSADevice.devices))(*[dev.agent for dev in HSADevice.devices]) - check(hsa.hsa_amd_agents_allow_access(len(HSADevice.devices), c_agents, None, src._buf)) - check(hsa.hsa_amd_agents_allow_access(len(HSADevice.devices), c_agents, None, dest._buf)) - # Wait for all active signals to finish the graph wait_signals_to_finish: Dict[HSADevice, List[hsa.hsa_signal_t]] = collections.defaultdict(list) for v in dedup_signals([s for s in list(self.w_dependency_map.values())+list(self.r_dependency_map.values()) if isinstance(s, hsa.hsa_signal_t)]): diff --git a/tinygrad/runtime/ops_hsa.py b/tinygrad/runtime/ops_hsa.py index d5e17bd5..76abf97e 100644 --- a/tinygrad/runtime/ops_hsa.py +++ b/tinygrad/runtime/ops_hsa.py @@ -1,11 +1,11 @@ from __future__ import annotations import ctypes, functools, subprocess, io, atexit -from typing import Tuple, TypeVar, List +from typing import Tuple, TypeVar, List, Dict import tinygrad.runtime.autogen.hsa as hsa from tinygrad.helpers import DEBUG, init_c_var, from_mv, round_up, to_mv, init_c_struct_t from tinygrad.device import Compiled, LRUAllocator from tinygrad.runtime.ops_hip import HIPCompiler -from tinygrad.runtime.driver.hsa import check, find_agent, find_memory_pool, AQLQueue +from tinygrad.runtime.driver.hsa import check, scan_agents, find_memory_pool, AQLQueue HSACompiler = HIPCompiler @@ -61,9 +61,9 @@ class HSAAllocator(LRUAllocator): super().__init__() def _alloc(self, size:int): - c_agents = (hsa.hsa_agent_t * len(HSADevice.devices))(*[dev.agent for dev in HSADevice.devices]) + c_agents = (hsa.hsa_agent_t * len(HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU]))(*HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU]) check(hsa.hsa_amd_memory_pool_allocate(self.device.gpu_mempool, size, 0, ctypes.byref(buf := ctypes.c_void_p()))) - check(hsa.hsa_amd_agents_allow_access(len(HSADevice.devices), c_agents, None, buf)) + check(hsa.hsa_amd_agents_allow_access(len(HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU]), c_agents, None, buf)) return buf.value def _free(self, opaque:T): @@ -74,7 +74,7 @@ class HSAAllocator(LRUAllocator): # Async copyin sync model uses barriers on the main hw queue, since barriers are guaranteed to execute in order with all other packets. copy_signal = self.device.alloc_signal(reusable=True) sync_signal = self.device.hw_queue.submit_barrier(need_signal=True) - c_agents = (hsa.hsa_agent_t*2)(*[HSADevice.cpu_agent, self.device.agent]) + c_agents = (hsa.hsa_agent_t*2)(HSADevice.cpu_agent, self.device.agent) check(hsa.hsa_amd_memory_pool_allocate(HSADevice.cpu_mempool, src.nbytes, 0, ctypes.byref(mem := ctypes.c_void_p()))) check(hsa.hsa_amd_agents_allow_access(2, c_agents, None, mem)) ctypes.memmove(mem, from_mv(src), src.nbytes) @@ -87,7 +87,7 @@ class HSAAllocator(LRUAllocator): sync_signal = self.device.hw_queue.submit_barrier(need_signal=True) if not hasattr(self, 'hb'): - c_agents = (hsa.hsa_agent_t*2)(*[HSADevice.cpu_agent, self.device.agent]) + c_agents = (hsa.hsa_agent_t*2)(HSADevice.cpu_agent, self.device.agent) self.hb = [] for _ in range(2): check(hsa.hsa_amd_memory_pool_allocate(HSADevice.cpu_mempool, CHUNK_SIZE, 0, ctypes.byref(mem := ctypes.c_void_p()))) @@ -128,7 +128,7 @@ class HSAAllocator(LRUAllocator): def copyout(self, dest:memoryview, src:T): HSADevice.synchronize_system() copy_signal = self.device.alloc_signal(reusable=True) - c_agents = (hsa.hsa_agent_t*2)(*[HSADevice.cpu_agent, self.device.agent]) + c_agents = (hsa.hsa_agent_t*2)(self.device.agent, HSADevice.cpu_agent) check(hsa.hsa_amd_memory_lock_to_pool(from_mv(dest), dest.nbytes, c_agents, 2, HSADevice.cpu_mempool, 0, ctypes.byref(addr:=ctypes.c_void_p()))) check(hsa.hsa_amd_memory_async_copy(addr, HSADevice.cpu_agent, src, self.device.agent, dest.nbytes, 0, None, copy_signal)) hsa.hsa_signal_wait_scacquire(copy_signal, hsa.HSA_SIGNAL_CONDITION_LT, 1, (1 << 64) - 1, hsa.HSA_WAIT_STATE_ACTIVE) @@ -144,18 +144,20 @@ class HSAAllocator(LRUAllocator): dest_dev.hw_queue.submit_barrier(wait_signals=[copy_signal]) class HSADevice(Compiled): - cpu_agent = None - cpu_mempool = None devices: List[HSADevice] = [] + agents: Dict[int, List[hsa.hsa_agent_t]] = {} + cpu_agent: hsa.hsa_agent_t + cpu_mempool: hsa.hsa_amd_memory_pool_t def __init__(self, device:str=""): - if not HSADevice.cpu_agent: + if not HSADevice.agents: check(hsa.hsa_init()) atexit.register(lambda: hsa.hsa_shut_down()) - HSADevice.cpu_agent = find_agent(hsa.HSA_DEVICE_TYPE_CPU, device_id=0) + HSADevice.agents = scan_agents() + HSADevice.cpu_agent = HSADevice.agents[hsa.HSA_DEVICE_TYPE_CPU][0] HSADevice.cpu_mempool = find_memory_pool(HSADevice.cpu_agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, location=hsa.HSA_AMD_MEMORY_POOL_LOCATION_CPU) self.device_id = int(device.split(":")[1]) if ":" in device else 0 - self.agent = find_agent(hsa.HSA_DEVICE_TYPE_GPU, device_id=self.device_id) + self.agent = HSADevice.agents[hsa.HSA_DEVICE_TYPE_GPU][self.device_id] self.gpu_mempool = find_memory_pool(self.agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, location=hsa.HSA_AMD_MEMORY_POOL_LOCATION_GPU) self.kernargs_pool = find_memory_pool(self.agent, segtyp=hsa.HSA_AMD_SEGMENT_GLOBAL, flags=hsa.HSA_AMD_MEMORY_POOL_GLOBAL_FLAG_KERNARG_INIT) self.hw_queue = AQLQueue(self)