mirror of https://github.com/commaai/tinygrad.git
cuda p2p enable when available (#4153)
This commit is contained in:
parent
380f27d629
commit
5a57b48134
|
@ -46,7 +46,7 @@ class CUDAGraph(MultiDeviceJITGraph):
|
|||
node_from = cuda.CUgraphNode()
|
||||
deps = self.access_resources(read=[src], write=[dest], new_dependency=node_from)
|
||||
c_deps = (cuda.CUgraphNode*len(deps))(*deps) if deps else None
|
||||
if getenv("CUDA_P2P"):
|
||||
if getenv("CUDA_P2P", CUDADevice.peer_access):
|
||||
cp_params = cuda.CUDA_MEMCPY3D_v2(srcMemoryType=cuda.CU_MEMORYTYPE_DEVICE, srcDevice=src._buf, srcPitch=src.nbytes, srcHeight=1,
|
||||
dstMemoryType=cuda.CU_MEMORYTYPE_DEVICE, dstDevice=dest._buf, dstPitch=dest.nbytes, dstHeight=1,
|
||||
WidthInBytes=dest.nbytes, Height=1, Depth=1)
|
||||
|
|
|
@ -150,15 +150,25 @@ class CUDAAllocator(LRUAllocator):
|
|||
|
||||
class CUDADevice(Compiled):
|
||||
devices: List[CUDADevice] = []
|
||||
peer_access = False
|
||||
|
||||
def __init__(self, device:str):
|
||||
device_id = int(device.split(":")[1]) if ":" in device else 0
|
||||
if not CUDACPU:
|
||||
check(cuda.cuInit(0))
|
||||
check(cuda.cuDeviceGet(ctypes.byref(cu_device := cuda.CUdevice()), device_id))
|
||||
self.context = init_c_var(cuda.CUcontext(), lambda x: check(cuda.cuCtxCreate_v2(ctypes.byref(x), 0, cu_device)))
|
||||
self.cu_device = init_c_var(cuda.CUdevice(), lambda x: check(cuda.cuDeviceGet(ctypes.byref(x), device_id)))
|
||||
self.context = init_c_var(cuda.CUcontext(), lambda x: check(cuda.cuCtxCreate_v2(ctypes.byref(x), 0, self.cu_device)))
|
||||
check(cuda.cuDeviceComputeCapability(ctypes.byref(major := ctypes.c_int()), ctypes.byref(minor := ctypes.c_int()), device_id))
|
||||
|
||||
for dev in CUDADevice.devices:
|
||||
check(cuda.cuDeviceCanAccessPeer(ctypes.byref(val := ctypes.c_int()), self.cu_device, dev.cu_device))
|
||||
if val.value != 1: continue
|
||||
check(cuda.cuCtxSetCurrent(dev.context))
|
||||
check(cuda.cuCtxEnablePeerAccess(self.context, 0))
|
||||
check(cuda.cuCtxSetCurrent(self.context))
|
||||
check(cuda.cuCtxEnablePeerAccess(dev.context, 0))
|
||||
CUDADevice.peer_access = True
|
||||
|
||||
self.arch = f"sm_{major.value}{minor.value}" if not CUDACPU else "sm_35"
|
||||
self.pending_copyin: List[Tuple[int, int, Optional[BufferOptions]]] = []
|
||||
CUDADevice.devices.append(self)
|
||||
|
|
Loading…
Reference in New Issue