restore hcq graph (#4513)

* Reapply "hcq graph (#4380)" (#4512)

This reverts commit 06c1e7498e.

* bring back hcq graph
This commit is contained in:
George Hotz 2024-05-10 07:45:05 -07:00 committed by GitHub
parent 06c1e7498e
commit 58e7256ce9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 159 additions and 10 deletions

View File

@ -10,6 +10,7 @@ from tinygrad.dtype import DType, ImageDType
class BufferOptions:
image: Optional[ImageDType] = None
uncached: bool = False
cpu_access: bool = False
host: bool = False
nolru: bool = False

View File

@ -41,12 +41,14 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
if ji.prg.__class__ in {EmptyOp, ViewOp}: continue
ji_graph_dev: Optional[Compiled] = None # device on which the ji will be graphed. Not graphed if None.
if isinstance(ji.prg, CompiledRunner): ji_graph_dev = ji.prg.device
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"HSA", "CUDA"}:
elif isinstance(ji.prg, BufferXfer) and ji.bufs[0] and ji.bufs[0].device.split(":", 1)[0] in {"HSA", "CUDA", "NV", "AMD"}:
ji_graph_dev = Device[ji.bufs[0].device]
graph_class = (ji_graph_dev.graph.func if isinstance(ji_graph_dev.graph, functools.partial) else ji_graph_dev.graph) if ji_graph_dev else None #type: ignore
can_be_graphed = ji_graph_dev and ji_graph_dev.graph
can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and (ji_graph_dev == current_device or
(isinstance(ji_graph_dev.graph, type) and issubclass(ji_graph_dev.graph, MultiGraphRunner) and type(ji_graph_dev) == type(current_device))) #type:ignore
can_share_graph = (ji_graph_dev == current_device or (isinstance(graph_class, type) and issubclass(graph_class, MultiGraphRunner)) and
type(ji_graph_dev) == type(current_device))
can_extend_graph_batch = can_be_graphed and len(current_batch) < max_batch_size and can_share_graph
if not can_extend_graph_batch and len(current_batch) > 0: flush_batch()
if can_be_graphed: current_batch.append(ji)

View File

@ -0,0 +1,139 @@
import ctypes, collections, array, time
from typing import List, Any, Dict, cast, Optional, Tuple, Set
from tinygrad.helpers import GraphException, round_up, to_mv
from tinygrad.buffer import Buffer, BufferOptions
from tinygrad.device import Compiled, CompiledRunner, Device
from tinygrad.shape.symbolic import Variable
from tinygrad.engine.realize import ExecItem, BufferXfer
from tinygrad.engine.jit import MultiGraphRunner
class HCQGraph(MultiGraphRunner):
def __init__(self, device_t, comp_hcq_t, copy_hcq_t, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
self.device_t, self.comp_hcq_t, self.copy_hcq_t = device_t, comp_hcq_t, copy_hcq_t
# Check all jit items are compatible.
self.devices = list(set(cast(self.device_t, d) for ji in jit_cache for d in [Device[cast(Buffer, x).device] for x in ji.bufs])) #type: ignore
if any(not isinstance(d, self.device_t) for d in self.devices): raise GraphException
# Allocate kernel args.
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
for ji in self.jit_cache:
kernargs_size[ji.prg.device] += round_up(ji.prg.clprg.kernargs_segment_size, 16) if isinstance(ji.prg, CompiledRunner) else 0
kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions(cpu_access=True)).va_addr for dev,sz in kernargs_size.items()}
# Fill initial arguments.
self.kargs_addrs: Dict[int, int] = {}
self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
for j,ji in enumerate(self.jit_cache):
if not isinstance(ji.prg, CompiledRunner): continue
self.kargs_addrs[j] = kernargs_ptrs[ji.prg.device]
kernargs_ptrs[ji.prg.device] += round_up(ji.prg.clprg.kernargs_segment_size, 16)
self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(self.kargs_addrs[j] + ji.prg.clprg.kernargs_offset)
for i in range(len(ji.bufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.bufs[i])._buf.va_addr)
for i in range(len(ji.prg.p.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.p.vars[i]])
# NV needs constbuffer to be set
if ji.prg.device.dname.startswith("NV"): to_mv(self.kargs_addrs[j], 0x160).cast('I')[:] = array.array('I', ji.prg.clprg.constbuffer_0)
# Build queues.
self.queue_list: List[Tuple[Any, ...]] = []
self.comp_queues: Dict[Compiled, Any] = collections.defaultdict(self.comp_hcq_t)
self.comp_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
self.comp_signal_val = {dev: 0 for dev in self.devices}
self.copy_queues: Dict[Compiled, Any] = collections.defaultdict(self.copy_hcq_t)
self.copy_signal = {dev: dev._get_signal(value=0) for dev in self.devices}
self.copy_signal_val = {dev: 0 for dev in self.devices}
self.kickoff_signal = self.devices[0]._get_signal(value=0)
self.kickoff_value = 0
self.graph_timeline = {dev: 0 for dev in self.devices}
self.copy_to_devs: Dict[Compiled, Set[Compiled]] = {dev: set() for dev in self.devices}
for j,ji in enumerate(self.jit_cache):
if isinstance(ji.prg, CompiledRunner):
deps = self.access_resources(ji.bufs[(outs:=ji.prg.p.outcount):], ji.bufs[:outs], (self.comp_signal[ji.prg.device], sig_val:=j+1))
deps.append((self.comp_signal[ji.prg.device], self.comp_signal_val[ji.prg.device]))
self.comp_signal_val[ji.prg.device] = sig_val
# Rebuilt runners with dynamic launch dims online.
if j in self.jc_idx_with_updatable_launch_dims:
if ji.prg.device in self.comp_queues: self.queue_list.append((self.comp_queues.pop(ji.prg.device), ji.prg.device))
self.queue_list.append((j, deps))
else:
for sig, val in deps: self.comp_queues[ji.prg.device].wait(sig, val)
self.comp_queues[ji.prg.device].exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.p.launch_dims(var_vals)) \
.signal(self.comp_signal[ji.prg.device], sig_val)
elif isinstance(ji.prg, BufferXfer):
dest, src = [cast(Buffer, x) for x in ji.bufs[0:2]]
Device[src.device]._gpu_map(dest._buf) #type: ignore
deps = self.access_resources([src], [dest], (self.copy_signal[Device[src.device]], sig_val:=j+1))
deps.append((self.copy_signal[Device[src.device]], self.copy_signal_val[Device[src.device]]))
self.copy_signal_val[Device[src.device]] = sig_val
for sig,val in deps: self.copy_queues[Device[src.device]].wait(sig, val)
self.copy_queues[Device[src.device]].copy(dest._buf.va_addr, src._buf.va_addr, dest.nbytes) \
.signal(self.copy_signal[Device[src.device]], sig_val)
self.copy_to_devs[Device[dest.device]].add(Device[src.device])
for dev in self.devices:
if self.copy_signal_val[dev] > 0: self.comp_queues[dev].wait(self.copy_signal[dev], self.copy_signal_val[dev])
for dep_dev in self.copy_to_devs: self.comp_queues[dev].wait(self.copy_signal[dep_dev], self.copy_signal_val[dep_dev])
self.queue_list.append((self.comp_queues.pop(dev), dev))
if self.copy_signal_val[dev] > 0: self.queue_list.append((self.copy_queues.pop(dev), dev))
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
# Wait and restore signals
self.kickoff_value += 1
for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
for dev in self.devices:
dev._set_signal(self.comp_signal[dev], 0)
dev._set_signal(self.copy_signal[dev], 0)
dev._set_signal(self.kickoff_signal, self.kickoff_value)
# Update rawbuffers
for (j,i),input_idx in self.input_replace.items():
self.ji_kargs_structs[j].__setattr__(f'f{i}', input_rawbuffers[input_idx]._buf.va_addr)
# Update var_vals
for j in self.jc_idx_with_updatable_var_vals:
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).p.vars):
self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
for dev in self.devices:
self.comp_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
.wait(self.kickoff_signal, self.kickoff_value).submit(dev)
self.copy_hcq_t().wait(dev.timeline_signal, dev.timeline_value - 1) \
.wait(self.kickoff_signal, self.kickoff_value).submit(dev)
for entry in self.queue_list:
if isinstance(entry[0], self.comp_hcq_t) or isinstance(entry[0], self.copy_hcq_t): queue, dev = entry
else:
# Kernel with dynamic launch bounds, rebuild it.
j, ji, deps, dev = entry[0], self.jit_cache[entry[0]], entry[1], self.jit_cache[entry[0]].prg.device
queue = self.comp_hcq_t()
for sig, val in deps: queue.wait(sig, val)
queue.exec(ji.prg.clprg, self.kargs_addrs[j], *ji.prg.launch_dims(var_vals)) \
.signal(self.comp_signal[dev], value=j+1)
queue.submit(dev)
for dev in self.devices:
self.comp_hcq_t().signal(dev.timeline_signal, dev.timeline_value).submit(dev)
self.graph_timeline[dev] = dev.timeline_value
dev.timeline_value += 1
if wait:
st = time.perf_counter()
for dev in self.devices: dev._wait_signal(dev.timeline_signal, self.graph_timeline[dev])
return time.perf_counter() - st
return None
def access_resources(self, read, write, new_dependency):
deps = self._access_resources(read, write, new_dependency)
return [(k, max(v for x, v in deps if id(x) == idk)) for idk, k in {id(x[0]): x[0] for x in deps}.items()]

