mirror of https://github.com/commaai/tinygrad.git
CommandQueue is the future (#3950)
* start of command queue * cq work * runs * cleanup * outs set * read is gone * future buffer work * command queue is better * command queue works * loadops * delete unneeded * command queue works * upd * fix tests * use CommandQueue in compile * delay sync
This commit is contained in:
parent
0a34d6016b
commit
7425a0c646
|
@ -15,7 +15,8 @@ from extra.onnx import get_run_onnx
|
|||
from tinygrad import Tensor, Device, GlobalCounters, dtypes
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.helpers import partition, Context, fetch, getenv, DEBUG
|
||||
from tinygrad.engine.realize import run_schedule, lower_schedule_item
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.commandqueue import CommandQueue
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.ops import LoadOps, ScheduleItem
|
||||
Device.DEFAULT = "GPU"
|
||||
|
@ -88,9 +89,10 @@ def test_vs_onnx(onnx_data, schedule:Optional[List[ScheduleItem]], inputs:Dict[s
|
|||
|
||||
# run code (all buffers have been allocated)
|
||||
GlobalCounters.reset()
|
||||
for si in schedule: lower_schedule_item(si)(si.outputs+si.inputs, {})
|
||||
output = schedule[-1].outputs[0]
|
||||
CommandQueue(schedule)()
|
||||
|
||||
new_tinygrad_out = np.frombuffer(schedule[-1].outputs[0].as_buffer(), dtype=schedule[-1].outputs[0].dtype.np)
|
||||
new_tinygrad_out = np.frombuffer(output.as_buffer(), dtype=output.dtype.np)
|
||||
np.testing.assert_allclose(new_torch_out.reshape(new_tinygrad_out.shape), new_tinygrad_out, atol=1e-4, rtol=1e-2)
|
||||
print("semi-thneed self-test passed!")
|
||||
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import unittest
|
||||
import time
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad import Tensor, dtypes, Device
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule, lower_schedule_item
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
|
||||
class TestFusionOp(unittest.TestCase):
|
||||
def test_contiguous_add(self):
|
||||
|
@ -27,7 +27,7 @@ class TestFusionOp(unittest.TestCase):
|
|||
a = Tensor([1,2,3,4])
|
||||
for _ in range(24): a = a + a
|
||||
sched = create_schedule([a.lazydata], None)
|
||||
ji = lower_schedule_item(sched[-1])
|
||||
ji = Device[Device.DEFAULT].get_runner(*sched[-1].ast)
|
||||
self.assertLess(time.perf_counter()-st, 1.0)
|
||||
assert len(ji.prg.splitlines()) < 250
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import lower_schedule_item
|
||||
|
||||
# TODO: can copy this in here when we remove it
|
||||
#from tinygrad.ops import get_lazyop_info
|
||||
|
@ -13,7 +12,7 @@ from tinygrad.engine.realize import lower_schedule_item
|
|||
|
||||
def get_stats(x:Tensor):
|
||||
si = create_schedule([x.lazydata])[-1]
|
||||
runner = lower_schedule_item(si)
|
||||
runner = Device[Device.DEFAULT].get_runner(*si.ast)
|
||||
return runner.op_estimate, runner.mem_estimate
|
||||
|
||||
class TestUOpsStats(unittest.TestCase):
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
# NOTE: this will replace jit.py, realize.py, and a lot of the boilerplate in each graph executor
|
||||
from __future__ import annotations
|
||||
from typing import List, Dict, Union, DefaultDict
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.helpers import colored, cpu_time_execution, DEBUG
|
||||
from tinygrad.ops import ScheduleItem, LoadOps, BufferOps
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.device import Buffer, JITRunner, Device, BufferXfer, BufferCopy, update_stats
|
||||
|
||||
class CustomOp(JITRunner):
|
||||
def __init__(self, fxn):
|
||||
self.fxn = fxn
|
||||
super().__init__()
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): self.fxn(*rawbufs)
|
||||
|
||||
# NOTE: two syncitems aren't the same if they are in different places in the queue
|
||||
@dataclass(eq=False)
|
||||
class SyncItem:
|
||||
device: str
|
||||
waiters: int = 0
|
||||
def __repr__(self): return f"SyncItem({self.device}, waiters={self.waiters}, {id(self)})"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WaitItem:
|
||||
sync: SyncItem
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CopyItem:
|
||||
output: Buffer
|
||||
input: Buffer
|
||||
|
||||
# this will interface with HWCommandQueue to replace Graph
|
||||
class CommandQueue:
|
||||
def __init__(self, schedule:List[ScheduleItem]):
|
||||
self.q: DefaultDict[str, List[Union[ScheduleItem, CopyItem, SyncItem, WaitItem]]] = defaultdict(list)
|
||||
|
||||
def add_sync_item(device:str):
|
||||
if not len(self.q[device]) or not isinstance(sync_item:=self.q[device][-1], SyncItem):
|
||||
sync_item = SyncItem(device)
|
||||
self.q[device].append(sync_item)
|
||||
return sync_item
|
||||
|
||||
def add_wait_item(device:str, syncitem:SyncItem):
|
||||
# if you are adding this right after a first sync, delete this one
|
||||
if len(self.q[device]) and isinstance(wi:=self.q[device][-1], WaitItem) and wi.sync.device == syncitem.device:
|
||||
self.q[device] = self.q[device][:-1]
|
||||
wi.sync.waiters -= 1
|
||||
if wi.sync.waiters == 0: self.q[wi.sync.device].remove(wi.sync)
|
||||
if (wi:=WaitItem(syncitem)) not in self.q[device]:
|
||||
syncitem.waiters += 1
|
||||
self.q[device].append(wi)
|
||||
|
||||
while len(schedule):
|
||||
si = schedule.pop(0)
|
||||
assert len(set(x.device for x in si.outputs+si.inputs)) == 1 or (si.ast[0].op is LoadOps.COPY and len(si.outputs) == 1)
|
||||
queue = self.q[si.outputs[0].device]
|
||||
|
||||
if si.ast[0].op is LoadOps.COPY:
|
||||
# TODO: add back copy device
|
||||
copy_device = si.outputs[0].device #+"-copy"
|
||||
add_wait_item(copy_device, add_sync_item(si.inputs[0].device))
|
||||
self.q[copy_device].append(CopyItem(si.outputs[0], si.inputs[0]))
|
||||
#add_wait_item(si.outputs[0].device, add_sync_item(copy_device))
|
||||
continue
|
||||
|
||||
# NOTE: LoadOps.EMPTY and LoadOps.CUSTOM are making it here
|
||||
queue.append(si)
|
||||
|
||||
def __call__(self):
|
||||
active_queues = list(self.q.keys())
|
||||
waiting_queues: DefaultDict[SyncItem, List[str]] = defaultdict(list)
|
||||
seen_sids = set()
|
||||
while len(active_queues):
|
||||
device = active_queues.pop(0)
|
||||
if not len(self.q[device]): continue
|
||||
si = self.q[device].pop(0)
|
||||
#print(device, si, active_queues, seen_sids)
|
||||
if isinstance(si, SyncItem):
|
||||
# don't sync if there's other options
|
||||
if all(isinstance(self.q[x][0], SyncItem) for x in active_queues if len(self.q[x])):
|
||||
et = cpu_time_execution(Device[device].synchronize, enable=DEBUG>=2)
|
||||
update_stats(colored("synchronize", "RED"), 0, 0, {}, et, 1, device=device)
|
||||
if si in waiting_queues:
|
||||
active_queues += waiting_queues[si]
|
||||
waiting_queues[si].clear()
|
||||
seen_sids.add(si)
|
||||
else:
|
||||
# put it back
|
||||
self.q[device] = [si] + self.q[device]
|
||||
elif isinstance(si, WaitItem):
|
||||
if si.sync not in seen_sids:
|
||||
waiting_queues[si.sync].append(device)
|
||||
continue
|
||||
elif isinstance(si, CopyItem):
|
||||
si.output.allocate()
|
||||
fxn = BufferXfer() if hasattr(Device[si.output.device].allocator, 'transfer') and \
|
||||
si.output.device.split(":")[0] == si.input.device.split(":")[0] else BufferCopy()
|
||||
fxn.exec([si.output, si.input])
|
||||
elif isinstance(si, ScheduleItem):
|
||||
for out in si.outputs:
|
||||
if not hasattr(out, "_buf") and not (out.device.startswith("DISK") and si.ast[0].op is BufferOps.STORE): out.allocate()
|
||||
if si.ast[0].op is not LoadOps.EMPTY:
|
||||
if si.ast[0].op is LoadOps.CUSTOM:
|
||||
runner:JITRunner = CustomOp(si.ast[0].arg)
|
||||
elif si.ast[0].op is BufferOps.STORE:
|
||||
runner = Device[si.outputs[0].device].get_runner(*si.ast)
|
||||
else: raise RuntimeError(f"unknown type {si}")
|
||||
runner.exec(list(si.outputs+si.inputs), si.var_vals)
|
||||
else:
|
||||
update_stats(colored(f"empty {si.outputs[0].size:10d} {si.outputs[0].dtype}", "yellow"), 0, 0, {}, None, 1, device=si.outputs[0].device)
|
||||
else: raise RuntimeError(f"unknown type {si}")
|
||||
active_queues.append(device)
|
|
@ -1,51 +1,5 @@
|
|||
from typing import List, Dict, Optional
|
||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps
|
||||
from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, JITRunner, update_stats
|
||||
from tinygrad.helpers import colored, getenv, cpu_time_execution, DEBUG
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from typing import List
|
||||
from tinygrad.ops import ScheduleItem
|
||||
from tinygrad.engine.commandqueue import CommandQueue
|
||||
|
||||
class CustomOp(JITRunner):
|
||||
def __init__(self, fxn):
|
||||
self.fxn = fxn
|
||||
super().__init__()
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False): self.fxn(*rawbufs)
|
||||
|
||||
class SyncOp(JITRunner):
|
||||
def __init__(self, device):
|
||||
self.device, self.dname = Device[device], device
|
||||
super().__init__()
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False, jit=False):
|
||||
et = cpu_time_execution(self.device.synchronize, enable=wait or DEBUG >= 1)
|
||||
update_stats(colored("synchronize", "RED"), 0, 0, {}, et, 1, device=self.dname)
|
||||
|
||||
def lower_schedule_item(si:ScheduleItem) -> Optional[JITRunner]:
|
||||
assert len(set(x.device for x in si.outputs+si.inputs)) == 1 or si.ast[0].op is LoadOps.COPY
|
||||
if si.ast[0].op is BufferOps.STORE: return Device[si.outputs[0].device].get_runner(*si.ast)
|
||||
assert len(si.ast) == 1 and len(si.outputs) == 1, "only ASTRunner supports multioutput"
|
||||
out, ast = si.outputs[0], si.ast[0]
|
||||
if ast.op is LoadOps.COPY:
|
||||
if hasattr(Device[out.device].allocator, 'transfer') and out.device.split(":")[0] == si.inputs[0].device.split(":")[0]: return BufferXfer()
|
||||
return BufferCopy()
|
||||
if ast.op is LoadOps.CUSTOM: return CustomOp(ast.arg)
|
||||
if ast.op is LoadOps.SYNC: return SyncOp(out.device)
|
||||
return None
|
||||
|
||||
logops = open(getenv("LOGOPS", ""), "a") if getenv("LOGOPS", "") else None
|
||||
def run_schedule(schedule:List[ScheduleItem]):
|
||||
while len(schedule):
|
||||
si = schedule.pop(0)
|
||||
if logops and si.ast[0].op not in LoadOps and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
|
||||
|
||||
# get the program
|
||||
prg = lower_schedule_item(si)
|
||||
dont_allocate = getattr(prg, "skip_allocation", False)
|
||||
|
||||
for out in si.outputs:
|
||||
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
|
||||
if out.size > 0 and not dont_allocate and not hasattr(out, "_buf"): out.allocate()
|
||||
|
||||
# run the function (put it in JIT)
|
||||
real_buffers = [x for x in si.outputs+si.inputs if x.size != 0]
|
||||
assert dont_allocate or all(hasattr(x, "_buf") for x in real_buffers), f"can't run, some inputs aren't realized {real_buffers}"
|
||||
if prg: prg.exec(real_buffers, si.var_vals)
|
||||
elif (out:=si.outputs[0]).size > 0: update_stats(colored(f"empty {out.size:10d} {out.dtype}", "yellow"), 0, 0, {}, None, 1, device=out.device)
|
||||
def run_schedule(schedule:List[ScheduleItem]): CommandQueue(schedule)()
|
||||
|
|
|
@ -75,7 +75,7 @@ def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Va
|
|||
def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
|
||||
inputs: List[LazyBuffer] = []
|
||||
var_vals: Dict[Variable, int] = out.st.var_vals.copy()
|
||||
if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.COPY, LoadOps.EMPTY}:
|
||||
if out.op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY}:
|
||||
op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs)
|
||||
else:
|
||||
output_st, membufs = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape), [out]
|
||||
|
@ -218,7 +218,8 @@ def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None)
|
|||
if GRAPH:
|
||||
kernel_number += 1
|
||||
for out in ps.outputs: realized_lazybuffer(out, kernel_number)
|
||||
schedule.append(ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs), tuple(x.buffer for x in ps.inputs), ps.var_vals))
|
||||
schedule.append(ScheduleItem(ps.ast, tuple(x.buffer for x in ps.outputs if x.size != 0),
|
||||
tuple(x.buffer for x in ps.inputs if x.size != 0), ps.var_vals))
|
||||
for x in graph[buf]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(x)
|
||||
|
|
|
@ -12,7 +12,7 @@ from weakref import ref, ReferenceType, WeakValueDictionary
|
|||
lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary()
|
||||
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
||||
base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
|
||||
if st.size == 0 and op is not LoadOps.SYNC: op, arg, srcs, base = LoadOps.CONST, 0, (), None
|
||||
if st.size == 0: op, arg, srcs, base = LoadOps.CONST, 0, (), None
|
||||
if op is LoadOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype), True
|
||||
|
||||
cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base))
|
||||
|
@ -94,12 +94,7 @@ class LazyBuffer:
|
|||
def is_unrealized_unpadded_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
|
||||
|
||||
def _copy(self, device:str) -> LazyBuffer:
|
||||
if (dstart:=self.device.split(":")[0]) in {"EXT", "DISK"} or (dstart in {"HSA", "CUDA"} and device.split(":")[0] == dstart):
|
||||
# DISK/EXT don't sync
|
||||
# copies in HSA/CUDA to other HSA/CUDA don't sync either
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self,), enable_cache=False)
|
||||
sync = LazyBuffer.loadop(LoadOps.SYNC, (0,), dtypes.uint32, self.device, src=(self,), enable_cache=True)
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self, sync), enable_cache=False)
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, None, (self,), enable_cache=False)
|
||||
|
||||
def copy_to_device(self, device:str, force: bool = False) -> LazyBuffer:
|
||||
# no COPY
|
||||
|
|
|
@ -19,7 +19,7 @@ class BinaryOps(Enum):
|
|||
class TernaryOps(Enum): WHERE = auto() # noqa: E702
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); SYNC = auto(); ASSIGN = auto() # noqa: E702
|
||||
class LoadOps(Enum): EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto() # noqa: E702
|
||||
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, LoadOps, TernaryOps, BufferOps]
|
||||
OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[LoadOps], Type[TernaryOps], Type[BufferOps]]
|
||||
|
|
|
@ -45,7 +45,6 @@ class DiskAllocator(Allocator):
|
|||
dest[:] = src._buf()
|
||||
|
||||
class DiskRunner(JITRunner):
|
||||
skip_allocation = True
|
||||
def __init__(self, ast:LazyOp):
|
||||
# two ASTs are allowed here.
|
||||
assert ast.op is BufferOps.STORE, "output of AST must be store"
|
||||
|
|
Loading…
Reference in New Issue