From 16405b973a14f7c385bb00d76b7ed19b0144ac73 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Wed, 26 Jun 2024 17:50:37 +0300 Subject: [PATCH] fix hcq sync (#5062) * fix hcq sync * rewrite * linter + comment * fix profiler * no default dict * correct sync of unjitted transfer * fix test --- test/test_graph.py | 3 - tinygrad/device.py | 7 +- tinygrad/runtime/graph/hcq.py | 120 ++++++++++++++++++++-------------- 3 files changed, 77 insertions(+), 53 deletions(-) diff --git a/test/test_graph.py b/test/test_graph.py index d8de3619..64e553c0 100644 --- a/test/test_graph.py +++ b/test/test_graph.py @@ -146,7 +146,6 @@ class TestGraph(unittest.TestCase): helper_test_graphs(Device[d0].graph, graphs) @unittest.skipUnless(Device.DEFAULT in {"CUDA", "NV", "AMD"}, "mutidevice graph required") - @unittest.skipIf(Device.DEFAULT in {"NV", "AMD"}, "Broken #5050") def test_copies_2_devs(self): d0, d1 = Device.DEFAULT, f"{Device.DEFAULT}:1" b0 = [helper_alloc_rawbuffer(d0, fill=True) for _ in range(3)] @@ -159,7 +158,6 @@ class TestGraph(unittest.TestCase): helper_test_graphs(Device[d0].graph, graphs) @unittest.skipUnless(Device.DEFAULT in {"CUDA", "NV", "AMD"}, "mutidevice graph required") - @unittest.skipIf(Device.DEFAULT in {"NV", "AMD"}, "Broken #5050") def test_copies_after_graph_global(self): d0, d1, d2, d3 = Device.DEFAULT, f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3" b0 = [helper_alloc_rawbuffer(d0, fill=True) for _ in range(8)] @@ -207,7 +205,6 @@ class TestGraph(unittest.TestCase): helper_test_graphs(Device[d0].graph, graphs) @unittest.skipUnless(Device.DEFAULT in {"CUDA", "NV", "AMD"}, "mutidevice graph required") - @unittest.skipIf(Device.DEFAULT in {"NV", "AMD"}, "Broken #5050") def test_graph_after_copies_devs(self): d0, d1, d2, d3 = Device.DEFAULT, f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3" b0 = [helper_alloc_rawbuffer(d0, fill=True) for _ in range(8)] diff --git a/tinygrad/device.py b/tinygrad/device.py index 324ba9c2..b4e1c091 100644 --- a/tinygrad/device.py +++ b/tinygrad/device.py @@ -309,7 +309,12 @@ class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method .wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \ .copy(dest.va_addr, src.va_addr, sz) \ .signal(src_dev.timeline_signal, src_dev.timeline_value).submit(src_dev) - dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value).submit(dest_dev) src_dev.timeline_value += 1 + if src_dev != dest_dev: + dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \ + .wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \ + .signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev) + dest_dev.timeline_value += 1 + def offset(self, buf, size:int, offset:int): return type(buf)(base=buf.base + offset, va_addr=buf.va_addr + offset, length=size, size=size) diff --git a/tinygrad/runtime/graph/hcq.py b/tinygrad/runtime/graph/hcq.py index fd8819bf..0db56be3 100644 --- a/tinygrad/runtime/graph/hcq.py +++ b/tinygrad/runtime/graph/hcq.py @@ -36,90 +36,100 @@ class HCQGraph(MultiGraphRunner): # NV needs constbuffer to be set if ji.prg.device.dname.startswith("NV"): to_mv(self.kargs_addrs[j], 0x160).cast('I')[:] = array.array('I', ji.prg.clprg.constbuffer_0) - # Build queues. + # Schedule Dependencies. + # There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any + # graph-related tasks. This synchronization uses a global timeline signal per device. Within the graph, the compute queue coordinates with + # global operations and sets a kickoff signal. Any queue accessing a buffer from another device waits for this signal from the device’s + # compute queue to ensure exclusive access. The compute queue signals the completion of the graph, synchronizing with the device's copy queue. self.comp_queues: Dict[Compiled, Any] = {dev: dev.hw_compute_queue_t() for dev in self.devices} - self.comp_signal = {dev: dev._get_signal(value=0) for dev in self.devices} - self.comp_signal_val = {dev: 0 for dev in self.devices} - self.copy_queues: Dict[Compiled, Any] = {dev: dev.hw_copy_queue_t() for dev in self.devices} - self.copy_signal = {dev: dev._get_signal(value=0) for dev in self.devices} - self.copy_signal_val = {dev: 0 for dev in self.devices} - self.kickoff_signal = self.devices[0]._get_signal(value=0) + self.signal_sched: Dict[int, Tuple[List, Optional[int], Optional[List]]] = {} # Dict[ji_idx, (deps, sigval, prof_info)] + self.signals: Dict[Any, Any] = {q: self.devices[0]._get_signal(value=0) for q in list(self.comp_queues.values())+list(self.copy_queues.values())} + self.dev_kickoff_signal = {dev: self.devices[0]._get_signal(value=0) for dev in self.devices + ['CPU']} # Dict[dev, signal] self.kickoff_value = 0 - self.graph_timeline = {dev: 0 for dev in self.devices} - self.signal_sched: Dict[int, Tuple[List, Optional[int], Optional[Tuple]]] = {} # Dict[ji_idx, (deps, output sigval, (prof_st_sig, prof_en_sig))] - self.exec_ptrs: Dict[int, Tuple[Any, int]] = {} - self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices} + self.save_devs: Dict[Any, Set] = {q: set() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())} + for dev in self.devices: self.save_devs[self.comp_queues[dev]].add(dev) + + self.graph_timeline = {dev: 0 for dev in self.devices} # Dict[dev, last graph sigval] + self.last_ji: Dict[Any, Any] = {q: None for q in list(self.comp_queues.values()) + list(self.copy_queues.values())} - # Schedule dependencies for j,ji in enumerate(self.jit_cache): + enqueue_dev = ji.prg.device if isinstance(ji.prg, CompiledRunner) else Device[ji.bufs[1].device] #type:ignore + enqueue_queue = self.comp_queues[enqueue_dev] if isinstance(ji.prg, CompiledRunner) else self.copy_queues[enqueue_dev] + out_signal = self.signals[enqueue_queue] + writable_buffers = ji.prg.p.outcount if isinstance(ji.prg, CompiledRunner) else 1 + deps = self.access_resources(enqueue_queue, ji.bufs[writable_buffers:], ji.bufs[:writable_buffers], j + 1) + if isinstance(ji.prg, CompiledRunner): - deps = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], (self.comp_signal[(dev:=ji.prg.device)], sig_val:=j+1)) - if (val:=self.comp_signal_val[dev]) > 0: deps = [x for x in deps if id(x[0]) != id(self.comp_signal[dev])] + [(self.comp_signal[dev], val)] + # Update signal on compute kernel to depend on the previous kernel. + if (last_j:=self.last_ji[enqueue_queue]) is not None: deps = [x for x in deps if id(x[0]) != id(out_signal)] + [(out_signal, last_j + 1)] # Remove self-dependency for AMD or NV with only 1 same-queue dep, since NV chains 2+ execs in this case, eliminating dep need. - if (dname:=dev.dname.split(":", 1)[0]) == "AMD" or (dname == "NV" and len(deps) == 1 and id(deps[0][0]) == id(self.comp_signal[dev])): - deps = [x for x in deps if id(x[0]) != id(self.comp_signal[dev])] + if (dname:=enqueue_dev.dname.split(":", 1)[0]) == "AMD" or (dname == "NV" and len(deps) == 1 and id(deps[0][0]) == id(out_signal)): + deps = [x for x in deps if id(x[0]) != id(out_signal)] + elif isinstance(ji.prg, BufferXfer): deps = [x for x in deps if id(x[0]) != id(out_signal)] - self.comp_signal_val[dev] = sig_val - elif isinstance(ji.prg, BufferXfer): - dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]] - deps = self.access_resources([src], [dest], (self.copy_signal[(dev:=Device[src.device])], sig_val:=j+1)) - deps = [x for x in deps if id(x[0]) != id(self.copy_signal[Device[src.device]])] - self.copy_signal_val[Device[src.device]] = sig_val - self.copy_to_devs[Device[dest.device]].add(Device[src.device]) + # Go through all dependencies and, if we need the signal from that ji, enable it by setting the signal value in the signal schedule. + for sig, val in deps: + if id(sig) in [id(x) for x in self.signals.values()]: + self.signal_sched[val - 1] = self.signal_sched[val - 1][:1] + (val,) + self.signal_sched[val - 1][2:] - # When running compute, set up lazy signals, since no dependencies might be there. Copies always have signals to sync. prof_ji_desc = ji.prg.clprg.name if isinstance(ji.prg, CompiledRunner) else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore - prof_info = (dev._get_signal(), dev._get_signal(), dev, prof_ji_desc, isinstance(ji.prg, BufferXfer)) if PROFILE else None - self.signal_sched[j] = (deps, None if isinstance(ji.prg, CompiledRunner) else j + 1, prof_info) - for sig, val in deps: self.signal_sched[val - 1] = (self.signal_sched[val - 1][0], val, self.signal_sched[val - 1][2]) + prof_info = ([enqueue_dev._get_signal() for _ in range(2)] + [enqueue_dev, prof_ji_desc, isinstance(ji.prg, BufferXfer)]) if PROFILE else None + self.signal_sched[j] = (deps, None if isinstance(ji.prg, CompiledRunner) else (j + 1), prof_info) + self.last_ji[enqueue_queue] = j + + # Build hardware queues. + self.exec_ptrs: Dict[int, Tuple[Any, int]] = {} + self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices} + self.kickoff_wait_cmds: Dict[Any, List] = {q: list() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())} - # Building hardware queues for dev in self.devices: - self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1).wait(self.kickoff_signal, self.kickoff_value) - self.copy_queues[dev].wait(dev.timeline_signal, dev.timeline_value - 1).wait(self.kickoff_signal, self.kickoff_value) + self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1) \ + .wait(self.dev_kickoff_signal['CPU'], self.kickoff_value).signal(self.dev_kickoff_signal[dev], self.kickoff_value) for j,ji in enumerate(self.jit_cache): deps, signal_value, prof_info = self.signal_sched[j] enqueue_queue = self.copy_queues[Device[ji.bufs[1].device]] if isinstance(ji.prg, BufferXfer) else self.comp_queues[ji.prg.device] #type:ignore # Encode waits and start profile timestamp (if needed). - for sig, val in deps: enqueue_queue.wait(sig, val) + for sig, val in deps: + enqueue_queue.wait(sig, val) + if id(sig) in [id(x) for x in self.dev_kickoff_signal.values()]: self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) - 1) if prof_info: enqueue_queue.timestamp(prof_info[0]) # Encode main commands based on ji type. if isinstance(ji.prg, CompiledRunner): - self.comp_queues[ji.prg.device].exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals), - signal=self.comp_signal[ji.prg.device] if signal_value is not None else None, signal_value=signal_value) - self.exec_ptrs[j] = (self.comp_queues[ji.prg.device], len(self.comp_queues[ji.prg.device]) - 1) + enqueue_queue.exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals), + signal=self.signals[enqueue_queue] if signal_value is not None else None, signal_value=signal_value) + self.exec_ptrs[j] = (enqueue_queue, len(enqueue_queue) - 1) elif isinstance(ji.prg, BufferXfer): dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]] Device[src.device]._gpu_map(dest._buf) #type: ignore - self.copy_queues[Device[src.device]].copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes) \ - .signal(self.copy_signal[Device[src.device]], signal_value) + enqueue_queue.copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes).signal(self.signals[enqueue_queue], signal_value) + self.copy_to_devs[Device[dest.device]].add(Device[src.device]) # Encode finish profile timestamp (if needed). if prof_info: enqueue_queue.timestamp(prof_info[1]) for dev in self.devices: - if self.copy_signal_val[dev] > 0: self.comp_queues[dev].wait(self.copy_signal[dev], self.copy_signal_val[dev]) - for dep_dev in self.copy_to_devs[dev]: self.comp_queues[dev].wait(self.copy_signal[dep_dev], self.copy_signal_val[dep_dev]) + for dep_dev in list(self.copy_to_devs[dev]) + [dev]: + if (last_j:=self.last_ji[self.copy_queues[dep_dev]]) is None: continue + self.comp_queues[dev].wait(self.signals[self.copy_queues[dep_dev]], self.signal_sched[last_j][1]) self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value) if hasattr(self.comp_queues[dev], 'bind'): self.comp_queues[dev].bind(dev) - if hasattr(self.copy_queues[dev], 'bind') and self.copy_signal_val[dev] > 0: self.copy_queues[dev].bind(dev) + if hasattr(self.copy_queues[dev], 'bind') and self.last_ji[self.copy_queues[dev]] is not None: self.copy_queues[dev].bind(dev) def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]: # Wait and restore signals self.kickoff_value += 1 for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev]) - for dev in self.devices: - dev._set_signal(self.comp_signal[dev], 0) - dev._set_signal(self.copy_signal[dev], 0) - self.devices[0]._set_signal(self.kickoff_signal, self.kickoff_value) + for queue in self.comp_queues.values(): self.devices[0]._set_signal(self.signals[queue], 0) + for queue in self.copy_queues.values(): self.devices[0]._set_signal(self.signals[queue], 0) + self.devices[0]._set_signal(self.dev_kickoff_signal['CPU'], self.kickoff_value) if PROFILE and self.kickoff_value > 1: for _,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): #type: ignore @@ -138,10 +148,12 @@ class HCQGraph(MultiGraphRunner): for dev in self.devices: self.comp_queues[dev].update_wait(1, dev.timeline_signal, dev.timeline_value - 1).update_wait(2, value=self.kickoff_value) \ + .update_signal(3, value=self.kickoff_value) \ .update_signal(len(self.comp_queues[dev]) - 1, dev.timeline_signal, dev.timeline_value).submit(dev) - if self.copy_signal_val[dev] > 0: - self.copy_queues[dev].update_wait(0, dev.timeline_signal, dev.timeline_value - 1).update_wait(1, value=self.kickoff_value).submit(dev) + if self.last_ji[(cp_queue:=self.copy_queues[dev])] is not None: + for cmd_idx in self.kickoff_wait_cmds[cp_queue]: cp_queue.update_wait(cmd_idx, value=self.kickoff_value) + cp_queue.submit(dev) self.graph_timeline[dev] = dev.timeline_value dev.timeline_value += 1 @@ -152,14 +164,24 @@ class HCQGraph(MultiGraphRunner): return time.perf_counter() - st return None - def access_resources(self, read, write, new_dependency): - deps = self._access_resources(read, write, new_dependency) - return [(k, max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()] + def access_resources(self, queue, read, write, new_val): + deps = self._access_resources(read, write, (queue, new_val)) + + sync_signals = [] + for dep_queue,_ in deps: self.save_devs[queue].update(self.save_devs[dep_queue]) + for buf in read+write: + if buf.device not in self.save_devs[queue]: + self.save_devs[queue].add(buf.device) + sync_signals += [(self.dev_kickoff_signal[Device[buf.device]], self.kickoff_value)] + + return [(self.signals[k], max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()] + sync_signals def __del__(self): + for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev]) + # Graph is destructed. No need to keep signals any more, so return them as part of profiling. if PROFILE and self.kickoff_value > 1: for _,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): dev.sig_prof_records += [(st, en, desc, is_cp)] #type: ignore - self.devices[0].signals_pool += [self.kickoff_signal] + list(self.copy_signal.values()) + list(self.comp_signal.values()) # type: ignore + self.devices[0].signals_pool += list(self.dev_kickoff_signal.values()) + list(self.signals.values()) # type: ignore for dev, buf in self.kernargs_bufs.items(): dev.allocator._free(buf, BufferOptions(cpu_access=True))