mirror of https://github.com/commaai/tinygrad.git
fix hcq sync (#5062)
* fix hcq sync * rewrite * linter + comment * fix profiler * no default dict * correct sync of unjitted transfer * fix test
This commit is contained in:
parent
3604642847
commit
16405b973a
|
@ -146,7 +146,6 @@ class TestGraph(unittest.TestCase):
|
|||
helper_test_graphs(Device[d0].graph, graphs)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CUDA", "NV", "AMD"}, "mutidevice graph required")
|
||||
@unittest.skipIf(Device.DEFAULT in {"NV", "AMD"}, "Broken #5050")
|
||||
def test_copies_2_devs(self):
|
||||
d0, d1 = Device.DEFAULT, f"{Device.DEFAULT}:1"
|
||||
b0 = [helper_alloc_rawbuffer(d0, fill=True) for _ in range(3)]
|
||||
|
@ -159,7 +158,6 @@ class TestGraph(unittest.TestCase):
|
|||
helper_test_graphs(Device[d0].graph, graphs)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CUDA", "NV", "AMD"}, "mutidevice graph required")
|
||||
@unittest.skipIf(Device.DEFAULT in {"NV", "AMD"}, "Broken #5050")
|
||||
def test_copies_after_graph_global(self):
|
||||
d0, d1, d2, d3 = Device.DEFAULT, f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3"
|
||||
b0 = [helper_alloc_rawbuffer(d0, fill=True) for _ in range(8)]
|
||||
|
@ -207,7 +205,6 @@ class TestGraph(unittest.TestCase):
|
|||
helper_test_graphs(Device[d0].graph, graphs)
|
||||
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CUDA", "NV", "AMD"}, "mutidevice graph required")
|
||||
@unittest.skipIf(Device.DEFAULT in {"NV", "AMD"}, "Broken #5050")
|
||||
def test_graph_after_copies_devs(self):
|
||||
d0, d1, d2, d3 = Device.DEFAULT, f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3"
|
||||
b0 = [helper_alloc_rawbuffer(d0, fill=True) for _ in range(8)]
|
||||
|
|
|
@ -309,7 +309,12 @@ class HCQCompatAllocator(LRUAllocator): # pylint: disable=abstract-method
|
|||
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
|
||||
.copy(dest.va_addr, src.va_addr, sz) \
|
||||
.signal(src_dev.timeline_signal, src_dev.timeline_value).submit(src_dev)
|
||||
dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value).submit(dest_dev)
|
||||
src_dev.timeline_value += 1
|
||||
|
||||
if src_dev != dest_dev:
|
||||
dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
|
||||
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
|
||||
.signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
|
||||
dest_dev.timeline_value += 1
|
||||
|
||||
def offset(self, buf, size:int, offset:int): return type(buf)(base=buf.base + offset, va_addr=buf.va_addr + offset, length=size, size=size)
|
||||
|
|
|
@ -36,90 +36,100 @@ class HCQGraph(MultiGraphRunner):
|
|||
# NV needs constbuffer to be set
|
||||
if ji.prg.device.dname.startswith("NV"): to_mv(self.kargs_addrs[j], 0x160).cast('I')[:] = array.array('I', ji.prg.clprg.constbuffer_0)
|
||||
|
||||
# Build queues.
|
||||
# Schedule Dependencies.
|
||||
# There are two types of queues on each device: copy and compute. Both must synchronize with all external operations before launching any
|
||||
# graph-related tasks. This synchronization uses a global timeline signal per device. Within the graph, the compute queue coordinates with
|
||||
# global operations and sets a kickoff signal. Any queue accessing a buffer from another device waits for this signal from the device’s
|
||||
# compute queue to ensure exclusive access. The compute queue signals the completion of the graph, synchronizing with the device's copy queue.
|
||||
self.comp_queues: Dict[Compiled, Any] = {dev: dev.hw_compute_queue_t() for dev in self.devices}
|
||||
self.comp_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
|
||||
self.comp_signal_val = {dev: 0 for dev in self.devices}
|
||||
|
||||
self.copy_queues: Dict[Compiled, Any] = {dev: dev.hw_copy_queue_t() for dev in self.devices}
|
||||
self.copy_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
|
||||
self.copy_signal_val = {dev: 0 for dev in self.devices}
|
||||
|
||||
self.kickoff_signal = self.devices[0]._get_signal(value=0)
|
||||
self.signal_sched: Dict[int, Tuple[List, Optional[int], Optional[List]]] = {} # Dict[ji_idx, (deps, sigval, prof_info)]
|
||||
self.signals: Dict[Any, Any] = {q: self.devices[0]._get_signal(value=0) for q in list(self.comp_queues.values())+list(self.copy_queues.values())}
|
||||
self.dev_kickoff_signal = {dev: self.devices[0]._get_signal(value=0) for dev in self.devices + ['CPU']} # Dict[dev, signal]
|
||||
self.kickoff_value = 0
|
||||
self.graph_timeline = {dev: 0 for dev in self.devices}
|
||||
|
||||
self.signal_sched: Dict[int, Tuple[List, Optional[int], Optional[Tuple]]] = {} # Dict[ji_idx, (deps, output sigval, (prof_st_sig, prof_en_sig))]
|
||||
self.exec_ptrs: Dict[int, Tuple[Any, int]] = {}
|
||||
self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
|
||||
self.save_devs: Dict[Any, Set] = {q: set() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
|
||||
for dev in self.devices: self.save_devs[self.comp_queues[dev]].add(dev)
|
||||
|
||||
self.graph_timeline = {dev: 0 for dev in self.devices} # Dict[dev, last graph sigval]
|
||||
self.last_ji: Dict[Any, Any] = {q: None for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
|
||||
|
||||
# Schedule dependencies
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
enqueue_dev = ji.prg.device if isinstance(ji.prg, CompiledRunner) else Device[ji.bufs[1].device] #type:ignore
|
||||
enqueue_queue = self.comp_queues[enqueue_dev] if isinstance(ji.prg, CompiledRunner) else self.copy_queues[enqueue_dev]
|
||||
out_signal = self.signals[enqueue_queue]
|
||||
writable_buffers = ji.prg.p.outcount if isinstance(ji.prg, CompiledRunner) else 1
|
||||
deps = self.access_resources(enqueue_queue, ji.bufs[writable_buffers:], ji.bufs[:writable_buffers], j + 1)
|
||||
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
deps = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], (self.comp_signal[(dev:=ji.prg.device)], sig_val:=j+1))
|
||||
if (val:=self.comp_signal_val[dev]) > 0: deps = [x for x in deps if id(x[0]) != id(self.comp_signal[dev])] + [(self.comp_signal[dev], val)]
|
||||
# Update signal on compute kernel to depend on the previous kernel.
|
||||
if (last_j:=self.last_ji[enqueue_queue]) is not None: deps = [x for x in deps if id(x[0]) != id(out_signal)] + [(out_signal, last_j + 1)]
|
||||
|
||||
# Remove self-dependency for AMD or NV with only 1 same-queue dep, since NV chains 2+ execs in this case, eliminating dep need.
|
||||
if (dname:=dev.dname.split(":", 1)[0]) == "AMD" or (dname == "NV" and len(deps) == 1 and id(deps[0][0]) == id(self.comp_signal[dev])):
|
||||
deps = [x for x in deps if id(x[0]) != id(self.comp_signal[dev])]
|
||||
if (dname:=enqueue_dev.dname.split(":", 1)[0]) == "AMD" or (dname == "NV" and len(deps) == 1 and id(deps[0][0]) == id(out_signal)):
|
||||
deps = [x for x in deps if id(x[0]) != id(out_signal)]
|
||||
elif isinstance(ji.prg, BufferXfer): deps = [x for x in deps if id(x[0]) != id(out_signal)]
|
||||
|
||||
self.comp_signal_val[dev] = sig_val
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
deps = self.access_resources([src], [dest], (self.copy_signal[(dev:=Device[src.device])], sig_val:=j+1))
|
||||
deps = [x for x in deps if id(x[0]) != id(self.copy_signal[Device[src.device]])]
|
||||
self.copy_signal_val[Device[src.device]] = sig_val
|
||||
self.copy_to_devs[Device[dest.device]].add(Device[src.device])
|
||||
# Go through all dependencies and, if we need the signal from that ji, enable it by setting the signal value in the signal schedule.
|
||||
for sig, val in deps:
|
||||
if id(sig) in [id(x) for x in self.signals.values()]:
|
||||
self.signal_sched[val - 1] = self.signal_sched[val - 1][:1] + (val,) + self.signal_sched[val - 1][2:]
|
||||
|
||||
# When running compute, set up lazy signals, since no dependencies might be there. Copies always have signals to sync.
|
||||
prof_ji_desc = ji.prg.clprg.name if isinstance(ji.prg, CompiledRunner) else f"{ji.bufs[1].device} -> {ji.bufs[0].device}" # type: ignore
|
||||
prof_info = (dev._get_signal(), dev._get_signal(), dev, prof_ji_desc, isinstance(ji.prg, BufferXfer)) if PROFILE else None
|
||||
self.signal_sched[j] = (deps, None if isinstance(ji.prg, CompiledRunner) else j + 1, prof_info)
|
||||
for sig, val in deps: self.signal_sched[val - 1] = (self.signal_sched[val - 1][0], val, self.signal_sched[val - 1][2])
|
||||
prof_info = ([enqueue_dev._get_signal() for _ in range(2)] + [enqueue_dev, prof_ji_desc, isinstance(ji.prg, BufferXfer)]) if PROFILE else None
|
||||
self.signal_sched[j] = (deps, None if isinstance(ji.prg, CompiledRunner) else (j + 1), prof_info)
|
||||
self.last_ji[enqueue_queue] = j
|
||||
|
||||
# Build hardware queues.
|
||||
self.exec_ptrs: Dict[int, Tuple[Any, int]] = {}
|
||||
self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
|
||||
self.kickoff_wait_cmds: Dict[Any, List] = {q: list() for q in list(self.comp_queues.values()) + list(self.copy_queues.values())}
|
||||
|
||||
# Building hardware queues
|
||||
for dev in self.devices:
|
||||
self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1).wait(self.kickoff_signal, self.kickoff_value)
|
||||
self.copy_queues[dev].wait(dev.timeline_signal, dev.timeline_value - 1).wait(self.kickoff_signal, self.kickoff_value)
|
||||
self.comp_queues[dev].memory_barrier().wait(dev.timeline_signal, dev.timeline_value - 1) \
|
||||
.wait(self.dev_kickoff_signal['CPU'], self.kickoff_value).signal(self.dev_kickoff_signal[dev], self.kickoff_value)
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
deps, signal_value, prof_info = self.signal_sched[j]
|
||||
enqueue_queue = self.copy_queues[Device[ji.bufs[1].device]] if isinstance(ji.prg, BufferXfer) else self.comp_queues[ji.prg.device] #type:ignore
|
||||
|
||||
# Encode waits and start profile timestamp (if needed).
|
||||
for sig, val in deps: enqueue_queue.wait(sig, val)
|
||||
for sig, val in deps:
|
||||
enqueue_queue.wait(sig, val)
|
||||
if id(sig) in [id(x) for x in self.dev_kickoff_signal.values()]: self.kickoff_wait_cmds[enqueue_queue].append(len(enqueue_queue) - 1)
|
||||
if prof_info: enqueue_queue.timestamp(prof_info[0])
|
||||
|
||||
# Encode main commands based on ji type.
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
self.comp_queues[ji.prg.device].exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals),
|
||||
signal=self.comp_signal[ji.prg.device] if signal_value is not None else None, signal_value=signal_value)
|
||||
self.exec_ptrs[j] = (self.comp_queues[ji.prg.device], len(self.comp_queues[ji.prg.device]) - 1)
|
||||
enqueue_queue.exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals),
|
||||
signal=self.signals[enqueue_queue] if signal_value is not None else None, signal_value=signal_value)
|
||||
self.exec_ptrs[j] = (enqueue_queue, len(enqueue_queue) - 1)
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
Device[src.device]._gpu_map(dest._buf) #type: ignore
|
||||
self.copy_queues[Device[src.device]].copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes) \
|
||||
.signal(self.copy_signal[Device[src.device]], signal_value)
|
||||
enqueue_queue.copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes).signal(self.signals[enqueue_queue], signal_value)
|
||||
self.copy_to_devs[Device[dest.device]].add(Device[src.device])
|
||||
|
||||
# Encode finish profile timestamp (if needed).
|
||||
if prof_info: enqueue_queue.timestamp(prof_info[1])
|
||||
|
||||
for dev in self.devices:
|
||||
if self.copy_signal_val[dev] > 0: self.comp_queues[dev].wait(self.copy_signal[dev], self.copy_signal_val[dev])
|
||||
for dep_dev in self.copy_to_devs[dev]: self.comp_queues[dev].wait(self.copy_signal[dep_dev], self.copy_signal_val[dep_dev])
|
||||
for dep_dev in list(self.copy_to_devs[dev]) + [dev]:
|
||||
if (last_j:=self.last_ji[self.copy_queues[dep_dev]]) is None: continue
|
||||
self.comp_queues[dev].wait(self.signals[self.copy_queues[dep_dev]], self.signal_sched[last_j][1])
|
||||
|
||||
self.comp_queues[dev].signal(dev.timeline_signal, dev.timeline_value)
|
||||
if hasattr(self.comp_queues[dev], 'bind'): self.comp_queues[dev].bind(dev)
|
||||
if hasattr(self.copy_queues[dev], 'bind') and self.copy_signal_val[dev] > 0: self.copy_queues[dev].bind(dev)
|
||||
if hasattr(self.copy_queues[dev], 'bind') and self.last_ji[self.copy_queues[dev]] is not None: self.copy_queues[dev].bind(dev)
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
# Wait and restore signals
|
||||
self.kickoff_value += 1
|
||||
for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
|
||||
for dev in self.devices:
|
||||
dev._set_signal(self.comp_signal[dev], 0)
|
||||
dev._set_signal(self.copy_signal[dev], 0)
|
||||
self.devices[0]._set_signal(self.kickoff_signal, self.kickoff_value)
|
||||
for queue in self.comp_queues.values(): self.devices[0]._set_signal(self.signals[queue], 0)
|
||||
for queue in self.copy_queues.values(): self.devices[0]._set_signal(self.signals[queue], 0)
|
||||
self.devices[0]._set_signal(self.dev_kickoff_signal['CPU'], self.kickoff_value)
|
||||
|
||||
if PROFILE and self.kickoff_value > 1:
|
||||
for _,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): #type: ignore
|
||||
|
@ -138,10 +148,12 @@ class HCQGraph(MultiGraphRunner):
|
|||
|
||||
for dev in self.devices:
|
||||
self.comp_queues[dev].update_wait(1, dev.timeline_signal, dev.timeline_value - 1).update_wait(2, value=self.kickoff_value) \
|
||||
.update_signal(3, value=self.kickoff_value) \
|
||||
.update_signal(len(self.comp_queues[dev]) - 1, dev.timeline_signal, dev.timeline_value).submit(dev)
|
||||
|
||||
if self.copy_signal_val[dev] > 0:
|
||||
self.copy_queues[dev].update_wait(0, dev.timeline_signal, dev.timeline_value - 1).update_wait(1, value=self.kickoff_value).submit(dev)
|
||||
if self.last_ji[(cp_queue:=self.copy_queues[dev])] is not None:
|
||||
for cmd_idx in self.kickoff_wait_cmds[cp_queue]: cp_queue.update_wait(cmd_idx, value=self.kickoff_value)
|
||||
cp_queue.submit(dev)
|
||||
|
||||
self.graph_timeline[dev] = dev.timeline_value
|
||||
dev.timeline_value += 1
|
||||
|
@ -152,14 +164,24 @@ class HCQGraph(MultiGraphRunner):
|
|||
return time.perf_counter() - st
|
||||
return None
|
||||
|
||||
def access_resources(self, read, write, new_dependency):
|
||||
deps = self._access_resources(read, write, new_dependency)
|
||||
return [(k, max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()]
|
||||
def access_resources(self, queue, read, write, new_val):
|
||||
deps = self._access_resources(read, write, (queue, new_val))
|
||||
|
||||
sync_signals = []
|
||||
for dep_queue,_ in deps: self.save_devs[queue].update(self.save_devs[dep_queue])
|
||||
for buf in read+write:
|
||||
if buf.device not in self.save_devs[queue]:
|
||||
self.save_devs[queue].add(buf.device)
|
||||
sync_signals += [(self.dev_kickoff_signal[Device[buf.device]], self.kickoff_value)]
|
||||
|
||||
return [(self.signals[k], max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()] + sync_signals
|
||||
|
||||
def __del__(self):
|
||||
for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
|
||||
|
||||
# Graph is destructed. No need to keep signals any more, so return them as part of profiling.
|
||||
if PROFILE and self.kickoff_value > 1:
|
||||
for _,_,(st,en,dev,desc,is_cp) in self.signal_sched.values(): dev.sig_prof_records += [(st, en, desc, is_cp)] #type: ignore
|
||||
|
||||
self.devices[0].signals_pool += [self.kickoff_signal] + list(self.copy_signal.values()) + list(self.comp_signal.values()) # type: ignore
|
||||
self.devices[0].signals_pool += list(self.dev_kickoff_signal.values()) + list(self.signals.values()) # type: ignore
|
||||
for dev, buf in self.kernargs_bufs.items(): dev.allocator._free(buf, BufferOptions(cpu_access=True))
|
||||
|
|
Loading…
Reference in New Issue