move JIT graphing into CapturedJit (#5852)

* move JIT graphing into CapturedJit

* better

* _jit_cache

* clear inputs cleanup

* test_pickle_jit with graph + cleanup

* 0 is fine to start

* support None in bufs

* alloc real buffers

* cleaner
This commit is contained in:
George Hotz 2024-07-31 20:48:17 -07:00 committed by GitHub
parent 0ec732b494
commit 9d05dfb6f4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 65 additions and 48 deletions

View File

@ -45,7 +45,7 @@ class TestPickle(unittest.TestCase):
def test_pickle_jit(self):
@TinyJit
def add(a, b): return a+b+1
def add(a, b): return a.sum()+b+1
for _ in range(3): add(Tensor.rand(10, 10), Tensor.rand(10, 10))
st = pickle.dumps(add)
del add
@ -55,7 +55,7 @@ class TestPickle(unittest.TestCase):
y = Tensor.ones(10, 10).contiguous().realize()
print("post jit")
out = add_fxn(x, y)
np.testing.assert_equal(out.numpy(), 3)
np.testing.assert_equal(out.numpy(), 102)
def test_pickle_schedule(self):
a = Tensor([1,2])

View File

@ -33,10 +33,10 @@ def apply_graph_to_jit(jit_cache: List[ExecItem], input_rawbuffers: List[Buffer]
for (j,i) in graph_runner.input_replace.keys(): graph_runner.jit_cache[j].bufs[i] = None
graphed_jit_cache.append(ExecItem(graph_runner, cast(List[Optional[Buffer]], input_rawbuffers)))
max_batch_size *= 2
if DEBUG >= 2: print(f"\tJIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
if DEBUG >= 2: print(f"JIT GRAPHing batch with {len(current_batch)} kernels on device {current_device}")
except GraphException as e:
graphed_jit_cache.extend(current_batch)
if DEBUG >= 2: print(f"\tJIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}")
if DEBUG >= 2: print(f"JIT GRAPHing failed batch with {len(current_batch)} kernels on device {current_device}: {e}")
current_batch = []
current_device = None
@ -128,6 +128,47 @@ class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
for rawbuf in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency
return list({id(x):x for x in wait_nodes}.values())
ReturnType = TypeVar('ReturnType')
@dataclass
class CapturedJit(Generic[ReturnType]):
ret: Any # includes the Tensors or any other returned object
jit_cache: List[ExecItem]
input_replace: Dict[Tuple[int, int], int]
extra_view_inputs: List[Tuple[int, int, str, int, DType]]
expected_names: List[Union[int, str]]
expected_st_vars_dtype_device: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]]
def __reduce__(self):
return self.__class__, (self.ret, self.jit_cache, self.input_replace, self.extra_view_inputs,
self.expected_names, self.expected_st_vars_dtype_device)
def __post_init__(self):
self._jit_cache: List[ExecItem] = self.jit_cache
self._input_replace: Dict[Tuple[int, int], int] = self.input_replace
self._graphed = False
self._clear_inputs()
def _clear_inputs(self):
for (j,i) in self._input_replace.keys(): self._jit_cache[j].bufs[i] = None
# jit exec
def __call__(self, input_buffers:List[Buffer], var_vals:Dict[Variable, int]) -> ReturnType:
# assign inputs
for idx, offset, device, size, dtype in self.extra_view_inputs:
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
for (j,i),input_idx in self._input_replace.items(): self._jit_cache[j].bufs[i] = input_buffers[input_idx]
# Condense the items into a graph executor.
if JIT < 2 and not self._graphed:
self._jit_cache = apply_graph_to_jit(self._jit_cache, input_buffers, var_vals)
self._input_replace = get_input_replace(self._jit_cache, input_buffers)
self._graphed = True
if DEBUG >= 1 and len(self._jit_cache) >= 10: print(f"jit execs {len(self._jit_cache)} kernels")
for ei in self._jit_cache: ei.run(var_vals, jit=True)
self._clear_inputs()
return self.ret
def _prepare_jit_inputs(args, kwargs):
input_tensors: List[Tuple[Union[int, str], Tensor]] = \
[(cast(Union[int, str], name),t) for name,t in itertools.chain(enumerate(args), sorted(kwargs.items())) if t.__class__ is Tensor]
@ -142,42 +183,12 @@ def _prepare_jit_inputs(args, kwargs):
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
return input_buffers, var_vals, names, st_vars_dtype_device
ReturnType = TypeVar('ReturnType')
@dataclass(frozen=True)
class CapturedJit(Generic[ReturnType]):
ret: Any # includes the Tensors or any other returned object
jit_cache: List[ExecItem]
input_replace: Dict[Tuple[int, int], int]
extra_view_inputs: List[Tuple[int, int, str, int, DType]]
expected_names: List[Union[int, str]]
expected_st_vars_dtype_device: List[Tuple[ShapeTracker, Tuple[Variable, ...], DType, str]]
def __post_init__(self): self.clear_jit_inputs()
def clear_jit_inputs(self):
for (j,i) in self.input_replace.keys(): self.jit_cache[j].bufs[i] = None
def __call__(self, *args, **kwargs) -> ReturnType:
input_buffers, var_vals, names, st_vars_dtype_device = _prepare_jit_inputs(args, kwargs)
assert self.expected_names == names, f"args mismatch in JIT: {self.expected_names=} != {names}"
assert self.expected_st_vars_dtype_device == st_vars_dtype_device, \
f"args mismatch in JIT: {self.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}"
# jit exec
for idx, offset, device, size, dtype in self.extra_view_inputs:
input_buffers.append(Buffer(device, size, dtype, base=input_buffers[idx], offset=offset).ensure_allocated())
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].bufs[i] = input_buffers[input_idx]
if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
for ei in self.jit_cache: ei.run(var_vals, jit=True)
# cleanup
self.clear_jit_inputs()
return self.ret
class TinyJit(Generic[ReturnType]):
def __init__(self, fxn:Callable[..., ReturnType]):
def __init__(self, fxn:Optional[Callable[..., ReturnType]], captured:Optional[CapturedJit]=None):
assert fxn or captured, "need either a function or a CapturedJit"
self.fxn = fxn
self.reset()
self.captured: Optional[CapturedJit] = captured
self.cnt: int = 2 if self.fxn is None else 0
def add_buffer(self, b:Buffer) -> Buffer:
if found:=self._buffer_replace.get(b, None): return found
@ -192,32 +203,33 @@ class TinyJit(Generic[ReturnType]):
self._jit_cache.append(ExecItem(ei.prg, [self.add_buffer(buf) for buf in ei.bufs if buf is not None]))
def reset(self):
self.cnt: int = 0
self.captured: Optional[CapturedJit] = None
assert self.fxn is not None, "can't reset without function"
self.cnt = 0
self.captured = None
def __reduce__(self):
assert self.captured is not None, "can't pickle an uncaptured JIT"
return CapturedJit, tuple(self.captured.__dict__.values())
return self.__class__, (None, self.captured)
# keep legacy code working
@property
def jit_cache(self) -> List[ExecItem]: return self.captured.jit_cache if self.captured is not None else []
def jit_cache(self) -> List[ExecItem]: return self.captured._jit_cache if self.captured is not None else []
@property
def input_replace(self) -> Dict[Tuple[int, int], int]: return self.captured.input_replace if self.captured is not None else {}
def input_replace(self) -> Dict[Tuple[int, int], int]: return self.captured._input_replace if self.captured is not None else {}
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) # add support for instance methods
def __call__(self, *args, **kwargs) -> ReturnType:
if self.captured is not None: return self.captured(*args, **kwargs)
input_buffers, var_vals, names, st_vars_dtype_device = _prepare_jit_inputs(args, kwargs)
if not JIT or self.cnt == 0:
# jit ignore
assert self.fxn is not None
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
ret = self.fxn(*args, **kwargs)
if len(params:=get_parameters(ret)): Tensor.realize(params[0], *params[1:])
elif self.cnt == 1:
# jit capture
assert self.fxn is not None
if capturing: raise RuntimeError(f"having TinyJit inside another TinyJit is not supported {len(capturing)=} {capturing=}")
self._jit_cache: List[ExecItem] = []
self._buffer_replace: WeakKeyDictionary[Buffer, Buffer] = WeakKeyDictionary()
@ -234,6 +246,7 @@ class TinyJit(Generic[ReturnType]):
if DEBUG >= 1: print(f"JIT captured {len(jit_cache)} kernels with {len(input_buffers)} inputs")
# track inputs that are views of buffers
# TODO: eventually expected_buffers should live in ExecItem
extra_view_inputs: List[Tuple[int, int, str, int, DType]] = []
for item in jit_cache:
for b in item.bufs:
@ -247,14 +260,18 @@ class TinyJit(Generic[ReturnType]):
assigned = _internal_memory_planner([cast(List[Buffer], item.bufs) for item in jit_cache], noopt_buffers, debug_prefix="JIT ")
jit_cache = [ExecItem(item.prg, [assigned.get(b,b).ensure_allocated() for b in item.bufs if b is not None]) for item in jit_cache]
# Condense the items into a graph executor.
if JIT < 2: jit_cache = apply_graph_to_jit(jit_cache, input_buffers, var_vals)
input_replace = get_input_replace(jit_cache, input_buffers)
if DEBUG >= 1 and len(set(input_replace.values())) != len(input_buffers): print("WARNING: some input tensors not found")
# set this for next run
self.captured = CapturedJit(ret, jit_cache, input_replace, extra_view_inputs, names, st_vars_dtype_device)
elif self.cnt >= 2:
# jit exec
assert self.captured is not None
assert self.captured.expected_names == names, f"args mismatch in JIT: {self.captured.expected_names=} != {names}"
assert self.captured.expected_st_vars_dtype_device == st_vars_dtype_device, \
f"args mismatch in JIT: {self.captured.expected_st_vars_dtype_device=} != {st_vars_dtype_device=}"
ret = self.captured(input_buffers, var_vals)
self.cnt += 1
return ret