View File

@ -388,7 +388,7 @@ class AMDAllocator(LRUAllocator):
def _alloc(self, size:int, options:BufferOptions):
try:
if options.host: return self.device._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_USERPTR, public=True)
else: return self.device._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM, public=True)
else: return self.device._gpu_alloc(size, kfd.KFD_IOC_ALLOC_MEM_FLAGS_VRAM, public=options.cpu_access)
except OSError as e:
if e.errno == errno.ENOMEM: raise MemoryError("Cannot allocate memory") from e
else: raise
@ -584,7 +584,9 @@ class AMDDevice(Compiled):
self.pm4_write_pointer = to_mv(self.pm4_queue.write_pointer_address, 8).cast("Q")
self.pm4_doorbell = to_mv(self.doorbells + self.pm4_queue.doorbell_offset - self.doorbells_base, 8).cast("Q")
super().__init__(device, AMDAllocator(self), AMDCompiler(self.arch), functools.partial(AMDProgram, self))
from tinygrad.runtime.graph.hcq import HCQGraph
super().__init__(device, AMDAllocator(self), AMDCompiler(self.arch), functools.partial(AMDProgram, self),
functools.partial(HCQGraph, AMDDevice, HWPM4Queue, HWCopyQueue))
def synchronize(self):
AMDDevice._wait_signal(self.timeline_signal, self.timeline_value - 1)

