mirror of https://github.com/commaai/tinygrad.git
move BatchExecutor (#2297)
* move BatchExecutor * refactor to get_optimized_program * that changed
This commit is contained in:
parent
0cbf6c1811
commit
8916028ddd
|
@ -1,56 +1,15 @@
|
|||
from __future__ import annotations
|
||||
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
|
||||
import functools, itertools
|
||||
from tinygrad.helpers import DEBUG, DType, merge_dicts, GlobalCounters, getenv, colored
|
||||
from tinygrad.ops import RawBuffer, Device, ASTRunner
|
||||
from tinygrad.helpers import DEBUG, DType, merge_dicts
|
||||
from tinygrad.ops import RawBuffer, Device, ASTRunner, BatchExecutor, JitItem
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, sym_infer
|
||||
from dataclasses import dataclass
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from weakref import ref, WeakKeyDictionary
|
||||
|
||||
JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU", "LLVM"]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JitItem:
|
||||
prg: ASTRunner
|
||||
rawbufs: List[Optional[RawBuffer]]
|
||||
|
||||
class BatchExecutor:
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
|
||||
self.jit_cache: List[JitItem] = jit_cache
|
||||
self.input_replace: Dict[Tuple[int, int], Union[int, str]] = {}
|
||||
self.op_estimate, self.mem_estimate = NumNode(0), NumNode(0)
|
||||
for j,ji in enumerate(jit_cache):
|
||||
if isinstance(ji.prg, ASTRunner): # TODO: this is just for world and needs to be refactored
|
||||
self.op_estimate += ji.prg.op_estimate
|
||||
self.mem_estimate += ji.prg.mem_estimate
|
||||
for i,a in enumerate(ji.rawbufs):
|
||||
if a in [v for v in input_rawbuffers.values()]:
|
||||
self.input_replace[(j,i)] = [k for k,v in input_rawbuffers.items() if v == a][0]
|
||||
assert set(self.input_replace.values()) == set(input_rawbuffers.keys()), "some input tensors not found"
|
||||
self.clear_jit_inputs()
|
||||
|
||||
def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
|
||||
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
|
||||
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), {v:var_vals[v] for v in getattr(ji.prg,"vars",[])}, jit=True)
|
||||
self.clear_jit_inputs()
|
||||
|
||||
def update_stats(self, var_vals: Dict[Variable, int], et: Optional[float]):
|
||||
# TODO: this is mostly copied from ASTRunner
|
||||
op_estimate = sym_infer(self.op_estimate, var_vals)
|
||||
mem_estimate = sym_infer(self.mem_estimate, var_vals)
|
||||
if DEBUG >= 2:
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'CYAN')} kernels:{len(self.jit_cache):4d} inputs:{len(self.input_replace):3d} {' '.join([f'{k.expr}={v}' for k,v in var_vals.items()])[:50]:50s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
||||
GlobalCounters.kernel_count += len(self.jit_cache)
|
||||
GlobalCounters.global_ops += sym_infer(self.op_estimate, var_vals)
|
||||
GlobalCounters.global_mem += sym_infer(self.mem_estimate, var_vals)
|
||||
if et is not None: GlobalCounters.time_sum_s += et
|
||||
|
||||
def clear_jit_inputs(self):
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
|
||||
class TinyJit:
|
||||
def __init__(self, fxn:Callable):
|
||||
self.fxn: Callable = fxn
|
||||
|
@ -98,8 +57,7 @@ class TinyJit:
|
|||
assert len(jit_cache) != 0, "didn't JIT anything!"
|
||||
if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_rawbuffers)} inputs")
|
||||
|
||||
alt_batch_exec = Device[Device.DEFAULT].batch_executor
|
||||
self.jit_fxn = (BatchExecutor if alt_batch_exec is None or getenv("JIT") == 2 else alt_batch_exec)(jit_cache, input_rawbuffers, var_vals)
|
||||
self.jit_fxn = Device[Device.DEFAULT].batch_executor(jit_cache, input_rawbuffers, var_vals)
|
||||
elif self.cnt == 0:
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
|
||||
|
@ -129,7 +87,7 @@ class _CacheCollector:
|
|||
def add(self, prg, rawbufs, var_vals):
|
||||
if self.cache is None: return
|
||||
for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}"
|
||||
self.placeholders[rawbufs[0]] = PlaceHolder(rawbufs[0])
|
||||
self.placeholders[rawbufs[0]] = PlaceHolder(rawbufs[0]) # NOTE: this is making an assumption that 0 is special
|
||||
self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, RawBuffer) else x for x in rawbufs]))
|
||||
|
||||
def finish(self) -> List[JitItem]:
|
||||
|
|
124
tinygrad/ops.py
124
tinygrad/ops.py
|
@ -1,10 +1,10 @@
|
|||
from __future__ import annotations
|
||||
import importlib, inspect, functools, pathlib, re
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping
|
||||
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping, cast
|
||||
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT
|
||||
from tinygrad.runtime.lib import RawBuffer
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer, NumNode
|
||||
from dataclasses import dataclass
|
||||
|
||||
# these are the llops your accelerator must implement, along with toCpu
|
||||
|
@ -106,13 +106,55 @@ class _Device:
|
|||
return "CPU"
|
||||
Device = _Device()
|
||||
|
||||
# **************** batch executor ****************
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class JitItem:
|
||||
prg: ASTRunner
|
||||
rawbufs: List[Optional[RawBuffer]]
|
||||
|
||||
class BatchExecutor:
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
|
||||
self.jit_cache: List[JitItem] = jit_cache
|
||||
self.input_replace: Dict[Tuple[int, int], Union[int, str]] = {}
|
||||
self.op_estimate, self.mem_estimate = NumNode(0), NumNode(0)
|
||||
for j,ji in enumerate(jit_cache):
|
||||
if isinstance(ji.prg, ASTRunner): # TODO: this is just for world and needs to be refactored
|
||||
self.op_estimate += ji.prg.op_estimate
|
||||
self.mem_estimate += ji.prg.mem_estimate
|
||||
for i,a in enumerate(ji.rawbufs):
|
||||
if a in [v for v in input_rawbuffers.values()]:
|
||||
self.input_replace[(j,i)] = [k for k,v in input_rawbuffers.items() if v == a][0]
|
||||
assert set(self.input_replace.values()) == set(input_rawbuffers.keys()), "some input tensors not found"
|
||||
self.clear_jit_inputs()
|
||||
|
||||
def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
|
||||
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
|
||||
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), {v:var_vals[v] for v in getattr(ji.prg,"vars",[])}, jit=True)
|
||||
self.clear_jit_inputs()
|
||||
|
||||
def update_stats(self, var_vals: Dict[Variable, int], et: Optional[float]):
|
||||
# TODO: this is mostly copied from ASTRunner
|
||||
op_estimate = sym_infer(self.op_estimate, var_vals)
|
||||
mem_estimate = sym_infer(self.mem_estimate, var_vals)
|
||||
if DEBUG >= 2:
|
||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'CYAN')} kernels:{len(self.jit_cache):4d} inputs:{len(self.input_replace):3d} {' '.join([f'{k.expr}={v}' for k,v in var_vals.items()])[:50]:50s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
|
||||
GlobalCounters.kernel_count += len(self.jit_cache)
|
||||
GlobalCounters.global_ops += sym_infer(self.op_estimate, var_vals)
|
||||
GlobalCounters.global_mem += sym_infer(self.mem_estimate, var_vals)
|
||||
if et is not None: GlobalCounters.time_sum_s += et
|
||||
|
||||
def clear_jit_inputs(self):
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
|
||||
# **************** for Interpreted Buffers ****************
|
||||
|
||||
class Interpreted:
|
||||
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_underlying=None):
|
||||
self.buffer, self.fxn_for_op, self.from_underlying = buffer, fxn_for_op, from_underlying
|
||||
self.synchronize = lambda: None
|
||||
self.batch_executor = None
|
||||
self.batch_executor = BatchExecutor
|
||||
self.codegen = None
|
||||
self.method_cache: Dict[LazyOp, Callable] = {}
|
||||
|
||||
|
@ -233,17 +275,50 @@ class ASTRunner:
|
|||
return et
|
||||
|
||||
class Compiled:
|
||||
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, compiler, runtime, synchronize=lambda: None, batch_executor=None):
|
||||
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize, self.batch_executor = buffer, linearizer_opts, renderer, compiler, runtime, synchronize, batch_executor
|
||||
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, compiler, runtime, synchronize=lambda: None, batch_executor=BatchExecutor):
|
||||
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize = buffer, linearizer_opts, renderer, compiler, runtime, synchronize
|
||||
self.batch_executor = BatchExecutor if getenv("JIT") == 2 else batch_executor
|
||||
self.method_cache: Dict[LazyOp, ASTRunner] = {}
|
||||
|
||||
def to_program(self, k):
|
||||
def to_program(self, k) -> ASTRunner:
|
||||
k.linearize()
|
||||
src, runtime_args = self.renderer(k.function_name, k.uops)
|
||||
return ASTRunner(k.function_name, src, k.global_size, k.local_size,
|
||||
op_estimate=k.info.flops, mem_estimate=k.info.mem_estimate,
|
||||
display_name=k.display_name, runtime_args=runtime_args).build(self.compiler, self.runtime)
|
||||
|
||||
def get_optimized_program(self, ast:LazyOp, rawbuffers:List[RawBuffer]) -> ASTRunner:
|
||||
if DEBUG >= 3:
|
||||
from tinygrad.graph import print_tree
|
||||
print_tree(ast)
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
k = Linearizer(ast, self.linearizer_opts)
|
||||
assert k.info.dtype == rawbuffers[0].dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {rawbuffers[0].dtype}"
|
||||
if not NOOPT:
|
||||
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
|
||||
if BEAM >= 1 and not vars_from_ast(ast):
|
||||
lins = [(("tc" if used_tensor_cores else "hc"), k)]
|
||||
# allocate a scratch buffer if output buffer is also input
|
||||
test_rawbuffers = [type(rawbuffers[0])(rawbuffers[0].size, rawbuffers[0].dtype), *rawbuffers[1:]] if rawbuffers[0] in rawbuffers[1:] else rawbuffers
|
||||
kb = Linearizer(ast, self.linearizer_opts)
|
||||
kb.required_optimizations()
|
||||
from tinygrad.features.search import beam_search, time_linearizer
|
||||
lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))))
|
||||
if used_tensor_cores:
|
||||
lins.append(("hc", Linearizer(ast, self.linearizer_opts)))
|
||||
lins[-1][1].hand_coded_optimizations()
|
||||
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, disable_cache=True, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
|
||||
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
|
||||
k = timed[0][1]
|
||||
else:
|
||||
k.required_optimizations()
|
||||
prg = self.to_program(k)
|
||||
# extract real vars used in ast
|
||||
prg.vars = vars_from_ast(ast)
|
||||
assert all(v._val is None for v in prg.vars), f"ast contains bound Variable {prg.vars}"
|
||||
return prg
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output, inputs, var_vals, **kwargs):
|
||||
# check if we can reuse the output buffer
|
||||
# if it's aliased, don't use it
|
||||
|
@ -264,44 +339,11 @@ class Compiled:
|
|||
# all the rawbuffers
|
||||
rawbuffers = [output.realized] + [x.realized for x in inputs]
|
||||
|
||||
# compilation time
|
||||
def get_program():
|
||||
if DEBUG >= 3:
|
||||
from tinygrad.graph import print_tree
|
||||
print_tree(ast)
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.lazy import vars_from_ast
|
||||
k = Linearizer(ast, self.linearizer_opts)
|
||||
assert k.info.dtype == output.dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {output.dtype}"
|
||||
if not NOOPT:
|
||||
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
|
||||
if BEAM >= 1 and not vars_from_ast(ast):
|
||||
lins = [(("tc" if used_tensor_cores else "hc"), k)]
|
||||
# allocate a scratch buffer if output buffer is also input
|
||||
test_rawbuffers = [type(rawbuffers[0])(rawbuffers[0].size, rawbuffers[0].dtype), *rawbuffers[1:]] if rawbuffers[0] in rawbuffers[1:] else rawbuffers
|
||||
kb = Linearizer(ast, self.linearizer_opts)
|
||||
kb.required_optimizations()
|
||||
from tinygrad.features.search import beam_search, time_linearizer
|
||||
lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))))
|
||||
if used_tensor_cores:
|
||||
lins.append(("hc", Linearizer(ast, self.linearizer_opts)))
|
||||
lins[-1][1].hand_coded_optimizations()
|
||||
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, disable_cache=True, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
|
||||
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
|
||||
k = timed[0][1]
|
||||
else:
|
||||
k.required_optimizations()
|
||||
prg = self.to_program(k)
|
||||
# extract real vars used in ast
|
||||
prg.vars = vars_from_ast(ast)
|
||||
assert all(v._val is None for v in prg.vars), f"ast contains bound Variable {prg.vars}"
|
||||
return prg
|
||||
|
||||
if getenv("ENABLE_METHOD_CACHE", 1):
|
||||
if ast not in self.method_cache: self.method_cache[ast] = get_program()
|
||||
if ast not in self.method_cache: self.method_cache[ast] = self.get_optimized_program(ast, rawbuffers)
|
||||
prg = self.method_cache[ast]
|
||||
else:
|
||||
prg = get_program()
|
||||
prg = self.get_optimized_program(ast, rawbuffers)
|
||||
|
||||
if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)
|
||||
|
||||
|
|
|
@ -4,9 +4,10 @@ import Metal, Cocoa, libdispatch
|
|||
from typing import List, Any, Tuple, Dict, Union, Set
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup, CI
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.ops import Compiled, BatchExecutor, JitItem
|
||||
from tinygrad.renderer.metal import MetalRenderer
|
||||
from tinygrad.runtime.lib import RawBufferMapped, RawBuffer, LRUAllocator
|
||||
from tinygrad.shape.symbolic import Variable, Node
|
||||
|
||||
class MetalAllocator(LRUAllocator):
|
||||
def _do_alloc(self, size, dtype, device, **kwargs): return METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared)
|
||||
|
@ -80,8 +81,6 @@ class MetalProgram:
|
|||
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
||||
METAL.mtl_buffers_in_flight.append(command_buffer)
|
||||
|
||||
from tinygrad.jit import BatchExecutor, JitItem
|
||||
from tinygrad.shape.symbolic import Variable, Node
|
||||
class MetalBatchExecutor(BatchExecutor):
|
||||
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
|
||||
super().__init__(jit_cache, input_rawbuffers, var_vals)
|
||||
|
@ -151,4 +150,4 @@ class MetalBatchExecutor(BatchExecutor):
|
|||
super().update_stats(var_vals, et)
|
||||
return et
|
||||
|
||||
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, batch_executor=MetalBatchExecutor if not CI else None)
|
||||
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, batch_executor=MetalBatchExecutor if not CI else BatchExecutor)
|
||||
|
|
Loading…
Reference in New Issue