move hcq device to runtime [pr] (#6879)

* things that are only used in one place don't belong in helpers [pr]

* start moving hcq device [pr]

* fix paths
This commit is contained in:
George Hotz 2024-10-04 22:26:50 +08:00 committed by GitHub
parent 5be2bd18a6
commit 6b063450df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 563 additions and 554 deletions

View File

@ -6,7 +6,7 @@ The main aspect of HCQ-compatible runtimes is how they interact with devices. In
### Command Queues
To interact with devices, there are 2 types of queues: `HWComputeQueue` and `HWCopyQueue`. Commands which are defined in a base `HWCommandQueue` class should be supported by both queues. These methods are timestamp and synchronization methods like [signal](#tinygrad.device.HWCommandQueue.signal) and [wait](#tinygrad.device.HWCommandQueue.wait).
To interact with devices, there are 2 types of queues: `HWComputeQueue` and `HWCopyQueue`. Commands which are defined in a base `HWCommandQueue` class should be supported by both queues. These methods are timestamp and synchronization methods like [signal](#tinygrad.runtime.support.hcq.HWCommandQueue.signal) and [wait](#tinygrad.runtime.support.hcq.HWCommandQueue.wait).
For example, the following Python code enqueues a wait, execute, and signal command on the HCQ-compatible device:
```python
@ -18,7 +18,7 @@ HWComputeQueue().wait(signal_to_wait, value_to_wait) \
Each runtime should implement the required functions that are defined in the `HWCommandQueue`, `HWComputeQueue`, and `HWCopyQueue` classes.
::: tinygrad.device.HWCommandQueue
::: tinygrad.runtime.support.hcq.HWCommandQueue
options:
members: [
"signal",
@ -31,7 +31,7 @@ Each runtime should implement the required functions that are defined in the `HW
]
show_source: false
::: tinygrad.device.HWComputeQueue
::: tinygrad.runtime.support.hcq.HWComputeQueue
options:
members: [
"memory_barrier",
@ -40,7 +40,7 @@ Each runtime should implement the required functions that are defined in the `HW
]
show_source: false
::: tinygrad.device.HWCopyQueue
::: tinygrad.runtime.support.hcq.HWCopyQueue
options:
members: [
"copy",
@ -52,7 +52,7 @@ Each runtime should implement the required functions that are defined in the `HW
To implement custom commands in the queue, use the @hcq_command decorator for your command implementations.
::: tinygrad.device.hcq_command
::: tinygrad.runtime.support.hcq.hcq_command
options:
members: [
"copy",
@ -64,7 +64,7 @@ To implement custom commands in the queue, use the @hcq_command decorator for yo
The `HCQCompiled` class defines the API for HCQ-compatible devices. This class serves as an abstract base class that device-specific implementations should inherit from and implement.
::: tinygrad.device.HCQCompiled
::: tinygrad.runtime.support.hcq.HCQCompiled
options:
show_source: false
@ -72,7 +72,7 @@ The `HCQCompiled` class defines the API for HCQ-compatible devices. This class s
Signals are device-dependent structures used for synchronization and timing in HCQ-compatible devices. They should be designed to record both a `value` and a `timestamp` within the same signal. HCQ-compatible backend implementations should use `HCQSignal` as a base class.
::: tinygrad.device.HCQSignal
::: tinygrad.runtime.support.hcq.HCQSignal
options:
members: [value, timestamp, wait]
show_source: false
@ -99,7 +99,7 @@ Each HCQ-compatible device must allocate two signals for global synchronization
The `HCQAllocator` base class simplifies allocator logic by leveraging [command queues](#command-queues) abstractions. This class efficiently handles copy and transfer operations, leaving only the alloc and free functions to be implemented by individual backends.
::: tinygrad.device.HCQAllocator
::: tinygrad.runtime.support.hcq.HCQAllocator
options:
members: [
"_alloc",
@ -111,7 +111,7 @@ The `HCQAllocator` base class simplifies allocator logic by leveraging [command
Backends must adhere to the `HCQBuffer` protocol when returning allocation results.
::: tinygrad.device.HCQBuffer
::: tinygrad.runtime.support.hcq.HCQBuffer
options:
members: true
show_source: false
@ -120,7 +120,7 @@ Backends must adhere to the `HCQBuffer` protocol when returning allocation resul
`HCQProgram` is a base class for defining programs compatible with HCQ-enabled devices. It provides a flexible framework for handling different argument layouts (see `HCQArgsState`).
::: tinygrad.device.HCQProgram
::: tinygrad.runtime.support.hcq.HCQProgram
options:
members: true
show_source: false
@ -129,7 +129,7 @@ Backends must adhere to the `HCQBuffer` protocol when returning allocation resul
`HCQArgsState` is a base class for managing the argument state for HCQ programs. Backend implementations should create a subclass of `HCQArgsState` to manage arguments for the given program.
::: tinygrad.device.HCQArgsState
::: tinygrad.runtime.support.hcq.HCQArgsState
options:
members: true
show_source: false

View File

@ -1,7 +1,8 @@
import unittest, ctypes, struct
from tinygrad import Device, Tensor, dtypes
from tinygrad.helpers import CI, getenv
from tinygrad.device import Buffer, BufferOptions, HCQCompiled
from tinygrad.device import Buffer, BufferOptions
from tinygrad.runtime.support.hcq import HCQCompiled
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import get_runner, CompiledRunner
from tinygrad.codegen.kernel import Kernel, Opt, OptOps

View File

@ -1,7 +1,8 @@
import unittest, struct, contextlib, tempfile, pathlib, json, time, atexit, random
from tinygrad import Device, Tensor, dtypes, TinyJit
from tinygrad.helpers import CI, getenv, Context
from tinygrad.device import Buffer, BufferOptions, ProfileLogger, HCQCompiled
from tinygrad.device import Buffer, BufferOptions
from tinygrad.runtime.support.hcq import ProfileLogger, HCQCompiled
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import get_runner

View File

@ -1,10 +1,9 @@
from __future__ import annotations
import multiprocessing, decimal, statistics, random, json
from dataclasses import dataclass, replace
from collections import defaultdict
from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type, Iterator, Union
import importlib, inspect, functools, pathlib, os, ctypes, atexit, time, contextlib, array
from tinygrad.helpers import SAVE_SCHEDULE, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv, PROFILEPATH, PROFILE
from typing import List, Optional, Dict, Tuple, Any, Iterator
import multiprocessing, importlib, inspect, functools, pathlib, os, ctypes, contextlib
from tinygrad.helpers import SAVE_SCHEDULE, getenv, diskcache_get, diskcache_put, DEBUG, GlobalCounters, flat_mv, from_mv
from tinygrad.dtype import DType, ImageDType
from tinygrad.renderer import Renderer
@ -200,533 +199,3 @@ class Compiled:
This method ensures that all previously queued operations on the device have been completed before proceeding.
"""
# override this in your device implementation
# **************** for HCQ Compatible Devices ****************
def hcq_command(func):
"""
Decorator for HWCommandQueue commands. Enables command indexing and stores metadata for command updates.
For example:
```python
@hcq_command
def command_method(self, ...): ...
```
"""
def __wrapper(self, *args, **kwargs):
self.cmds_offset.append(len(self.q))
func(self, *args, **kwargs)
self.cmds_len.append(len(self.q) - self.cmds_offset[-1])
self.cmds_meta.append(func.__name__)
return self
return __wrapper
class HWCommandQueue:
"""
A base class for hardware command queues in the HCQ (Hardware Command Queue) API.
Both compute and copy queues should have the following commands implemented.
"""
def __init__(self): self.q, self.binded_device, self.cmds_offset, self.cmds_len, self.cmds_meta = [], None, [], [], []
def __len__(self): return len(self.cmds_offset)
def _patch(self, cmd_idx, offset, data): self.q[(st:=self.cmds_offset[cmd_idx]+offset):st+len(data)] = array.array('I', data)
def _cur_cmd_idx(self) -> int:
"""
Returns the index of the command currently being enqueued.
Should be called only within functions that enqueue commands and are decorated with `@hcq_command`.
"""
return len(self) - 1
@hcq_command
def signal(self, signal:HCQSignal, value:int):
"""
Enqueues a signal command which sets the signal to the given value, ensuring all previous operations are completed.
Args:
signal: The signal to set
value: The value to set the signal to
"""
self._signal(signal, value)
def _signal(self, signal:HCQSignal, value:int): raise NotImplementedError("backend should overload this function")
@hcq_command
def wait(self, signal:HCQSignal, value:int):
"""
Enqueues a wait command which halts execution until the signal is greater than or equal to a specific value.
Args:
signal: The signal to wait on
value: The value to wait for
"""
self._wait(signal, value)
def _wait(self, signal, value): raise NotImplementedError("backend should overload this function")
@hcq_command
def timestamp(self, signal:HCQSignal):
"""
Enqueues a timestamp command which records the current time in a signal after all previously enqueued commands are completed.
Args:
signal: The signal to store the timestamp
"""
self._timestamp(signal)
def _timestamp(self, signal): raise NotImplementedError("backend should overload this function")
def update_signal(self, cmd_idx:int, signal:Optional[Any]=None, value:Optional[int]=None):
"""
Updates a previously queued signal command.
Args:
cmd_idx: Index of the signal command to update
signal: New signal to set (if None, keeps the original)
value: New value to set (if None, keeps the original)
"""
if self.cmds_meta[cmd_idx] != "signal": raise RuntimeError("called update_signal not on a signal command")
self._update_signal(cmd_idx, signal, value)
return self
def _update_signal(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
def update_wait(self, cmd_idx:int, signal:Optional[Any]=None, value:Optional[int]=None):
"""
Updates a previously queued wait command.
Args:
cmd_idx: Index of the wait command to update
signal: New signal to wait on (if None, keeps the original)
value: New value to wait for (if None, keeps the original)
"""
if self.cmds_meta[cmd_idx] != "wait": raise RuntimeError("called update_wait not on a wait command")
self._update_wait(cmd_idx, signal, value)
return self
def _update_wait(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
def bind(self, device:HCQCompiled):
"""
Associates the queue with a specific device for optimized execution.
This optional method allows backend implementations to tailor the queue for efficient use on the given device. When implemented, it can eliminate
the need to copy queues into the device, thereby enhancing performance.
Args:
device: The target device for queue optimization.
Note:
Implementing this method is optional but recommended for performance gains.
"""
def submit(self, device:HCQCompiled):
"""
Submits the command queue to a specific device for execution.
Args:
device: The device to submit the queue to
"""
if self.q: self._submit(device)
return self
def _submit(self, device:HCQCompiled): raise NotImplementedError("backend should overload this function")
class HWComputeQueue(HWCommandQueue):
@hcq_command
def memory_barrier(self):
"""
Enqueues a memory barrier command to ensure memory coherence between agents.
"""
self._memory_barrier()
def _memory_barrier(self): pass
@hcq_command
def exec(self, prg:HCQProgram, args_state:HCQArgsState, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int]):
"""
Enqueues an execution command for a kernel program.
Args:
prg: The program to execute
args_state: The args state to execute program with
global_size: The global work size
local_size: The local work size
"""
self._exec(prg, args_state, global_size, local_size)
def _exec(self, prg, args_state, global_size, local_size): raise NotImplementedError("backend should overload this function")
def update_exec(self, cmd_idx:int, global_size:Optional[Tuple[int,int,int]]=None, local_size:Optional[Tuple[int,int,int]]=None):
"""
Updates a previously queued execution command.
Args:
cmd_idx: Index of the execution command to update
global_size: New global work size (if None, keeps the original)
local_size: New local work size (if None, keeps the original)
"""
if self.cmds_meta[cmd_idx] != "exec": raise RuntimeError("called update_exec not on an exec command")
self._update_exec(cmd_idx, global_size, local_size)
return self
def _update_exec(self, cmd_idx, global_size, local_size): raise NotImplementedError("backend should overload this function")
class HWCopyQueue(HWCommandQueue):
@hcq_command
def copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int):
"""
Enqueues a copy command to transfer data.
Args:
dest: The destination of the copy
src: The source of the copy
copy_size: The size of data to copy
"""
self._copy(dest, src, copy_size)
def _copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int): raise NotImplementedError("backend should overload this function")
def update_copy(self, cmd_idx:int, dest:Optional[HCQBuffer]=None, src:Optional[HCQBuffer]=None):
"""
Updates a previously queued copy command.
Args:
cmd_idx: Index of the copy command to update
dest: New destination of the copy (if None, keeps the original)
src: New source of the copy (if None, keeps the original)
"""
if self.cmds_meta[cmd_idx] != "copy": raise RuntimeError("called update_copy not on an copy command")
self._update_copy(cmd_idx, dest, src)
return self
def _update_copy(self, cmd_idx, dest, src): raise NotImplementedError("backend should overload this function")
class HCQSignal:
def __init__(self, value:int=0): self._set_value(value)
@property
def value(self) -> int: return self._get_value()
@value.setter
def value(self, new_value:int): self._set_value(new_value)
def _get_value(self) -> int: raise NotImplementedError("_get_value() method must be implemented")
def _set_value(self, new_value:int): raise NotImplementedError("_set_value() method must be implemented")
@property
def timestamp(self) -> decimal.Decimal:
"""
Get the timestamp field of the signal.
This property provides read-only access to the signal's timestamp.
Returns:
The timestamp in microseconds.
"""
return self._get_timestamp()
def _get_timestamp(self) -> decimal.Decimal: raise NotImplementedError("_get_timestamp() method must be implemented")
def wait(self, value:int, timeout:int=10000):
"""
Waits the signal is greater than or equal to a specific value.
Args:
value: The value to wait for.
timeout: Maximum time to wait in milliseconds. Defaults to 10s.
"""
start_time = time.time() * 1000
while time.time() * 1000 - start_time < timeout:
if self.value >= value: return
raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
@contextlib.contextmanager
def hcq_profile(dev, enabled, desc, queue_type=None, queue=None):
st, en = (dev.signal_t(), dev.signal_t()) if enabled else (None, None)
if enabled and queue is not None: queue.timestamp(st)
elif enabled:
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(st).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
dev.timeline_value += 1
try: yield (st, en)
finally:
if enabled and queue is not None: queue.timestamp(en)
elif enabled:
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(en).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
dev.timeline_value += 1
if enabled and PROFILE: dev.sig_prof_records.append((st, en, desc, queue_type is dev.hw_copy_queue_t))
class HCQArgsState:
def __init__(self, ptr:int, prg:HCQProgram, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()): self.ptr, self.prg = ptr, prg
def update_buffer(self, index:int, buf:HCQBuffer): raise NotImplementedError("need update_buffer")
def update_var(self, index:int, val:int): raise NotImplementedError("need update_var")
class HCQProgram:
def __init__(self, args_state_t:Type[HCQArgsState], device:HCQCompiled, name:str, kernargs_alloc_size:int):
self.args_state_t, self.device, self.name, self.kernargs_alloc_size = args_state_t, device, name, kernargs_alloc_size
def fill_kernargs(self, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=(), kernargs_ptr:Optional[int]=None) -> HCQArgsState:
"""
Fills arguments for the kernel, optionally allocating space from the device if `kernargs_ptr` is not provided.
Args:
bufs: Buffers to be written to kernel arguments.
vals: Values to be written to kernel arguments.
kernargs_ptr: Optional pointer to pre-allocated kernel arguments memory.
Returns:
Arguments state with the given buffers and values set for the program.
"""
return self.args_state_t(kernargs_ptr or self.device._alloc_kernargs(self.kernargs_alloc_size), self, bufs, vals=vals)
def __call__(self, *bufs:HCQBuffer, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1),
vals:Tuple[int, ...]=(), wait:bool=False) -> Optional[float]:
"""
Enqueues the program for execution with the given arguments and dimensions.
Args:
bufs: Buffer arguments to execute the kernel with.
global_size: Specifies the global work size for kernel execution (equivalent to CUDA's grid size).
local_size: Specifies the local work size for kernel execution (equivalent to CUDA's block size).
vals: Value arguments to execute the kernel with.
wait: If True, waits for the kernel to complete execution.
Returns:
Execution time of the kernel if 'wait' is True, otherwise None.
"""
kernargs = self.fill_kernargs(bufs, vals)
q = self.device.hw_compute_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1).memory_barrier()
with hcq_profile(self.device, queue=q, desc=self.name, enabled=wait or PROFILE) as (sig_st, sig_en):
q.exec(self, kernargs, global_size, local_size)
q.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
self.device.timeline_value += 1
if wait: self.device.timeline_signal.wait(self.device.timeline_value - 1)
return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None
class ProfileLogger:
writers: int = 0
mjson: List[Dict] = []
actors: Dict[Union[str, Tuple[str, str]], int] = {}
def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None, args=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor, args)]
def _ensure_actor(self, actor_name, subactor_name):
if actor_name not in self.actors:
self.actors[actor_name] = (pid:=len(self.actors))
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
if (subactor_key:=(actor_name,subactor_name)) not in self.actors:
self.actors[subactor_key] = (tid:=len(self.actors))
self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
return self.actors[actor_name], self.actors.get(subactor_key, -1)
def __del__(self):
# perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
for name, st, et, actor_name, subactor_name, args in self.events:
pid, tid = self._ensure_actor(actor_name,subactor_name)
args = {k: (v if v.__class__ is str else v(et-st)) for k, v in args.items()} if args is not None else None
self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts": st, "dur": et-st, "args": args})
for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)
pid, tid = self._ensure_actor(actor_name,subactor_name)
self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts": en, "bp": "e"})
self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts": st, "bp": "e"})
ProfileLogger.writers -= 1
if ProfileLogger.writers == 0 and len(self.mjson) > 0:
with open(PROFILEPATH.value, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
print(f"Saved profile to {PROFILEPATH.value}. Use https://ui.perfetto.dev/ to open it.")
class HCQCompiled(Compiled):
"""
A base class for devices compatible with the HCQ (Hardware Command Queue) API.
"""
devices: List[HCQCompiled] = []
gpu2cpu_copy_time_diff: decimal.Decimal = decimal.Decimal('nan')
gpu2cpu_compute_time_diff: decimal.Decimal = decimal.Decimal('nan')
def __init__(self, device:str, allocator:Allocator, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[HCQSignal],
comp_queue_t:Type[HWComputeQueue], copy_queue_t:Optional[Type[HWCopyQueue]], timeline_signals:Tuple[HCQSignal, HCQSignal]):
self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t
self.timeline_value:int = 1
self.timeline_signal, self._shadow_timeline_signal = timeline_signals
self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = []
self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool, Optional[Dict]]] = []
self.dep_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = []
if PROFILE: self._prof_setup()
from tinygrad.runtime.graph.hcq import HCQGraph
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20, BufferOptions(cpu_access=True))
self.kernargs_ptr:int = self.kernargs_page.va_addr
self.devices.append(self)
def synchronize(self):
self.timeline_signal.wait(self.timeline_value - 1) if not hasattr(self, '_syncdev') else self._syncdev()
if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
if PROFILE:
self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp, None) for st, en, name, is_cp in self.sig_prof_records]
self.sig_prof_records = []
def _alloc_kernargs(self, alloc_size:int) -> int:
"""
Allocates space for arguments passed to the kernel.
"""
if self.kernargs_ptr >= (self.kernargs_page.va_addr + self.kernargs_page.size - alloc_size): self.kernargs_ptr = self.kernargs_page.va_addr
self.kernargs_ptr = (res:=self.kernargs_ptr) + alloc_size
return res
def _ensure_shared_time_base(self):
if not self.gpu2cpu_compute_time_diff.is_nan(): return
def _sync_cpu_queue(d, q_t):
q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.timeline_value).submit(d)
d.timeline_value += 1
st = time.perf_counter_ns()
d.timeline_signal.wait(d.timeline_value - 1) # average of the two
et = time.perf_counter_ns()
return (decimal.Decimal(et+st) / 2000) - d.timeline_signal.timestamp
# randomly sample the timing from GPU to CPU
choices: List = [(d, d.hw_compute_queue_t, []) for d in self.devices]
choices += [(d, d.hw_copy_queue_t, []) for d in self.devices if d.hw_copy_queue_t is not None]
for _ in range(100*len(self.devices)):
d,q,l = random.choice(choices)
l.append(_sync_cpu_queue(d,q))
for d,q,l in choices:
if q == d.hw_compute_queue_t: d.gpu2cpu_compute_time_diff = statistics.median(l)
if q == d.hw_copy_queue_t: d.gpu2cpu_copy_time_diff = statistics.median(l)
def _sync_gpu_to_gpu_queue(d1, d2, q1_t, q2_t):
q1_t().signal(d1.timeline_signal, d1.timeline_value).wait(d2.timeline_signal, d2.timeline_value) \
.timestamp(d1.timeline_signal).signal(d1.timeline_signal, d1.timeline_value+1).submit(d1)
q2_t().signal(d2.timeline_signal, d2.timeline_value).wait(d1.timeline_signal, d1.timeline_value) \
.timestamp(d2.timeline_signal).signal(d2.timeline_signal, d2.timeline_value+1).submit(d2)
d1.timeline_value += 2
d2.timeline_value += 2
d1.timeline_signal.wait(d1.timeline_value - 1)
d2.timeline_signal.wait(d2.timeline_value - 1)
return d2.timeline_signal.timestamp - d1.timeline_signal.timestamp
# then test it by timing the GPU to GPU times
jitter_matrix = [[float('nan')]*len(self.devices) for _ in range(len(self.devices))]
for i1, d1 in enumerate(self.devices):
for i2, d2 in enumerate(self.devices):
if d1 == d2: continue
d1_to_d2 = statistics.median(_sync_gpu_to_gpu_queue(d1, d2, d1.hw_compute_queue_t, d2.hw_compute_queue_t) - \
_sync_gpu_to_gpu_queue(d2, d1, d2.hw_compute_queue_t, d1.hw_compute_queue_t) for _ in range(20)) / 2
jitter_matrix[i1][i2] = d1_to_d2 - (d1.gpu2cpu_compute_time_diff - d2.gpu2cpu_compute_time_diff)
print("pairwise clock jitter matrix (us):\n" + '\n'.join([''.join([f'{float(item):8.3f}' for item in row]) for row in jitter_matrix]))
def _gpu2cpu_time(self, gpu_time:decimal.Decimal, is_copy:bool) -> float:
"""
Translates local gpu time (timestamp) into global cpu time.
"""
self._ensure_shared_time_base()
return float(gpu_time + (self.gpu2cpu_copy_time_diff if is_copy else self.gpu2cpu_compute_time_diff))
def _prof_setup(self):
if hasattr(self, 'profile_logger'): return
atexit.register(self._prof_finalize)
self.profile_logger = ProfileLogger()
def _prof_finalize(self):
qname = ["COMPUTE", "DMA"]
# Sync to be sure all events on the device are recorded.
self.synchronize()
for st, en, name, is_cp, args in self.raw_prof_records:
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, qname[is_cp], args)]
for a_st, a_en, a_dev, a_is_copy, b_st, b_en, b_dev, b_is_copy in self.dep_prof_records:
# Perfetto connects nodes based on timing data, ensuring every choice is valid by averaging times to a midpoint.
a_tm, b_tm = a_dev._gpu2cpu_time((a_st+a_en)/decimal.Decimal(2), a_is_copy), b_dev._gpu2cpu_time((b_st+b_en)/decimal.Decimal(2), b_is_copy)
self.profile_logger.deps += [(a_tm, b_tm, a_dev.dname, qname[a_is_copy], b_dev.dname, qname[b_is_copy])]
self.raw_prof_records, self.dep_prof_records = [], []
# Remove the logger, this flushes all data written by the device.
del self.profile_logger
def _wrap_timeline_signal(self):
self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
self.timeline_signal.value = 0
cast(HCQAllocator, self.allocator).b_timeline = [0] * len(cast(HCQAllocator, self.allocator).b)
# Protocol for hcq compatible allocators for allocated buffers to contain VA address and it's size.
class HCQBuffer(Protocol): va_addr:int; size:int # noqa: E702
class HCQAllocator(LRUAllocator): # pylint: disable=abstract-method
"""
A base allocator class compatible with the HCQ (Hardware Command Queue) API.
This class implements basic copy operations following the HCQ API, utilizing both `HWComputeQueue` and `HWCopyQueue`.
"""
def __init__(self, device:HCQCompiled, batch_size:int=(2 << 20), batch_cnt:int=32):
self.device:Any = device
self.b = [self._alloc(batch_size, BufferOptions(host=True)) for _ in range(batch_cnt)]
self.b_timeline, self.b_next = [0] * len(self.b), 0
super().__init__()
def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer: raise NotImplementedError("need hcq compat alloc")
def copyin(self, dest:HCQBuffer, src:memoryview):
with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"CPU -> {self.device.dname}", enabled=PROFILE):
for i in range(0, src.nbytes, self.b[0].size):
self.b_next = (self.b_next + 1) % len(self.b)
self.device.timeline_signal.wait(self.b_timeline[self.b_next])
ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
.copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
self.b_timeline[self.b_next] = self.device.timeline_value
self.device.timeline_value += 1
def copy_from_disk(self, dest:HCQBuffer, src, size):
def _get_temp_buf():
# Check if the next buffer is safe to be used (its signal has passed) and reserve it.
if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.device.timeline_signal.value:
self.b_timeline[(self.b_next + 1) % len(self.b)], self.b_next = (1 << 64), (self.b_next + 1) % len(self.b)
return (self.b[self.b_next].va_addr, self.b_next)
return None
with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"DISK -> {self.device.dname}", enabled=PROFILE):
for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
.copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
self.b_timeline[batch_info[1]] = self.device.timeline_value
self.device.timeline_value += 1
def copyout(self, dest:memoryview, src:HCQBuffer):
self.device.synchronize()
with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"{self.device.dname} -> CPU", enabled=PROFILE):
for i in range(0, dest.nbytes, self.b[0].size):
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
.copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
self.device.timeline_signal.wait(self.device.timeline_value)
self.device.timeline_value += 1
ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
def transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev, dest_dev):
src_dev.allocator.map(dest)
with hcq_profile(src_dev, queue_type=src_dev.hw_copy_queue_t, desc=f"{src_dev.dname} -> {dest_dev.dname}", enabled=PROFILE):
src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
.copy(dest.va_addr, src.va_addr, sz) \
.signal(src_dev.timeline_signal, src_dev.timeline_value).submit(src_dev)
src_dev.timeline_value += 1
if src_dev != dest_dev:
dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
.signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
dest_dev.timeline_value += 1
def map(self, buf:HCQBuffer): pass
def offset(self, buf, size:int, offset:int) -> HCQBuffer:
return type(buf)(va_addr=buf.va_addr + offset, size=size, **{k:v for k,v in buf.__dict__.items() if k not in ['va_addr', 'size']},
**{x[0]:getattr(buf, x[0]) for x in getattr(buf, '_fields_', []) if x[0] not in ['va_addr', 'size']}, _base=buf)

View File

@ -1,8 +1,8 @@
import collections, time
from typing import List, Any, Dict, cast, Optional, Tuple, Set
from tinygrad.helpers import round_up, PROFILE, memsize_to_str
from tinygrad.device import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWCommandQueue, HWComputeQueue, HWCopyQueue, HCQArgsState, \
Buffer, BufferOptions, Compiled, Device
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWCommandQueue, HWComputeQueue, HWCopyQueue, HCQArgsState
from tinygrad.device import Buffer, BufferOptions, Compiled, Device
from tinygrad.shape.symbolic import Variable
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
from tinygrad.engine.jit import MultiGraphRunner

View File

@ -2,8 +2,8 @@ from __future__ import annotations
from typing import Tuple, List, Any
import os, ctypes, ctypes.util, functools, pathlib, mmap, errno, time, array, contextlib, decimal
from dataclasses import dataclass
from tinygrad.device import HCQCompiled, HCQAllocator, HCQBuffer, HWComputeQueue, HWCopyQueue, HCQArgsState, \
HCQSignal, HCQProgram, BufferOptions
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWComputeQueue, HWCopyQueue, HCQArgsState, HCQSignal, HCQProgram
from tinygrad.device import BufferOptions
from tinygrad.helpers import getenv, to_mv, round_up, data64_le, DEBUG, mv_address
from tinygrad.renderer.cstyle import AMDRenderer
from tinygrad.runtime.autogen import kfd, hsa, amd_gpu, libc

View File

@ -2,8 +2,9 @@ from __future__ import annotations
import os, ctypes, contextlib, re, fcntl, functools, mmap, struct, array, decimal
from typing import Tuple, List, Any, cast, Union, Dict, Type
from dataclasses import dataclass
from tinygrad.device import HCQCompiled, HCQAllocator, HCQBuffer, HWCommandQueue, HWComputeQueue, HWCopyQueue, hcq_command, \
HCQArgsState, HCQProgram, HCQSignal, BufferOptions
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQBuffer, HWCommandQueue, HWComputeQueue, HWCopyQueue, hcq_command
from tinygrad.runtime.support.hcq import HCQArgsState, HCQProgram, HCQSignal
from tinygrad.device import BufferOptions
from tinygrad.helpers import getenv, mv_address, init_c_struct_t, to_mv, round_up, data64, data64_le, DEBUG, prod
from tinygrad.renderer.assembly import PTXRenderer
from tinygrad.renderer.cstyle import NVRenderer

View File

@ -2,7 +2,8 @@ from __future__ import annotations
import os, ctypes, functools, mmap, struct, array, decimal, math
from types import SimpleNamespace
from typing import Tuple, List, Any, cast
from tinygrad.device import BufferOptions, HCQBuffer, HWComputeQueue, HCQProgram, HCQCompiled, HCQSignal, HCQAllocator, HCQArgsState
from tinygrad.device import BufferOptions
from tinygrad.runtime.support.hcq import HCQBuffer, HWComputeQueue, HCQProgram, HCQCompiled, HCQSignal, HCQAllocator, HCQArgsState
from tinygrad.runtime.autogen import kgsl, adreno, libc
from tinygrad.runtime.ops_gpu import CLCompiler, CLDevice
from tinygrad.renderer.cstyle import QCOMRenderer

View File

@ -0,0 +1,536 @@
from __future__ import annotations
from typing import List, Optional, Dict, Tuple, Any, cast, Protocol, Type, Union
import contextlib, decimal, statistics, random, json, atexit, time, array, ctypes
from tinygrad.helpers import PROFILEPATH, PROFILE, from_mv
from tinygrad.renderer import Renderer
from tinygrad.device import BufferOptions, Allocator, Compiler, Compiled, LRUAllocator
# **************** for HCQ Compatible Devices ****************
def hcq_command(func):
"""
Decorator for HWCommandQueue commands. Enables command indexing and stores metadata for command updates.
For example:
```python
@hcq_command
def command_method(self, ...): ...
```
"""
def __wrapper(self, *args, **kwargs):
self.cmds_offset.append(len(self.q))
func(self, *args, **kwargs)
self.cmds_len.append(len(self.q) - self.cmds_offset[-1])
self.cmds_meta.append(func.__name__)
return self
return __wrapper
class HWCommandQueue:
"""
A base class for hardware command queues in the HCQ (Hardware Command Queue) API.
Both compute and copy queues should have the following commands implemented.
"""
def __init__(self): self.q, self.binded_device, self.cmds_offset, self.cmds_len, self.cmds_meta = [], None, [], [], []
def __len__(self): return len(self.cmds_offset)
def _patch(self, cmd_idx, offset, data): self.q[(st:=self.cmds_offset[cmd_idx]+offset):st+len(data)] = array.array('I', data)
def _cur_cmd_idx(self) -> int:
"""
Returns the index of the command currently being enqueued.
Should be called only within functions that enqueue commands and are decorated with `@hcq_command`.
"""
return len(self) - 1
@hcq_command
def signal(self, signal:HCQSignal, value:int):
"""
Enqueues a signal command which sets the signal to the given value, ensuring all previous operations are completed.
Args:
signal: The signal to set
value: The value to set the signal to
"""
self._signal(signal, value)
def _signal(self, signal:HCQSignal, value:int): raise NotImplementedError("backend should overload this function")
@hcq_command
def wait(self, signal:HCQSignal, value:int):
"""
Enqueues a wait command which halts execution until the signal is greater than or equal to a specific value.
Args:
signal: The signal to wait on
value: The value to wait for
"""
self._wait(signal, value)
def _wait(self, signal, value): raise NotImplementedError("backend should overload this function")
@hcq_command
def timestamp(self, signal:HCQSignal):
"""
Enqueues a timestamp command which records the current time in a signal after all previously enqueued commands are completed.
Args:
signal: The signal to store the timestamp
"""
self._timestamp(signal)
def _timestamp(self, signal): raise NotImplementedError("backend should overload this function")
def update_signal(self, cmd_idx:int, signal:Optional[Any]=None, value:Optional[int]=None):
"""
Updates a previously queued signal command.
Args:
cmd_idx: Index of the signal command to update
signal: New signal to set (if None, keeps the original)
value: New value to set (if None, keeps the original)
"""
if self.cmds_meta[cmd_idx] != "signal": raise RuntimeError("called update_signal not on a signal command")
self._update_signal(cmd_idx, signal, value)
return self
def _update_signal(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
def update_wait(self, cmd_idx:int, signal:Optional[Any]=None, value:Optional[int]=None):
"""
Updates a previously queued wait command.
Args:
cmd_idx: Index of the wait command to update
signal: New signal to wait on (if None, keeps the original)
value: New value to wait for (if None, keeps the original)
"""
if self.cmds_meta[cmd_idx] != "wait": raise RuntimeError("called update_wait not on a wait command")
self._update_wait(cmd_idx, signal, value)
return self
def _update_wait(self, cmd_idx:int, signal:Optional[Any], value:Optional[int]): raise NotImplementedError("backend should overload this function")
def bind(self, device:HCQCompiled):
"""
Associates the queue with a specific device for optimized execution.
This optional method allows backend implementations to tailor the queue for efficient use on the given device. When implemented, it can eliminate
the need to copy queues into the device, thereby enhancing performance.
Args:
device: The target device for queue optimization.
Note:
Implementing this method is optional but recommended for performance gains.
"""
def submit(self, device:HCQCompiled):
"""
Submits the command queue to a specific device for execution.
Args:
device: The device to submit the queue to
"""
if self.q: self._submit(device)
return self
def _submit(self, device:HCQCompiled): raise NotImplementedError("backend should overload this function")
class HWComputeQueue(HWCommandQueue):
@hcq_command
def memory_barrier(self):
"""
Enqueues a memory barrier command to ensure memory coherence between agents.
"""
self._memory_barrier()
def _memory_barrier(self): pass
@hcq_command
def exec(self, prg:HCQProgram, args_state:HCQArgsState, global_size:Tuple[int,int,int], local_size:Tuple[int,int,int]):
"""
Enqueues an execution command for a kernel program.
Args:
prg: The program to execute
args_state: The args state to execute program with
global_size: The global work size
local_size: The local work size
"""
self._exec(prg, args_state, global_size, local_size)
def _exec(self, prg, args_state, global_size, local_size): raise NotImplementedError("backend should overload this function")
def update_exec(self, cmd_idx:int, global_size:Optional[Tuple[int,int,int]]=None, local_size:Optional[Tuple[int,int,int]]=None):
"""
Updates a previously queued execution command.
Args:
cmd_idx: Index of the execution command to update
global_size: New global work size (if None, keeps the original)
local_size: New local work size (if None, keeps the original)
"""
if self.cmds_meta[cmd_idx] != "exec": raise RuntimeError("called update_exec not on an exec command")
self._update_exec(cmd_idx, global_size, local_size)
return self
def _update_exec(self, cmd_idx, global_size, local_size): raise NotImplementedError("backend should overload this function")
class HWCopyQueue(HWCommandQueue):
@hcq_command
def copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int):
"""
Enqueues a copy command to transfer data.
Args:
dest: The destination of the copy
src: The source of the copy
copy_size: The size of data to copy
"""
self._copy(dest, src, copy_size)
def _copy(self, dest:HCQBuffer, src:HCQBuffer, copy_size:int): raise NotImplementedError("backend should overload this function")
def update_copy(self, cmd_idx:int, dest:Optional[HCQBuffer]=None, src:Optional[HCQBuffer]=None):
"""
Updates a previously queued copy command.
Args:
cmd_idx: Index of the copy command to update
dest: New destination of the copy (if None, keeps the original)
src: New source of the copy (if None, keeps the original)
"""
if self.cmds_meta[cmd_idx] != "copy": raise RuntimeError("called update_copy not on an copy command")
self._update_copy(cmd_idx, dest, src)
return self
def _update_copy(self, cmd_idx, dest, src): raise NotImplementedError("backend should overload this function")
class HCQSignal:
def __init__(self, value:int=0): self._set_value(value)
@property
def value(self) -> int: return self._get_value()
@value.setter
def value(self, new_value:int): self._set_value(new_value)
def _get_value(self) -> int: raise NotImplementedError("_get_value() method must be implemented")
def _set_value(self, new_value:int): raise NotImplementedError("_set_value() method must be implemented")
@property
def timestamp(self) -> decimal.Decimal:
"""
Get the timestamp field of the signal.
This property provides read-only access to the signal's timestamp.
Returns:
The timestamp in microseconds.
"""
return self._get_timestamp()
def _get_timestamp(self) -> decimal.Decimal: raise NotImplementedError("_get_timestamp() method must be implemented")
def wait(self, value:int, timeout:int=10000):
"""
Waits the signal is greater than or equal to a specific value.
Args:
value: The value to wait for.
timeout: Maximum time to wait in milliseconds. Defaults to 10s.
"""
start_time = time.time() * 1000
while time.time() * 1000 - start_time < timeout:
if self.value >= value: return
raise RuntimeError(f"Wait timeout: {timeout} ms! (the signal is not set to {value}, but {self.value})")
@contextlib.contextmanager
def hcq_profile(dev, enabled, desc, queue_type=None, queue=None):
st, en = (dev.signal_t(), dev.signal_t()) if enabled else (None, None)
if enabled and queue is not None: queue.timestamp(st)
elif enabled:
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(st).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
dev.timeline_value += 1
try: yield (st, en)
finally:
if enabled and queue is not None: queue.timestamp(en)
elif enabled:
queue_type().wait(dev.timeline_signal, dev.timeline_value - 1).timestamp(en).signal(dev.timeline_signal, dev.timeline_value).submit(dev)
dev.timeline_value += 1
if enabled and PROFILE: dev.sig_prof_records.append((st, en, desc, queue_type is dev.hw_copy_queue_t))
class HCQArgsState:
def __init__(self, ptr:int, prg:HCQProgram, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=()): self.ptr, self.prg = ptr, prg
def update_buffer(self, index:int, buf:HCQBuffer): raise NotImplementedError("need update_buffer")
def update_var(self, index:int, val:int): raise NotImplementedError("need update_var")
class HCQProgram:
def __init__(self, args_state_t:Type[HCQArgsState], device:HCQCompiled, name:str, kernargs_alloc_size:int):
self.args_state_t, self.device, self.name, self.kernargs_alloc_size = args_state_t, device, name, kernargs_alloc_size
def fill_kernargs(self, bufs:Tuple[HCQBuffer, ...], vals:Tuple[int, ...]=(), kernargs_ptr:Optional[int]=None) -> HCQArgsState:
"""
Fills arguments for the kernel, optionally allocating space from the device if `kernargs_ptr` is not provided.
Args:
bufs: Buffers to be written to kernel arguments.
vals: Values to be written to kernel arguments.
kernargs_ptr: Optional pointer to pre-allocated kernel arguments memory.
Returns:
Arguments state with the given buffers and values set for the program.
"""
return self.args_state_t(kernargs_ptr or self.device._alloc_kernargs(self.kernargs_alloc_size), self, bufs, vals=vals)
def __call__(self, *bufs:HCQBuffer, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1),
vals:Tuple[int, ...]=(), wait:bool=False) -> Optional[float]:
"""
Enqueues the program for execution with the given arguments and dimensions.
Args:
bufs: Buffer arguments to execute the kernel with.
global_size: Specifies the global work size for kernel execution (equivalent to CUDA's grid size).
local_size: Specifies the local work size for kernel execution (equivalent to CUDA's block size).
vals: Value arguments to execute the kernel with.
wait: If True, waits for the kernel to complete execution.
Returns:
Execution time of the kernel if 'wait' is True, otherwise None.
"""
kernargs = self.fill_kernargs(bufs, vals)
q = self.device.hw_compute_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1).memory_barrier()
with hcq_profile(self.device, queue=q, desc=self.name, enabled=wait or PROFILE) as (sig_st, sig_en):
q.exec(self, kernargs, global_size, local_size)
q.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
self.device.timeline_value += 1
if wait: self.device.timeline_signal.wait(self.device.timeline_value - 1)
return (float(sig_en.timestamp - sig_st.timestamp) / 1e6) if wait else None
class ProfileLogger:
writers: int = 0
mjson: List[Dict] = []
actors: Dict[Union[str, Tuple[str, str]], int] = {}
def __init__(self): self.events, self.deps, ProfileLogger.writers = [], [], ProfileLogger.writers + 1
def add_event(self, ev_name, ev_start, ev_end, actor, subactor=None, args=None): self.events += [(ev_name, ev_start, ev_end, actor, subactor, args)]
def _ensure_actor(self, actor_name, subactor_name):
if actor_name not in self.actors:
self.actors[actor_name] = (pid:=len(self.actors))
self.mjson.append({"name": "process_name", "ph": "M", "pid": pid, "args": {"name": actor_name}})
if (subactor_key:=(actor_name,subactor_name)) not in self.actors:
self.actors[subactor_key] = (tid:=len(self.actors))
self.mjson.append({"name": "thread_name", "ph": "M", "pid": self.actors[actor_name], "tid":tid, "args": {"name": subactor_name}})
return self.actors[actor_name], self.actors.get(subactor_key, -1)
def __del__(self):
# perfetto json docs: https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
for name, st, et, actor_name, subactor_name, args in self.events:
pid, tid = self._ensure_actor(actor_name,subactor_name)
args = {k: (v if v.__class__ is str else v(et-st)) for k, v in args.items()} if args is not None else None
self.mjson.append({"name": name, "ph": "X", "pid": pid, "tid": tid, "ts": st, "dur": et-st, "args": args})
for en,st,dep_actor_name,dep_subactor_name,actor_name,subactor_name in self.deps:
dep_pid, dep_tid = self._ensure_actor(dep_actor_name,dep_subactor_name)
pid, tid = self._ensure_actor(actor_name,subactor_name)
self.mjson.append({"ph": "s", "pid": dep_pid, "tid": dep_tid, "id": len(self.mjson), "ts": en, "bp": "e"})
self.mjson.append({"ph": "f", "pid": pid, "tid": tid, "id": len(self.mjson)-1, "ts": st, "bp": "e"})
ProfileLogger.writers -= 1
if ProfileLogger.writers == 0 and len(self.mjson) > 0:
with open(PROFILEPATH.value, "w") as f: f.write(json.dumps({"traceEvents": self.mjson}))
print(f"Saved profile to {PROFILEPATH.value}. Use https://ui.perfetto.dev/ to open it.")
class HCQCompiled(Compiled):
"""
A base class for devices compatible with the HCQ (Hardware Command Queue) API.
"""
devices: List[HCQCompiled] = []
gpu2cpu_copy_time_diff: decimal.Decimal = decimal.Decimal('nan')
gpu2cpu_compute_time_diff: decimal.Decimal = decimal.Decimal('nan')
def __init__(self, device:str, allocator:Allocator, renderer:Renderer, compiler:Compiler, runtime, signal_t:Type[HCQSignal],
comp_queue_t:Type[HWComputeQueue], copy_queue_t:Optional[Type[HWCopyQueue]], timeline_signals:Tuple[HCQSignal, HCQSignal]):
self.signal_t, self.hw_compute_queue_t, self.hw_copy_queue_t = signal_t, comp_queue_t, copy_queue_t
self.timeline_value:int = 1
self.timeline_signal, self._shadow_timeline_signal = timeline_signals
self.sig_prof_records:List[Tuple[HCQSignal, HCQSignal, str, bool]] = []
self.raw_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, str, bool, Optional[Dict]]] = []
self.dep_prof_records:List[Tuple[decimal.Decimal, decimal.Decimal, HCQCompiled, bool, decimal.Decimal, decimal.Decimal, HCQCompiled, bool]] = []
if PROFILE: self._prof_setup()
from tinygrad.runtime.graph.hcq import HCQGraph
super().__init__(device, allocator, renderer, compiler, runtime, HCQGraph)
self.kernargs_page:HCQBuffer = self.allocator.alloc(16 << 20, BufferOptions(cpu_access=True))
self.kernargs_ptr:int = self.kernargs_page.va_addr
self.devices.append(self)
def synchronize(self):
self.timeline_signal.wait(self.timeline_value - 1) if not hasattr(self, '_syncdev') else self._syncdev()
if self.timeline_value > (1 << 31): self._wrap_timeline_signal()
if PROFILE:
self.raw_prof_records += [(st.timestamp, en.timestamp, name, is_cp, None) for st, en, name, is_cp in self.sig_prof_records]
self.sig_prof_records = []
def _alloc_kernargs(self, alloc_size:int) -> int:
"""
Allocates space for arguments passed to the kernel.
"""
if self.kernargs_ptr >= (self.kernargs_page.va_addr + self.kernargs_page.size - alloc_size): self.kernargs_ptr = self.kernargs_page.va_addr
self.kernargs_ptr = (res:=self.kernargs_ptr) + alloc_size
return res
def _ensure_shared_time_base(self):
if not self.gpu2cpu_compute_time_diff.is_nan(): return
def _sync_cpu_queue(d, q_t):
q_t().timestamp(d.timeline_signal).signal(d.timeline_signal, d.timeline_value).submit(d)
d.timeline_value += 1
st = time.perf_counter_ns()
d.timeline_signal.wait(d.timeline_value - 1) # average of the two
et = time.perf_counter_ns()
return (decimal.Decimal(et+st) / 2000) - d.timeline_signal.timestamp
# randomly sample the timing from GPU to CPU
choices: List = [(d, d.hw_compute_queue_t, []) for d in self.devices]
choices += [(d, d.hw_copy_queue_t, []) for d in self.devices if d.hw_copy_queue_t is not None]
for _ in range(100*len(self.devices)):
d,q,l = random.choice(choices)
l.append(_sync_cpu_queue(d,q))
for d,q,l in choices:
if q == d.hw_compute_queue_t: d.gpu2cpu_compute_time_diff = statistics.median(l)
if q == d.hw_copy_queue_t: d.gpu2cpu_copy_time_diff = statistics.median(l)
def _sync_gpu_to_gpu_queue(d1, d2, q1_t, q2_t):
q1_t().signal(d1.timeline_signal, d1.timeline_value).wait(d2.timeline_signal, d2.timeline_value) \
.timestamp(d1.timeline_signal).signal(d1.timeline_signal, d1.timeline_value+1).submit(d1)
q2_t().signal(d2.timeline_signal, d2.timeline_value).wait(d1.timeline_signal, d1.timeline_value) \
.timestamp(d2.timeline_signal).signal(d2.timeline_signal, d2.timeline_value+1).submit(d2)
d1.timeline_value += 2
d2.timeline_value += 2
d1.timeline_signal.wait(d1.timeline_value - 1)
d2.timeline_signal.wait(d2.timeline_value - 1)
return d2.timeline_signal.timestamp - d1.timeline_signal.timestamp
# then test it by timing the GPU to GPU times
jitter_matrix = [[float('nan')]*len(self.devices) for _ in range(len(self.devices))]
for i1, d1 in enumerate(self.devices):
for i2, d2 in enumerate(self.devices):
if d1 == d2: continue
d1_to_d2 = statistics.median(_sync_gpu_to_gpu_queue(d1, d2, d1.hw_compute_queue_t, d2.hw_compute_queue_t) - \
_sync_gpu_to_gpu_queue(d2, d1, d2.hw_compute_queue_t, d1.hw_compute_queue_t) for _ in range(20)) / 2
jitter_matrix[i1][i2] = d1_to_d2 - (d1.gpu2cpu_compute_time_diff - d2.gpu2cpu_compute_time_diff)
print("pairwise clock jitter matrix (us):\n" + '\n'.join([''.join([f'{float(item):8.3f}' for item in row]) for row in jitter_matrix]))
def _gpu2cpu_time(self, gpu_time:decimal.Decimal, is_copy:bool) -> float:
"""
Translates local gpu time (timestamp) into global cpu time.
"""
self._ensure_shared_time_base()
return float(gpu_time + (self.gpu2cpu_copy_time_diff if is_copy else self.gpu2cpu_compute_time_diff))
def _prof_setup(self):
if hasattr(self, 'profile_logger'): return
atexit.register(self._prof_finalize)
self.profile_logger = ProfileLogger()
def _prof_finalize(self):
qname = ["COMPUTE", "DMA"]
# Sync to be sure all events on the device are recorded.
self.synchronize()
for st, en, name, is_cp, args in self.raw_prof_records:
self.profile_logger.events += [(name, self._gpu2cpu_time(st, is_cp), self._gpu2cpu_time(en, is_cp), self.dname, qname[is_cp], args)]
for a_st, a_en, a_dev, a_is_copy, b_st, b_en, b_dev, b_is_copy in self.dep_prof_records:
# Perfetto connects nodes based on timing data, ensuring every choice is valid by averaging times to a midpoint.
a_tm, b_tm = a_dev._gpu2cpu_time((a_st+a_en)/decimal.Decimal(2), a_is_copy), b_dev._gpu2cpu_time((b_st+b_en)/decimal.Decimal(2), b_is_copy)
self.profile_logger.deps += [(a_tm, b_tm, a_dev.dname, qname[a_is_copy], b_dev.dname, qname[b_is_copy])]
self.raw_prof_records, self.dep_prof_records = [], []
# Remove the logger, this flushes all data written by the device.
del self.profile_logger
def _wrap_timeline_signal(self):
self.timeline_signal, self._shadow_timeline_signal, self.timeline_value = self._shadow_timeline_signal, self.timeline_signal, 1
self.timeline_signal.value = 0
cast(HCQAllocator, self.allocator).b_timeline = [0] * len(cast(HCQAllocator, self.allocator).b)
# Protocol for hcq compatible allocators for allocated buffers to contain VA address and it's size.
class HCQBuffer(Protocol): va_addr:int; size:int # noqa: E702
class HCQAllocator(LRUAllocator): # pylint: disable=abstract-method
"""
A base allocator class compatible with the HCQ (Hardware Command Queue) API.
This class implements basic copy operations following the HCQ API, utilizing both `HWComputeQueue` and `HWCopyQueue`.
"""
def __init__(self, device:HCQCompiled, batch_size:int=(2 << 20), batch_cnt:int=32):
self.device:Any = device
self.b = [self._alloc(batch_size, BufferOptions(host=True)) for _ in range(batch_cnt)]
self.b_timeline, self.b_next = [0] * len(self.b), 0
super().__init__()
def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer: raise NotImplementedError("need hcq compat alloc")
def copyin(self, dest:HCQBuffer, src:memoryview):
with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"CPU -> {self.device.dname}", enabled=PROFILE):
for i in range(0, src.nbytes, self.b[0].size):
self.b_next = (self.b_next + 1) % len(self.b)
self.device.timeline_signal.wait(self.b_timeline[self.b_next])
ctypes.memmove(self.b[self.b_next].va_addr, from_mv(src[i:]), lsize:=min(self.b[self.b_next].size, src.nbytes-i))
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
.copy(dest.va_addr+i, self.b[self.b_next].va_addr, lsize) \
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
self.b_timeline[self.b_next] = self.device.timeline_value
self.device.timeline_value += 1
def copy_from_disk(self, dest:HCQBuffer, src, size):
def _get_temp_buf():
# Check if the next buffer is safe to be used (its signal has passed) and reserve it.
if self.b_timeline[(self.b_next + 1) % len(self.b)] <= self.device.timeline_signal.value:
self.b_timeline[(self.b_next + 1) % len(self.b)], self.b_next = (1 << 64), (self.b_next + 1) % len(self.b)
return (self.b[self.b_next].va_addr, self.b_next)
return None
with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"DISK -> {self.device.dname}", enabled=PROFILE):
for (batch_info, dst_off, src_off, copy_size) in src.device.allocator._copyout_sharded(src, size, _get_temp_buf, seg_len=self.b[0].size):
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
.copy(dest.va_addr + dst_off, batch_info[0] + src_off, copy_size) \
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
self.b_timeline[batch_info[1]] = self.device.timeline_value
self.device.timeline_value += 1
def copyout(self, dest:memoryview, src:HCQBuffer):
self.device.synchronize()
with hcq_profile(self.device, queue_type=self.device.hw_copy_queue_t, desc=f"{self.device.dname} -> CPU", enabled=PROFILE):
for i in range(0, dest.nbytes, self.b[0].size):
self.device.hw_copy_queue_t().wait(self.device.timeline_signal, self.device.timeline_value - 1) \
.copy(self.b[0].va_addr, src.va_addr+i, lsize:=min(self.b[0].size, dest.nbytes-i)) \
.signal(self.device.timeline_signal, self.device.timeline_value).submit(self.device)
self.device.timeline_signal.wait(self.device.timeline_value)
self.device.timeline_value += 1
ctypes.memmove(from_mv(dest[i:]), self.b[0].va_addr, lsize)
def transfer(self, dest:HCQBuffer, src:HCQBuffer, sz:int, src_dev, dest_dev):
src_dev.allocator.map(dest)
with hcq_profile(src_dev, queue_type=src_dev.hw_copy_queue_t, desc=f"{src_dev.dname} -> {dest_dev.dname}", enabled=PROFILE):
src_dev.hw_copy_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
.copy(dest.va_addr, src.va_addr, sz) \
.signal(src_dev.timeline_signal, src_dev.timeline_value).submit(src_dev)
src_dev.timeline_value += 1
if src_dev != dest_dev:
dest_dev.hw_compute_queue_t().wait(src_dev.timeline_signal, src_dev.timeline_value - 1) \
.wait(dest_dev.timeline_signal, dest_dev.timeline_value - 1) \
.signal(dest_dev.timeline_signal, dest_dev.timeline_value).submit(dest_dev)
dest_dev.timeline_value += 1
def map(self, buf:HCQBuffer): pass
def offset(self, buf, size:int, offset:int) -> HCQBuffer:
return type(buf)(va_addr=buf.va_addr + offset, size=size, **{k:v for k,v in buf.__dict__.items() if k not in ['va_addr', 'size']},
**{x[0]:getattr(buf, x[0]) for x in getattr(buf, '_fields_', []) if x[0] not in ['va_addr', 'size']}, _base=buf)