fix hcq sync (#5062)

* fix hcq sync

* rewrite

* linter + comment

* fix profiler

* no default dict

* correct sync of unjitted transfer

* fix test
This commit is contained in:
nimlgen 2024-06-26 17:50:37 +03:00 committed by GitHub
parent 3604642847
commit 16405b973a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 77 additions and 53 deletions

View File

@ -146,7 +146,6 @@ class TestGraph(unittest.TestCase):
helper_test_graphs(Device[d0].graph, graphs)
@unittest.skipUnless(Device.DEFAULT in {"CUDA", "NV", "AMD"}, "mutidevice graph required")
@unittest.skipIf(Device.DEFAULT in {"NV", "AMD"}, "Broken #5050")
def test_copies_2_devs(self):
d0, d1 = Device.DEFAULT, f"{Device.DEFAULT}:1"
b0 = [helper_alloc_rawbuffer(d0, fill=True) for _ in range(3)]
@ -159,7 +158,6 @@ class TestGraph(unittest.TestCase):
helper_test_graphs(Device[d0].graph, graphs)
@unittest.skipUnless(Device.DEFAULT in {"CUDA", "NV", "AMD"}, "mutidevice graph required")
@unittest.skipIf(Device.DEFAULT in {"NV", "AMD"}, "Broken #5050")
def test_copies_after_graph_global(self):
d0, d1, d2, d3 = Device.DEFAULT, f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3"
b0 = [helper_alloc_rawbuffer(d0, fill=True) for _ in range(8)]
@ -207,7 +205,6 @@ class TestGraph(unittest.TestCase):
helper_test_graphs(Device[d0].graph, graphs)
@unittest.skipUnless(Device.DEFAULT in {"CUDA", "NV", "AMD"}, "mutidevice graph required")
@unittest.skipIf(Device.DEFAULT in {"NV", "AMD"}, "Broken #5050")
def test_graph_after_copies_devs(self):
d0, d1, d2, d3 = Device.DEFAULT, f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3"
b0 = [helper_alloc_rawbuffer(d0, fill=True) for _ in range(8)]

View File

@ -309,7 +309,12 @@ class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
.copy(dest.va_addr, src.va_addr, sz) \
.signal(src_dev.timeline_signal, src_dev.timeline_value).submit(src_dev)
dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value).submit(dest_dev)
src_dev.timeline_value += 1
if src_dev != dest_dev:
dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
.signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
dest_dev.timeline_value += 1
def offset(self, buf, size:int, offset:int): return type(buf)(base=buf.base + offset, va_addr=buf.va_addr + offset, length=size, size=size)

View File

