move BatchExecutor (#2297)

* move BatchExecutor

* refactor to get_optimized_program

* that changed
This commit is contained in:
George Hotz 2023-11-14 08:08:51 -08:00 committed by GitHub
parent 0cbf6c1811
commit 8916028ddd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 91 additions and 92 deletions

View File

@ -1,56 +1,15 @@
from __future__ import annotations
from typing import Callable, List, Tuple, Any, Dict, cast, Union, Optional
import functools, itertools
from tinygrad.helpers import DEBUG, DType, merge_dicts, GlobalCounters, getenv, colored
from tinygrad.ops import RawBuffer, Device, ASTRunner
from tinygrad.helpers import DEBUG, DType, merge_dicts
from tinygrad.ops import RawBuffer, Device, ASTRunner, BatchExecutor, JitItem
from tinygrad.tensor import Tensor
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, sym_infer
from dataclasses import dataclass
from tinygrad.shape.symbolic import Variable
from weakref import ref, WeakKeyDictionary
JIT_SUPPORTED_DEVICE = ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU", "LLVM"]
@dataclass(frozen=True)
class JitItem:
prg: ASTRunner
rawbufs: List[Optional[RawBuffer]]
class BatchExecutor:
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
self.jit_cache: List[JitItem] = jit_cache
self.input_replace: Dict[Tuple[int, int], Union[int, str]] = {}
self.op_estimate, self.mem_estimate = NumNode(0), NumNode(0)
for j,ji in enumerate(jit_cache):
if isinstance(ji.prg, ASTRunner): # TODO: this is just for world and needs to be refactored
self.op_estimate += ji.prg.op_estimate
self.mem_estimate += ji.prg.mem_estimate
for i,a in enumerate(ji.rawbufs):
if a in [v for v in input_rawbuffers.values()]:
self.input_replace[(j,i)] = [k for k,v in input_rawbuffers.items() if v == a][0]
assert set(self.input_replace.values()) == set(input_rawbuffers.keys()), "some input tensors not found"
self.clear_jit_inputs()
def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), {v:var_vals[v] for v in getattr(ji.prg,"vars",[])}, jit=True)
self.clear_jit_inputs()
def update_stats(self, var_vals: Dict[Variable, int], et: Optional[float]):
# TODO: this is mostly copied from ASTRunner
op_estimate = sym_infer(self.op_estimate, var_vals)
mem_estimate = sym_infer(self.mem_estimate, var_vals)
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'CYAN')} kernels:{len(self.jit_cache):4d} inputs:{len(self.input_replace):3d} {' '.join([f'{k.expr}={v}' for k,v in var_vals.items()])[:50]:50s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += len(self.jit_cache)
GlobalCounters.global_ops += sym_infer(self.op_estimate, var_vals)
GlobalCounters.global_mem += sym_infer(self.mem_estimate, var_vals)
if et is not None: GlobalCounters.time_sum_s += et
def clear_jit_inputs(self):
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
class TinyJit:
def __init__(self, fxn:Callable):
self.fxn: Callable = fxn
@ -98,8 +57,7 @@ class TinyJit:
assert len(jit_cache) != 0, "didn't JIT anything!"
if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_rawbuffers)} inputs")
alt_batch_exec = Device[Device.DEFAULT].batch_executor
self.jit_fxn = (BatchExecutor if alt_batch_exec is None or getenv("JIT") == 2 else alt_batch_exec)(jit_cache, input_rawbuffers, var_vals)
self.jit_fxn = Device[Device.DEFAULT].batch_executor(jit_cache, input_rawbuffers, var_vals)
elif self.cnt == 0:
self.ret = self.fxn(*args, **kwargs)
@ -129,7 +87,7 @@ class _CacheCollector:
def add(self, prg, rawbufs, var_vals):
if self.cache is None: return
for k,v in var_vals.items(): assert k in self.var_vals and self.var_vals[k] == v, f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}"
self.placeholders[rawbufs[0]] = PlaceHolder(rawbufs[0])
self.placeholders[rawbufs[0]] = PlaceHolder(rawbufs[0]) # NOTE: this is making an assumption that 0 is special
self.cache.append((prg, [self.placeholders.get(x, x) if isinstance(x, RawBuffer) else x for x in rawbufs]))
def finish(self) -> List[JitItem]:

View File