View File

@ -1,5 +1,5 @@
from __future__ import annotations
import os, ctypes, pathlib, re, fcntl, functools, mmap, struct, tempfile, hashlib, subprocess, time
import os, ctypes, pathlib, re, fcntl, functools, mmap, struct, tempfile, hashlib, subprocess, time, array
from typing import Tuple, List, Any, cast
from dataclasses import replace
from tinygrad.device import Compiled, Compiler, CompilerOptions
@ -116,7 +116,7 @@ class HWComputeQueue:
def submit(self, dev:NVDevice):
if len(self.q) == 0: return
assert len(self.q) < (1 << 21)
for i,packet in enumerate(self.q): dev.cmdq[dev.cmdq_wptr//4 + i] = packet
dev.cmdq[dev.cmdq_wptr//4:dev.cmdq_wptr//4+len(self.q)] = array.array('I', self.q)
fifo_entry = dev.compute_put_value % dev.compute_gpfifo_entries
dev.compute_gpu_ring[fifo_entry] = ((dev.cmdq_page.base+dev.cmdq_wptr)//4 << 2) | (len(self.q) << 42) | (1 << 41)
dev.compute_gpu_ring_controls.GPPut = (dev.compute_put_value + 1) % dev.compute_gpfifo_entries
@ -146,7 +146,7 @@ class HWCopyQueue:
def submit(self, dev:NVDevice):
if len(self.q) == 0: return
for i,packet in enumerate(self.q): dev.cmdq[dev.cmdq_wptr//4 + i] = packet
dev.cmdq[dev.cmdq_wptr//4:dev.cmdq_wptr//4+len(self.q)] = array.array('I', self.q)
fifo_entry = dev.dma_put_value % dev.dma_gpfifo_entries
dev.dma_gpu_ring[fifo_entry] = ((dev.cmdq_page.base+dev.cmdq_wptr)//4 << 2) | (len(self.q) << 42)
dev.dma_gpu_ring_controls.GPPut = (dev.dma_put_value + 1) % dev.dma_gpfifo_entries
@ -235,6 +235,9 @@ class NVProgram:
def __call__(self, *args, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
if prod(local_size) > 1024 or self.max_threads < prod(local_size): raise RuntimeError("Too many resources requsted for launch")
if not hasattr(self, "args_struct_t"):
self.args_struct_t = init_c_struct_t(tuple([(f'f{i}', ctypes.c_void_p) for i in range(len(args))] +
[(f'v{i}', ctypes.c_int) for i in range(len(vals))]))
if self.device.kernargs_ptr >= (self.device.kernargs_page.base + self.device.kernargs_page.length - self.kernargs_segment_size):
self.device.kernargs_ptr = self.device.kernargs_page.base
@ -265,7 +268,7 @@ class NVAllocator(LRUAllocator):
def _alloc(self, size:int, options:BufferOptions):
if options.host: return self.device._gpu_host_alloc(size)
else: return self.device._gpu_alloc(size)
else: return self.device._gpu_alloc(size, map_to_cpu=options.cpu_access)
def _free(self, gpumem, options:BufferOptions):
NVDevice.synchronize_system()
@ -491,7 +494,9 @@ class NVDevice(Compiled):
self.arch: str = 'sm_89' # TODO: fix
super().__init__(device, NVAllocator(self), NVCompiler(self.arch), functools.partial(NVProgram, self))
from tinygrad.runtime.graph.hcq import HCQGraph
super().__init__(device, NVAllocator(self), NVCompiler(self.arch), functools.partial(NVProgram, self),
functools.partial(HCQGraph, NVDevice, HWComputeQueue, HWCopyQueue))
self._cmdq_setup_compute_gpfifo()
self._cmdq_setup_dma_gpfifo()