@ -36,90 +36,100 @@ class HCQGraph(MultiGraphRunner):
# NV needs constbuffer to be set
if ji.prg.device.dname.startswith("NV"): to_mv(self.kargs_addrs[j], 0x160).cast('I')[:] = array.array('I', ji.prg.clprg.constbuffer_0)
# Build queues.
# Schedule Dependencies.
# There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
# graph-related tasks. This synchronization uses a global timeline signal per device. Within the graph, the compute queue coordinates with
# global operations and sets a kickoff signal. Any queue accessing a buffer from another device waits for this signal from the devices
# compute queue to ensure exclusive access. The compute queue signals the completion of the graph, synchronizing with the device's copy queue.
self.comp_queues: Dict[Compiled, Any] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
self.comp_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
self.comp_signal_val = {dev: 0 for dev in self.devices}
self.copy_queues: Dict[Compiled, Any] = {dev: dev.hw_copy_queue_t() for dev in self.devices}
self.copy_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
self.copy_signal_val = {dev: 0 for dev in self.devices}
self.kickoff_signal = self.devices[0]._get_signal(value=0)
self.signal_sched: Dict[int, Tuple[List, Optional[int], Optional[List]]] = {} # Dict[ji_idx, (deps, sigval, prof_info)]
self.signals: Dict[Any, Any] = {q: self.devices[0]._get_signal(value=0) for q in list(self.comp_queues.values())+list(self.copy_queues.values())}
self.dev_kickoff_signal = {dev: self.devices[0]._get_signal(value=0) for dev in self.devices + ['CPU']} # Dict[dev, signal]
self.kickoff_value = 0
self.graph_timeline = {dev: 0 for dev in self.devices}
self.signal_sched: Dict[int, Tuple[List, Optional[int], Optional[Tuple]]] = {} # Dict[ji_idx, (deps, output sigval, (prof_st_sig, prof_en_sig))]
self.exec_ptrs: Dict[int, Tuple[Any, int]] = {}
self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
self.save_devs: Dict[Any, Set] = {q: set() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
for dev in self.devices: self.save_devs[self.comp_queues[dev]].add(dev)
self.graph_timeline = {dev: 0 for dev in self.devices} # Dict[dev, last graph sigval]
self.last_ji: Dict[Any, Any] = {q: None for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
# Schedule dependencies
for j,ji in enumerate(self.jit_cache):
enqueue_dev = ji.prg.device if isinstance(ji.prg, CompiledRunner) else Device[ji.bufs[1].device] #type:ignore
enqueue_queue = self.comp_queues[enqueue_dev] if isinstance(ji.prg, CompiledRunner) else self.copy_queues[enqueue_dev]
out_signal = self.signals[enqueue_queue]
writable_buffers = ji.prg.p.outcount if isinstance(ji.prg, CompiledRunner) else 1
deps = self.access_resources(enqueue_queue, ji.bufs[writable_buffers:], ji.bufs[:writable_buffers], j + 1)
if isinstance(ji.prg, CompiledRunner):
deps = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], (self.comp_signal[(dev:=ji.prg.device)], sig_val:=j+1))
if (val:=self.comp_signal_val[dev]) > 0: deps = [x for x in deps if id(x[0]) != id(self.comp_signal[dev])] + [(self.comp_signal[dev], val)]
# Update signal on compute kernel to depend on the previous kernel.
if (last_j:=self.last_ji[enqueue_queue]) is not None: deps = [x for x in deps if id(x[0]) != id(out_signal)] + [(out_signal, last_j + 1)]
# Remove self-dependency for AMD or NV with only 1 same-queue dep, since NV chains 2+ execs in this case, eliminating dep need.
if (dname:=dev.dname.split(":", 1)[0]) == "AMD" or (dname == "NV" and len(deps) == 1 and id(deps[0][0]) == id(self.comp_signal[dev])):
deps = [x for x in deps if id(x[0]) != id(self.comp_signal[dev])]
if (dname:=enqueue_dev.dname.split(":", 1)[0]) == "AMD" or (dname == "NV" and len(deps) == 1 and id(deps[0][0]) == id(out_signal)):
deps = [x for x in deps if id(x[0]) != id(out_signal)]
elif isinstance(ji.prg, BufferXfer): deps = [x for x in deps if id(x[0]) != id(out_signal)]
self.comp_signal_val[dev] = sig_val
elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
deps = self.access_resources([src], [dest], (self.copy_signal[(dev:=Device[src.device])], sig_val:=j+1))
deps = [x for x in deps if id(x[0]) != id(self.copy_signal[Device[src.device]])]
self.copy_signal_val[Device[src.device]] = sig_val
self.copy_to_devs[Device[dest.device]].add(Device[src.device])
# Go through all dependencies and, if we need the signal from that ji, enable it by setting the signal value in the signal schedule.
for sig, val in deps:
if id(sig) in [id(x) for x in self.signals.values()]:
self.signal_sched[val - 1] = self.signal_sched[val - 1][:1] + (val,) + self.signal_sched[val - 1][2:]
# When running compute, set up lazy signals, since no dependencies might be there. Copies always have signals to sync.
prof_ji_desc = ji.prg.clprg.name if isinstance(ji.prg, CompiledRunner) else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
prof_info = (dev._get_signal(), dev._get_signal(), dev, prof_ji_desc, isinstance(ji.prg, BufferXfer)) if PROFILE else None
self.signal_sched[j] = (deps, None if isinstance(ji.prg, CompiledRunner) else j + 1, prof_info)
for sig, val in deps: self.signal_sched[val - 1] = (self.signal_sched[val - 1][0], val, self.signal_sched[val - 1][2])
prof_info = ([enqueue_dev._get_signal() for _ in range(2)] + [enqueue_dev, prof_ji_desc, isinstance(ji.prg, BufferXfer)]) if PROFILE else None
self.signal_sched[j] = (deps, None if isinstance(ji.prg, CompiledRunner) else (j + 1), prof_info)
self.last_ji[enqueue_queue] = j
# Build hardware queues.
self.exec_ptrs: Dict[int, Tuple[Any, int]] = {}
self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
self.kickoff_wait_cmds: Dict[Any, List] = {q: list() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
# Building hardware queues
for dev in self.devices:
self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1).wait(self.kickoff_signal, self.kickoff_value)
self.copy_queues[dev].wait(dev.timeline_signal, dev.timeline_value - 1).wait(self.kickoff_signal, self.kickoff_value)
self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1) \
.wait(self.dev_kickoff_signal['CPU'], self.kickoff_value).signal(self.dev_kickoff_signal[dev], self.kickoff_value)
for j,ji in enumerate(self.jit_cache):
deps, signal_value, prof_info = self.signal_sched[j]
enqueue_queue = self.copy_queues[Device[ji.bufs[1].device]] if isinstance(ji.prg, BufferXfer) else self.comp_queues[ji.prg.device] #type:ignore
# Encode waits and start profile timestamp (if needed).
for sig, val in deps: enqueue_queue.wait(sig, val)
for sig, val in deps:
enqueue_queue.wait(sig, val)
if id(sig) in [id(x) for x in self.dev_kickoff_signal.values()]: self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) - 1)
if prof_info: enqueue_queue.timestamp(prof_info[0])
# Encode main commands based on ji type.
if isinstance(ji.prg, CompiledRunner):
self.comp_queues[ji.prg.device].exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals),
signal=self.comp_signal[ji.prg.device] if signal_value is not None else None, signal_value=signal_value)
self.exec_ptrs[j] = (self.comp_queues[ji.prg.device], len(self.comp_queues[ji.prg.device]) - 1)
enqueue_queue.exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals),
signal=self.signals[enqueue_queue] if signal_value is not None else None, signal_value=signal_value)
self.exec_ptrs[j] = (enqueue_queue, len(enqueue_queue) - 1)
elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
Device[src.device]._gpu_map(dest._buf) #type: ignore
self.copy_queues[Device[src.device]].copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes) \
.signal(self.copy_signal[Device[src.device]], signal_value)
enqueue_queue.copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes).signal(self.signals[enqueue_queue], signal_value)
self.copy_to_devs[Device[dest.device]].add(Device[src.device])
# Encode finish profile timestamp (if needed).
if prof_info: enqueue_queue.timestamp(prof_info[1])
for dev in self.devices:
if self.copy_signal_val[dev] > 0: self.comp_queues[dev].wait(self.copy_signal[dev], self.copy_signal_val[dev])
for dep_dev in self.copy_to_devs[dev]: self.comp_queues[dev].wait(self.copy_signal[dep_dev], self.copy_signal_val[dep_dev])
for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
if (last_j:=self.last_ji[self.copy_queues[dep_dev]]) is None: continue
self.comp_queues[dev].wait(self.signals[self.copy_queues[dep_dev]], self.signal_sched[last_j][1])
self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value)
if hasattr(self.comp_queues[dev], 'bind'): self.comp_queues[dev].bind(dev)
if hasattr(self.copy_queues[dev], 'bind') and self.copy_signal_val[dev] > 0: self.copy_queues[dev].bind(dev)
if hasattr(self.copy_queues[dev], 'bind') and self.last_ji[self.copy_queues[dev]] is not None: self.copy_queues[dev].bind(dev)
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
# Wait and restore signals
self.kickoff_value += 1
for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
for dev in self.devices:
dev._set_signal(self.comp_signal[dev], 0)
dev._set_signal(self.copy_signal[dev], 0)
self.devices[0]._set_signal(self.kickoff_signal, self.kickoff_value)
for queue in self.comp_queues.values(): self.devices[0]._set_signal(self.signals[queue], 0)
for queue in self.copy_queues.values(): self.devices[0]._set_signal(self.signals[queue], 0)
self.devices[0]._set_signal(self.dev_kickoff_signal['CPU'], self.kickoff_value)
if PROFILE and self.kickoff_value > 1:
for _,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): #type: ignore
@ -138,10 +148,12 @@ class HCQGraph(MultiGraphRunner):
for dev in self.devices:
self.comp_queues[dev].update_wait(1, dev.timeline_signal, dev.timeline_value - 1).update_wait(2, value=self.kickoff_value) \
.update_signal(3, value=self.kickoff_value) \
.update_signal(len(self.comp_queues[dev]) - 1, dev.timeline_signal, dev.timeline_value).submit(dev)
if self.copy_signal_val[dev] > 0:
self.copy_queues[dev].update_wait(0, dev.timeline_signal, dev.timeline_value - 1).update_wait(1, value=self.kickoff_value).submit(dev)
if self.last_ji[(cp_queue:=self.copy_queues[dev])] is not None:
for cmd_idx in self.kickoff_wait_cmds[cp_queue]: cp_queue.update_wait(cmd_idx, value=self.kickoff_value)
cp_queue.submit(dev)
self.graph_timeline[dev] = dev.timeline_value
dev.timeline_value += 1
@ -152,14 +164,24 @@ class HCQGraph(MultiGraphRunner):
return time.perf_counter() - st
return None
def access_resources(self, read, write, new_dependency):
deps = self._access_resources(read, write, new_dependency)
return [(k, max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()]
def access_resources(self, queue, read, write, new_val):
deps = self._access_resources(read, write, (queue, new_val))
sync_signals = []
for dep_queue,_ in deps: self.save_devs[queue].update(self.save_devs[dep_queue])
for buf in read+write:
if buf.device not in self.save_devs[queue]:
self.save_devs[queue].add(buf.device)
sync_signals += [(self.dev_kickoff_signal[Device[buf.device]], self.kickoff_value)]
return [(self.signals[k], max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()] + sync_signals
def __del__(self):
for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
# Graph is destructed. No need to keep signals any more, so return them as part of profiling.
if PROFILE and self.kickoff_value > 1:
for _,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): dev.sig_prof_records += [(st, en, desc, is_cp)] #type: ignore
self.devices[0].signals_pool += [self.kickoff_signal] + list(self.copy_signal.values()) + list(self.comp_signal.values()) # type: ignore
self.devices[0].signals_pool += list(self.dev_kickoff_signal.values()) + list(self.signals.values()) # type: ignore
for dev, buf in self.kernargs_bufs.items(): dev.allocator._free(buf, BufferOptions(cpu_access=True))