@ -1,10 +1,10 @@
from __future__ import annotations
import importlib, inspect, functools, pathlib, re
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping, cast
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT
from tinygrad.runtime.lib import RawBuffer
from tinygrad.shape.symbolic import Variable, sym_infer
from tinygrad.shape.symbolic import Variable, sym_infer, NumNode
from dataclasses import dataclass
# these are the llops your accelerator must implement, along with toCpu
@ -106,13 +106,55 @@ class _Device:
return "CPU"
Device = _Device()
# **************** batch executor ****************
@dataclass(frozen=True)
class JitItem:
prg: ASTRunner
rawbufs: List[Optional[RawBuffer]]
class BatchExecutor:
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
self.jit_cache: List[JitItem] = jit_cache
self.input_replace: Dict[Tuple[int, int], Union[int, str]] = {}
self.op_estimate, self.mem_estimate = NumNode(0), NumNode(0)
for j,ji in enumerate(jit_cache):
if isinstance(ji.prg, ASTRunner): # TODO: this is just for world and needs to be refactored
self.op_estimate += ji.prg.op_estimate
self.mem_estimate += ji.prg.mem_estimate
for i,a in enumerate(ji.rawbufs):
if a in [v for v in input_rawbuffers.values()]:
self.input_replace[(j,i)] = [k for k,v in input_rawbuffers.items() if v == a][0]
assert set(self.input_replace.values()) == set(input_rawbuffers.keys()), "some input tensors not found"
self.clear_jit_inputs()
def __call__(self, input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int], wait=False):
for (j,i),input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
for ji in self.jit_cache: ji.prg(cast(List[RawBuffer], ji.rawbufs), {v:var_vals[v] for v in getattr(ji.prg,"vars",[])}, jit=True)
self.clear_jit_inputs()
def update_stats(self, var_vals: Dict[Variable, int], et: Optional[float]):
# TODO: this is mostly copied from ASTRunner
op_estimate = sym_infer(self.op_estimate, var_vals)
mem_estimate = sym_infer(self.mem_estimate, var_vals)
if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'CYAN')} kernels:{len(self.jit_cache):4d} inputs:{len(self.input_replace):3d} {' '.join([f'{k.expr}={v}' for k,v in var_vals.items()])[:50]:50s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += len(self.jit_cache)
GlobalCounters.global_ops += sym_infer(self.op_estimate, var_vals)
GlobalCounters.global_mem += sym_infer(self.mem_estimate, var_vals)
if et is not None: GlobalCounters.time_sum_s += et
def clear_jit_inputs(self):
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
# **************** for Interpreted Buffers ****************
class Interpreted:
def __init__(self, buffer, fxn_for_op: Dict[Op, Callable], from_underlying=None):
self.buffer, self.fxn_for_op, self.from_underlying = buffer, fxn_for_op, from_underlying
self.synchronize = lambda: None
self.batch_executor = None
self.batch_executor = BatchExecutor
self.codegen = None
self.method_cache: Dict[LazyOp, Callable] = {}
@ -233,17 +275,50 @@ class ASTRunner:
return et
class Compiled:
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, compiler, runtime, synchronize=lambda: None, batch_executor=None):
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize, self.batch_executor = buffer, linearizer_opts, renderer, compiler, runtime, synchronize, batch_executor
def __init__(self, buffer: Type[RawBuffer], linearizer_opts, renderer, compiler, runtime, synchronize=lambda: None, batch_executor=BatchExecutor):
self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize = buffer, linearizer_opts, renderer, compiler, runtime, synchronize
self.batch_executor = BatchExecutor if getenv("JIT") == 2 else batch_executor
self.method_cache: Dict[LazyOp, ASTRunner] = {}
def to_program(self, k):
def to_program(self, k) -> ASTRunner:
k.linearize()
src, runtime_args = self.renderer(k.function_name, k.uops)
return ASTRunner(k.function_name, src, k.global_size, k.local_size,
op_estimate=k.info.flops, mem_estimate=k.info.mem_estimate,
display_name=k.display_name, runtime_args=runtime_args).build(self.compiler, self.runtime)
def get_optimized_program(self, ast:LazyOp, rawbuffers:List[RawBuffer]) -> ASTRunner:
if DEBUG >= 3:
from tinygrad.graph import print_tree
print_tree(ast)
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.lazy import vars_from_ast
k = Linearizer(ast, self.linearizer_opts)
assert k.info.dtype == rawbuffers[0].dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {rawbuffers[0].dtype}"
if not NOOPT:
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
if BEAM >= 1 and not vars_from_ast(ast):
lins = [(("tc" if used_tensor_cores else "hc"), k)]
# allocate a scratch buffer if output buffer is also input
test_rawbuffers = [type(rawbuffers[0])(rawbuffers[0].size, rawbuffers[0].dtype), *rawbuffers[1:]] if rawbuffers[0] in rawbuffers[1:] else rawbuffers
kb = Linearizer(ast, self.linearizer_opts)
kb.required_optimizations()
from tinygrad.features.search import beam_search, time_linearizer
lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))))
if used_tensor_cores:
lins.append(("hc", Linearizer(ast, self.linearizer_opts)))
lins[-1][1].hand_coded_optimizations()
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, disable_cache=True, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
k = timed[0][1]
else:
k.required_optimizations()
prg = self.to_program(k)
# extract real vars used in ast
prg.vars = vars_from_ast(ast)
assert all(v._val is None for v in prg.vars), f"ast contains bound Variable {prg.vars}"
return prg
def exec_ast(self, ast:LazyOp, output, inputs, var_vals, **kwargs):
# check if we can reuse the output buffer
# if it's aliased, don't use it
@ -264,44 +339,11 @@ class Compiled:
# all the rawbuffers
rawbuffers = [output.realized] + [x.realized for x in inputs]
# compilation time
def get_program():
if DEBUG >= 3:
from tinygrad.graph import print_tree
print_tree(ast)
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.lazy import vars_from_ast
k = Linearizer(ast, self.linearizer_opts)
assert k.info.dtype == output.dtype, f"linearizer must match dtype. linearizer wants {k.info.dtype} but buffer is {output.dtype}"
if not NOOPT:
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
if BEAM >= 1 and not vars_from_ast(ast):
lins = [(("tc" if used_tensor_cores else "hc"), k)]
# allocate a scratch buffer if output buffer is also input
test_rawbuffers = [type(rawbuffers[0])(rawbuffers[0].size, rawbuffers[0].dtype), *rawbuffers[1:]] if rawbuffers[0] in rawbuffers[1:] else rawbuffers
kb = Linearizer(ast, self.linearizer_opts)
kb.required_optimizations()
from tinygrad.features.search import beam_search, time_linearizer
lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))))
if used_tensor_cores:
lins.append(("hc", Linearizer(ast, self.linearizer_opts)))
lins[-1][1].hand_coded_optimizations()
timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, disable_cache=True, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2])
if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed))
k = timed[0][1]
else:
k.required_optimizations()
prg = self.to_program(k)
# extract real vars used in ast
prg.vars = vars_from_ast(ast)
assert all(v._val is None for v in prg.vars), f"ast contains bound Variable {prg.vars}"
return prg
if getenv("ENABLE_METHOD_CACHE", 1):
if ast not in self.method_cache: self.method_cache[ast] = get_program()
if ast not in self.method_cache: self.method_cache[ast] = self.get_optimized_program(ast, rawbuffers)
prg = self.method_cache[ast]
else:
prg = get_program()
prg = self.get_optimized_program(ast, rawbuffers)
if prg.name == getenv("PRINT_PRG", ''): print(prg.prg)

