mirror of https://github.com/commaai/tinygrad.git
optimize nv sync (#4729)
* optimize nv sync * sdma signal without wfi * nv mockgou support * sep change
This commit is contained in:
parent
8415b14978
commit
c87b066b66
|
@ -100,6 +100,9 @@ class GPFIFO:
|
|||
gx, gy, gz = qmd.cta_raster_width, qmd.cta_raster_height, qmd.cta_raster_depth
|
||||
lx, ly, lz = qmd.cta_thread_dimension0, qmd.cta_thread_dimension1, qmd.cta_thread_dimension2
|
||||
gpuocelot_lib.ptx_run(ctypes.cast(prg_addr, ctypes.c_char_p), args_cnt+vals_cnt, (ctypes.c_void_p*len(cargs))(*cargs), lx, ly, lz, gx, gy, gz, 0)
|
||||
if qmd.release0_enable:
|
||||
rel0 = to_mv(qmd.release0_address_lower + (qmd.release0_address_upper << 32), 0x8).cast('Q')
|
||||
rel0[0] = qmd.release0_payload_lower + (qmd.release0_payload_upper << 32)
|
||||
|
||||
def execute_cmd(self, cmd) -> SchedResult:
|
||||
if cmd == nv_gpu.NVC56F_SEM_EXECUTE: return self._exec_signal()
|
||||
|
@ -147,12 +150,18 @@ class GPFIFO:
|
|||
ctypes.memmove(addr, cdata, sz)
|
||||
|
||||
def _exec_nvc6b5_dma(self):
|
||||
flags = self._next_dword()
|
||||
if (flags & 0b11) != 0:
|
||||
src = self._state64(nv_gpu.NVC6B5_OFFSET_IN_UPPER)
|
||||
dst = self._state64(nv_gpu.NVC6B5_OFFSET_OUT_UPPER)
|
||||
sz = self._state(nv_gpu.NVC6B5_LINE_LENGTH_IN)
|
||||
flags = self._next_dword()
|
||||
assert flags == 0x182, f"unsupported flags in _exec_nvc6b5_dma: {flags}"
|
||||
ctypes.memmove(dst, src, sz)
|
||||
elif ((flags >> 3) & 0b11) != 0:
|
||||
src = to_mv(self._state64(nv_gpu.NVC6B5_SET_SEMAPHORE_A), 0x4).cast('I')
|
||||
val = self._state(nv_gpu.NVC6B5_SET_SEMAPHORE_PAYLOAD)
|
||||
src[0] = val
|
||||
else: raise RuntimeError("unknown nvc6b5_dma flags")
|
||||
|
||||
class NVGPU(VirtGPU):
|
||||
def __init__(self, gpuid):
|
||||
|
|
|
@ -56,14 +56,17 @@ class HCQGraph(MultiGraphRunner):
|
|||
for j,ji in enumerate(self.jit_cache):
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
deps = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], (self.comp_signal[ji.prg.device], sig_val:=j+1))
|
||||
deps = [x for x in deps if id(x[0]) != id(self.comp_signal[ji.prg.device])] # remove wait for the same queue as all operations are ordered.
|
||||
|
||||
# NV should wait for the previous kernel to finish
|
||||
deps = [x for x in deps if id(x[0]) != id(self.comp_signal[ji.prg.device])]
|
||||
if ji.prg.device.dname.startswith("NV"): deps.append((self.comp_signal[ji.prg.device], self.comp_signal_val[ji.prg.device]))
|
||||
self.comp_signal_val[ji.prg.device] = sig_val
|
||||
|
||||
for sig, val in deps: self.comp_queues[ji.prg.device].wait(sig, val)
|
||||
|
||||
self.exec_ptrs[j] = (self.comp_queues[ji.prg.device], self.comp_queues[ji.prg.device].ptr())
|
||||
self.comp_queues[ji.prg.device].exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals)) \
|
||||
.signal(self.comp_signal[ji.prg.device], sig_val)
|
||||
self.comp_queues[ji.prg.device].exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals),
|
||||
signal=self.comp_signal[ji.prg.device], signal_value=sig_val)
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
Device[src.device]._gpu_map(dest._buf) #type: ignore
|
||||
|
|
|
@ -138,7 +138,7 @@ class HWPM4Queue:
|
|||
amd_gpu.PACKET3_ACQUIRE_MEM_GCR_CNTL_GL2_INV(gl2)]
|
||||
return self
|
||||
|
||||
def exec(self, prg:AMDProgram, kernargs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1)):
|
||||
def exec(self, prg, kernargs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), signal=None, signal_value=0):
|
||||
self.hdp_flush()
|
||||
self.invalidate_cache()
|
||||
|
||||
|
@ -167,6 +167,8 @@ class HWPM4Queue:
|
|||
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_SET_SH_REG, 1), regCOMPUTE_RESOURCE_LIMITS, 0]
|
||||
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_DISPATCH_DIRECT, 3), *global_size, CS_W32_EN | FORCE_START_AT_000 | COMPUTE_SHADER_EN]
|
||||
self.q += [amd_gpu.PACKET3(amd_gpu.PACKET3_EVENT_WRITE, 0), amd_gpu.EVENT_TYPE(7) | amd_gpu.EVENT_INDEX(4)]
|
||||
|
||||
if signal is not None: self.signal(signal, signal_value)
|
||||
return self
|
||||
|
||||
def update_exec(self, cmd_ptr, global_size, local_size):
|
||||
|
|
|
@ -92,7 +92,7 @@ class HWQueue:
|
|||
|
||||
def wait(self, signal, value=0):
|
||||
self.q += [nvmethod(0, nv_gpu.NVC56F_SEM_ADDR_LO, 5), *nvdata64_le(ctypes.addressof(from_mv(signal))), *nvdata64_le(value),
|
||||
(3 << 0) | (1 << 12) | (1 << 24)] # ACQUIRE | ACQUIRE_SWITCH_TSG | PAYLOAD_SIZE_64BIT
|
||||
(3 << 0) | (1 << 24)] # ACQUIRE | PAYLOAD_SIZE_64BIT
|
||||
return self
|
||||
|
||||
def signal(self, signal, value=0, timestamp=False):
|
||||
|
@ -130,11 +130,18 @@ class HWComputeQueue(HWQueue):
|
|||
self.q += [nvmethod(1, nv_gpu.NVC6C0_LOAD_INLINE_DATA, len(data), typ=6)] + [x for x in data]
|
||||
return self
|
||||
|
||||
def exec(self, prg, kernargs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1)):
|
||||
def exec(self, prg, kernargs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), signal=None, signal_value=0):
|
||||
prg.qmd.cta_raster_width, prg.qmd.cta_raster_height, prg.qmd.cta_raster_depth = global_size
|
||||
prg.qmd.cta_thread_dimension0, prg.qmd.cta_thread_dimension1, prg.qmd.cta_thread_dimension2 = local_size
|
||||
prg.qmd.constant_buffer_addr_lower_0 = kernargs & 0xffffffff
|
||||
prg.qmd.constant_buffer_addr_upper_0 = kernargs >> 32
|
||||
if signal is not None:
|
||||
prg.qmd.release0_address_lower = ctypes.addressof(from_mv(signal)) & 0xffffffff
|
||||
prg.qmd.release0_address_upper = ctypes.addressof(from_mv(signal)) >> 32
|
||||
prg.qmd.release0_payload_lower = signal_value & 0xffffffff
|
||||
prg.qmd.release0_payload_upper = signal_value >> 32
|
||||
prg.qmd.release0_enable = 1
|
||||
else: prg.qmd.release0_enable = 0
|
||||
self.q += [nvmethod(1, nv_gpu.NVC6C0_INVALIDATE_SHADER_CACHES_NO_WFI, 1), (1 << 12) | (1 << 4) | (1 << 0)]
|
||||
self.q += [nvmethod(1, nv_gpu.NVC6C0_SET_INLINE_QMD_ADDRESS_A, 0x42), *nvdata64((kernargs + round_up(prg.constbuf_0_size, 1 << 8)) >> 8)]
|
||||
self.q += [x for x in to_mv(ctypes.addressof(prg.qmd), ctypes.sizeof(prg.qmd)).cast("I")]
|
||||
|
|
Loading…
Reference in New Issue