mirror of https://github.com/commaai/tinygrad.git
restore hcq graph (#4513)
* Reapply "hcq graph (#4380)" (#4512)
This reverts commit 06c1e7498e
.
* bring back hcq graph
This commit is contained in:
parent
06c1e7498e
commit
58e7256ce9
|
@ -10,6 +10,7 @@ from tinygrad.dtype import DType, ImageDType
|
|||
class BufferOptions:
|
||||
image: Optional[ImageDType] = None
|
||||
uncached: bool = False
|
||||
cpu_access: bool = False
|
||||
host: bool = False
|
||||
nolru: bool = False
|
||||
|
||||
|
|
|
@ -41,12 +41,14 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
|
|||
if ji.prg.__class__ in {EmptyOp, ViewOp}: continue
|
||||
ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
|
||||
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device
|
||||
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"HSA", "CUDA"}:
|
||||
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"HSA", "CUDA", "NV", "AMD"}:
|
||||
ji_graph_dev = Device[ji.bufs[0].device]
|
||||
|
||||
graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None #type: ignore
|
||||
can_be_graphed = ji_graph_dev and ji_graph_dev.graph
|
||||
can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and (ji_graph_dev == current_device or
|
||||
(isinstance(ji_graph_dev.graph, type) and issubclass(ji_graph_dev.graph, MultiGraphRunner) and type(ji_graph_dev) == type(current_device))) #type:ignore
|
||||
can_share_graph = (ji_graph_dev == current_device or (isinstance(graph_class, type) and issubclass(graph_class, MultiGraphRunner)) and
|
||||
type(ji_graph_dev) == type(current_device))
|
||||
can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and can_share_graph
|
||||
if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
|
||||
|
||||
if can_be_graphed: current_batch.append(ji)
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
import ctypes, collections, array, time
|
||||
from typing import List, Any, Dict, cast, Optional, Tuple, Set
|
||||
from tinygrad.helpers import GraphException, round_up, to_mv
|
||||
from tinygrad.buffer import Buffer, BufferOptions
|
||||
from tinygrad.device import Compiled, CompiledRunner, Device
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
|
||||
class HCQGraph(MultiGraphRunner):
|
||||
def __init__(self, device_t, comp_hcq_t, copy_hcq_t, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
self.device_t, self.comp_hcq_t, self.copy_hcq_t = device_t, comp_hcq_t, copy_hcq_t
|
||||
|
||||
# Check all jit items are compatible.
|
||||
self.devices = list(set(cast(self.device_t, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs])) #type: ignore
|
||||
if any(not isinstance(d, self.device_t) for d in self.devices): raise GraphException
|
||||
|
||||
# Allocate kernel args.
|
||||
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
|
||||
for ji in self.jit_cache:
|
||||
kernargs_size[ji.prg.device] += round_up(ji.prg.clprg.kernargs_segment_size, 16) if isinstance(ji.prg, CompiledRunner) else 0
|
||||
kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)).va_addr for dev,sz in kernargs_size.items()}
|
||||
|
||||
# Fill initial arguments.
|
||||
self.kargs_addrs: Dict[int, int] = {}
|
||||
self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
if not isinstance(ji.prg, CompiledRunner): continue
|
||||
self.kargs_addrs[j] = kernargs_ptrs[ji.prg.device]
|
||||
kernargs_ptrs[ji.prg.device] += round_up(ji.prg.clprg.kernargs_segment_size, 16)
|
||||
|
||||
self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(self.kargs_addrs[j] + ji.prg.clprg.kernargs_offset)
|
||||
for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf.va_addr)
|
||||
for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]])
|
||||
|
||||
# NV needs constbuffer to be set
|
||||
if ji.prg.device.dname.startswith("NV"): to_mv(self.kargs_addrs[j], 0x160).cast('I')[:] = array.array('I', ji.prg.clprg.constbuffer_0)
|
||||
|
||||
# Build queues.
|
||||
self.queue_list: List[Tuple[Any, ...]] = []
|
||||
|
||||
self.comp_queues: Dict[Compiled, Any] = collections.defaultdict(self.comp_hcq_t)
|
||||
self.comp_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
|
||||
self.comp_signal_val = {dev: 0 for dev in self.devices}
|
||||
|
||||
self.copy_queues: Dict[Compiled, Any] = collections.defaultdict(self.copy_hcq_t)
|
||||
self.copy_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
|
||||
self.copy_signal_val = {dev: 0 for dev in self.devices}
|
||||
|
||||
self.kickoff_signal = self.devices[0]._get_signal(value=0)
|
||||
self.kickoff_value = 0
|
||||
self.graph_timeline = {dev: 0 for dev in self.devices}
|
||||
|
||||
self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
deps = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], (self.comp_signal[ji.prg.device], sig_val:=j+1))
|
||||
deps.append((self.comp_signal[ji.prg.device], self.comp_signal_val[ji.prg.device]))
|
||||
self.comp_signal_val[ji.prg.device] = sig_val
|
||||
|
||||
# Rebuilt runners with dynamic launch dims online.
|
||||
if j in self.jc_idx_with_updatable_launch_dims:
|
||||
if ji.prg.device in self.comp_queues: self.queue_list.append((self.comp_queues.pop(ji.prg.device), ji.prg.device))
|
||||
self.queue_list.append((j, deps))
|
||||
else:
|
||||
for sig, val in deps: self.comp_queues[ji.prg.device].wait(sig, val)
|
||||
self.comp_queues[ji.prg.device].exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals)) \
|
||||
.signal(self.comp_signal[ji.prg.device], sig_val)
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
|
||||
Device[src.device]._gpu_map(dest._buf) #type: ignore
|
||||
|
||||
deps = self.access_resources([src], [dest], (self.copy_signal[Device[src.device]], sig_val:=j+1))
|
||||
deps.append((self.copy_signal[Device[src.device]], self.copy_signal_val[Device[src.device]]))
|
||||
self.copy_signal_val[Device[src.device]] = sig_val
|
||||
|
||||
for sig,val in deps: self.copy_queues[Device[src.device]].wait(sig, val)
|
||||
self.copy_queues[Device[src.device]].copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes) \
|
||||
.signal(self.copy_signal[Device[src.device]], sig_val)
|
||||
self.copy_to_devs[Device[dest.device]].add(Device[src.device])
|
||||
|
||||
for dev in self.devices:
|
||||
if self.copy_signal_val[dev] > 0: self.comp_queues[dev].wait(self.copy_signal[dev], self.copy_signal_val[dev])
|
||||
for dep_dev in self.copy_to_devs: self.comp_queues[dev].wait(self.copy_signal[dep_dev], self.copy_signal_val[dep_dev])
|
||||
|
||||
self.queue_list.append((self.comp_queues.pop(dev), dev))
|
||||
if self.copy_signal_val[dev] > 0: self.queue_list.append((self.copy_queues.pop(dev), dev))
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
# Wait and restore signals
|
||||
self.kickoff_value += 1
|
||||
for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
|
||||
for dev in self.devices:
|
||||
dev._set_signal(self.comp_signal[dev], 0)
|
||||
dev._set_signal(self.copy_signal[dev], 0)
|
||||
dev._set_signal(self.kickoff_signal, self.kickoff_value)
|
||||
|
||||
# Update rawbuffers
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf.va_addr)
|
||||
|
||||
# Update var_vals
|
||||
for j in self.jc_idx_with_updatable_var_vals:
|
||||
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
|
||||
self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
|
||||
|
||||
for dev in self.devices:
|
||||
self.comp_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
|
||||
.wait(self.kickoff_signal, self.kickoff_value).submit(dev)
|
||||
self.copy_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
|
||||
.wait(self.kickoff_signal, self.kickoff_value).submit(dev)
|
||||
|
||||
for entry in self.queue_list:
|
||||
if isinstance(entry[0], self.comp_hcq_t) or isinstance(entry[0], self.copy_hcq_t): queue, dev = entry
|
||||
else:
|
||||
# Kernel with dynamic launch bounds, rebuild it.
|
||||
j, ji, deps, dev = entry[0], self.jit_cache[entry[0]], entry[1], self.jit_cache[entry[0]].prg.device
|
||||
queue = self.comp_hcq_t()
|
||||
for sig, val in deps: queue.wait(sig, val)
|
||||
queue.exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.launch_dims(var_vals)) \
|
||||
.signal(self.comp_signal[dev], value=j+1)
|
||||
queue.submit(dev)
|
||||
|
||||
for dev in self.devices:
|
||||
self.comp_hcq_t().signal(dev.timeline_signal, dev.timeline_value).submit(dev)
|
||||
self.graph_timeline[dev] = dev.timeline_value
|
||||
dev.timeline_value += 1
|
||||
|
||||
if wait:
|
||||
st = time.perf_counter()
|
||||
for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
|
||||
return time.perf_counter() - st
|
||||
return None
|
||||
|
||||
def access_resources(self, read, write, new_dependency):
|
||||
deps = self._access_resources(read, write, new_dependency)
|
||||
return [(k, max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()]
|
|
@ -388,7 +388,7 @@ class AMDAllocator(LRUAllocator):
|
|||
def _alloc(self, size:int, options:BufferOptions):
|
||||
try:
|
||||
if options.host: return self.device._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_USERPTR, public=True)
|
||||
else: return self.device._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM, public=True)
|
||||
else: return self.device._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM, public=options.cpu_access)
|
||||
except OSError as e:
|
||||
if e.errno == errno.ENOMEM: raise MemoryError("Cannot allocate memory") from e
|
||||
else: raise
|
||||
|
@ -584,7 +584,9 @@ class AMDDevice(Compiled):
|
|||
self.pm4_write_pointer = to_mv(self.pm4_queue.write_pointer_address, 8).cast("Q")
|
||||
self.pm4_doorbell = to_mv(self.doorbells + self.pm4_queue.doorbell_offset - self.doorbells_base, 8).cast("Q")
|
||||
|
||||
super().__init__(device, AMDAllocator(self), AMDCompiler(self.arch), functools.partial(AMDProgram, self))
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
super().__init__(device, AMDAllocator(self), AMDCompiler(self.arch), functools.partial(AMDProgram, self),
|
||||
functools.partial(HCQGraph, AMDDevice, HWPM4Queue, HWCopyQueue))
|
||||
|
||||
def synchronize(self):
|
||||
AMDDevice._wait_signal(self.timeline_signal, self.timeline_value - 1)
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import annotations
|
||||
import os, ctypes, pathlib, re, fcntl, functools, mmap, struct, tempfile, hashlib, subprocess, time
|
||||
import os, ctypes, pathlib, re, fcntl, functools, mmap, struct, tempfile, hashlib, subprocess, time, array
|
||||
from typing import Tuple, List, Any, cast
|
||||
from dataclasses import replace
|
||||
from tinygrad.device import Compiled, Compiler, CompilerOptions
|
||||
|
@ -116,7 +116,7 @@ class HWComputeQueue:
|
|||
def submit(self, dev:NVDevice):
|
||||
if len(self.q) == 0: return
|
||||
assert len(self.q) < (1 << 21)
|
||||
for i,packet in enumerate(self.q): dev.cmdq[dev.cmdq_wptr//4 + i] = packet
|
||||
dev.cmdq[dev.cmdq_wptr//4:dev.cmdq_wptr//4+len(self.q)] = array.array('I', self.q)
|
||||
fifo_entry = dev.compute_put_value % dev.compute_gpfifo_entries
|
||||
dev.compute_gpu_ring[fifo_entry] = ((dev.cmdq_page.base+dev.cmdq_wptr)//4 << 2) | (len(self.q) << 42) | (1 << 41)
|
||||
dev.compute_gpu_ring_controls.GPPut = (dev.compute_put_value + 1) % dev.compute_gpfifo_entries
|
||||
|
@ -146,7 +146,7 @@ class HWCopyQueue:
|
|||
|
||||
def submit(self, dev:NVDevice):
|
||||
if len(self.q) == 0: return
|
||||
for i,packet in enumerate(self.q): dev.cmdq[dev.cmdq_wptr//4 + i] = packet
|
||||
dev.cmdq[dev.cmdq_wptr//4:dev.cmdq_wptr//4+len(self.q)] = array.array('I', self.q)
|
||||
fifo_entry = dev.dma_put_value % dev.dma_gpfifo_entries
|
||||
dev.dma_gpu_ring[fifo_entry] = ((dev.cmdq_page.base+dev.cmdq_wptr)//4 << 2) | (len(self.q) << 42)
|
||||
dev.dma_gpu_ring_controls.GPPut = (dev.dma_put_value + 1) % dev.dma_gpfifo_entries
|
||||
|
@ -235,6 +235,9 @@ class NVProgram:
|
|||
|
||||
def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
|
||||
if prod(local_size) > 1024 or self.max_threads < prod(local_size): raise RuntimeError("Too many resources requsted for launch")
|
||||
if not hasattr(self, "args_struct_t"):
|
||||
self.args_struct_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(args))] +
|
||||
[(f'v{i}', ctypes.c_int) for i in range(len(vals))]))
|
||||
|
||||
if self.device.kernargs_ptr >= (self.device.kernargs_page.base + self.device.kernargs_page.length - self.kernargs_segment_size):
|
||||
self.device.kernargs_ptr = self.device.kernargs_page.base
|
||||
|
@ -265,7 +268,7 @@ class NVAllocator(LRUAllocator):
|
|||
|
||||
def _alloc(self, size:int, options:BufferOptions):
|
||||
if options.host: return self.device._gpu_host_alloc(size)
|
||||
else: return self.device._gpu_alloc(size)
|
||||
else: return self.device._gpu_alloc(size, map_to_cpu=options.cpu_access)
|
||||
|
||||
def _free(self, gpumem, options:BufferOptions):
|
||||
NVDevice.synchronize_system()
|
||||
|
@ -491,7 +494,9 @@ class NVDevice(Compiled):
|
|||
|
||||
self.arch: str = 'sm_89' # TODO: fix
|
||||
|
||||
super().__init__(device, NVAllocator(self), NVCompiler(self.arch), functools.partial(NVProgram, self))
|
||||
from tinygrad.runtime.graph.hcq import HCQGraph
|
||||
super().__init__(device, NVAllocator(self), NVCompiler(self.arch), functools.partial(NVProgram, self),
|
||||
functools.partial(HCQGraph, NVDevice, HWComputeQueue, HWCopyQueue))
|
||||
|
||||
self._cmdq_setup_compute_gpfifo()
|
||||
self._cmdq_setup_dma_gpfifo()
|
||||
|
|
Loading…
Reference in New Issue