View File

@ -4,9 +4,10 @@ import Metal, Cocoa, libdispatch
from typing import List, Any, Tuple, Dict, Union, Set
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup, CI
from tinygrad.ops import Compiled
from tinygrad.ops import Compiled, BatchExecutor, JitItem
from tinygrad.renderer.metal import MetalRenderer
from tinygrad.runtime.lib import RawBufferMapped, RawBuffer, LRUAllocator
from tinygrad.shape.symbolic import Variable, Node
class MetalAllocator(LRUAllocator):
def _do_alloc(self, size, dtype, device, **kwargs): return METAL.device.newBufferWithLength_options_(size*dtype.itemsize, Metal.MTLResourceStorageModeShared)
@ -80,8 +81,6 @@ class MetalProgram:
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
METAL.mtl_buffers_in_flight.append(command_buffer)
from tinygrad.jit import BatchExecutor, JitItem
from tinygrad.shape.symbolic import Variable, Node
class MetalBatchExecutor(BatchExecutor):
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: Dict[Union[int, str], RawBuffer], var_vals: Dict[Variable, int]):
super().__init__(jit_cache, input_rawbuffers, var_vals)
@ -151,4 +150,4 @@ class MetalBatchExecutor(BatchExecutor):
super().update_stats(var_vals, et)
return et
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, batch_executor=MetalBatchExecutor if not CI else None)
MetalBuffer = Compiled(RawMetalBuffer, LinearizerOptions(device="METAL"), MetalRenderer, compile_metal, MetalProgram, METAL.synchronize, batch_executor=MetalBatchExecutor if not CI else BatchExecutor)