nv do not modify prg's qmd (#4948)

This commit is contained in:
nimlgen 2024-06-14 01:15:40 +03:00 committed by GitHub
parent 845c10bc28
commit 4bfd1904f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 12 additions and 13 deletions

View File

@ -147,20 +147,19 @@ class HWComputeQueue(HWQueue):
return self
def exec(self, prg, kernargs, global_size=(1,1,1), local_size=(1,1,1), signal=None, signal_value=0, chain_exec_ptr=None):
prg.qmd.cta_raster_width, prg.qmd.cta_raster_height, prg.qmd.cta_raster_depth = global_size
prg.qmd.cta_thread_dimension0, prg.qmd.cta_thread_dimension1, prg.qmd.cta_thread_dimension2 = local_size
prg.qmd.constant_buffer_addr_lower_0 = kernargs & 0xffffffff
prg.qmd.constant_buffer_addr_upper_0 = kernargs >> 32
if signal is not None:
prg.qmd.release0_address_lower = ctypes.addressof(from_mv(signal)) & 0xffffffff
prg.qmd.release0_address_upper = ctypes.addressof(from_mv(signal)) >> 32
prg.qmd.release0_payload_lower = signal_value & 0xffffffff
prg.qmd.release0_payload_upper = signal_value >> 32
prg.qmd.release0_enable = 1
else: prg.qmd.release0_enable = 0
ctypes.memmove(qmd_addr:=(kernargs + round_up(prg.constbuf_0_size, 1 << 8)), ctypes.addressof(prg.qmd), 0x40 * 4)
self.ptr_to_qmd[self.ptr()] = qmd_struct_t.from_address(qmd_addr) # Save qmd for later update
self.ptr_to_qmd[self.ptr()] = qmd = qmd_struct_t.from_address(qmd_addr) # Save qmd for later update
qmd.cta_raster_width, qmd.cta_raster_height, qmd.cta_raster_depth = global_size
qmd.cta_thread_dimension0, qmd.cta_thread_dimension1, qmd.cta_thread_dimension2 = local_size
qmd.constant_buffer_addr_lower_0 = kernargs & 0xffffffff
qmd.constant_buffer_addr_upper_0 = kernargs >> 32
if signal is not None:
qmd.release0_address_lower = ctypes.addressof(from_mv(signal)) & 0xffffffff
qmd.release0_address_upper = ctypes.addressof(from_mv(signal)) >> 32
qmd.release0_payload_lower = signal_value & 0xffffffff
qmd.release0_payload_upper = signal_value >> 32
qmd.release0_enable = 1
if chain_exec_ptr is None:
self.q += [nvmethod(1, nv_gpu.NVC6C0_INVALIDATE_SHADER_CACHES_NO_WFI, 1), (1 << 12) | (1 << 4) | (1 << 0)]