update docs, remove corealize (#4264)

* update docs, remove corealize

* handle 0 line count

* tensor schedule
This commit is contained in:
George Hotz 2024-04-23 12:05:29 +04:00 committed by GitHub
parent 9b7efa72ea
commit 967638f0d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 61 additions and 32 deletions

View File

@ -5,3 +5,23 @@ Everything in [Tensor](tensor.md) is syntactic sugar around [function.py](functi
::: tinygrad.lazy.LazyBuffer ::: tinygrad.lazy.LazyBuffer
options: options:
show_source: false show_source: false
## Lowering
The [scheduler](/tinygrad/engine/schedule.py) converts the graph of LazyBuffers into a list of `ScheduleItem`. `ast` specifies what compute to run, and `bufs` specifies what buffers to run it on.
::: tinygrad.ops.ScheduleItem
The code in [realize](/tinygrad/engine/realize.py) lowers `ScheduleItem` to `ExecItem` with
::: tinygrad.engine.realize.lower_schedule
## Execution
Creating `ExecItem`, which has a run method
::: tinygrad.engine.realize.ExecItem
options:
members: true
Lists of `ExecItem` can be condensed into a single ExecItem with the Graph API (rename to Queue?)

View File

@ -12,7 +12,8 @@
## tinygrad ops ## tinygrad ops
::: tinygrad.Tensor.corealize ::: tinygrad.Tensor.schedule_with_vars
::: tinygrad.Tensor.schedule
::: tinygrad.Tensor.realize ::: tinygrad.Tensor.realize
::: tinygrad.Tensor.replace ::: tinygrad.Tensor.replace
::: tinygrad.Tensor.assign ::: tinygrad.Tensor.assign

2
serve_docs.sh Executable file
View File

@ -0,0 +1,2 @@
#!/bin/bash
mkdocs serve -w tinygrad/

View File

@ -55,7 +55,7 @@ class BenchmarkResnetTrain(unittest.TestCase):
return f"{name} x{(bs, cin, xy, xy)}", [layer], cin, xy return f"{name} x{(bs, cin, xy, xy)}", [layer], cin, xy
def _test_layer(self, name, layer, cin, xy): def _test_layer(self, name, layer, cin, xy):
optim = SGD(get_parameters(layer), bs / 128 * 1.0) # need sgd for some params but not consequential for benchmarking optim = SGD(get_parameters(layer), bs / 128 * 1.0) # need sgd for some params but not consequential for benchmarking
with Context(SAVE_SCHEDULE=0): Tensor.corealize([t.assign(t) for t in get_parameters(layer)]) with Context(SAVE_SCHEDULE=0): Tensor.realize(*[t.assign(t) for t in get_parameters(layer)])
JITCNT = getenv("JITCNT", 1) JITCNT = getenv("JITCNT", 1)
Tensor.training = True Tensor.training = True
@ -67,8 +67,8 @@ class BenchmarkResnetTrain(unittest.TestCase):
y = x.sequential(layer).contiguous().contiguous_backward() y = x.sequential(layer).contiguous().contiguous_backward()
y.sum().backward() y.sum().backward()
if getenv("ASSIGN", 1): Tensor.corealize([y, x.grad] + optim.schedule_step()) if getenv("ASSIGN", 1): Tensor.realize(y, x.grad, *optim.schedule_step())
else: Tensor.corealize([y, x.grad] + [t.grad for t in optim.params]) else: Tensor.realize(y, x.grad, *[t.grad for t in optim.params])
return y.detach() return y.detach()
CNT = getenv("CNT", 5) CNT = getenv("CNT", 5)

View File

@ -144,7 +144,7 @@ class TestOptReduceLoop(unittest.TestCase):
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") @unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
class TestOptWChild(unittest.TestCase): class TestOptWChild(unittest.TestCase):
@unittest.skip("this no longer happens, use corealize") @unittest.skip("this no longer happens, use realize")
def test_unrealized_child(self): def test_unrealized_child(self):
a = Tensor.randn(16, 16) a = Tensor.randn(16, 16)
b = Tensor.randn(16, 16) b = Tensor.randn(16, 16)

View File

@ -171,7 +171,7 @@ class TestAssign(unittest.TestCase):
b = Tensor.full((4,), 3).contiguous().realize() b = Tensor.full((4,), 3).contiguous().realize()
a += b a += b
b += a b += a
Tensor.corealize([a,b]) Tensor.realize(a,b)
np.testing.assert_allclose(a.numpy(), 5) np.testing.assert_allclose(a.numpy(), 5)
np.testing.assert_allclose(b.numpy(), 8) np.testing.assert_allclose(b.numpy(), 8)
@ -183,7 +183,7 @@ class TestAssign(unittest.TestCase):
c = a+9 c = a+9
a += b a += b
b += c b += c
Tensor.corealize([a,b]) Tensor.realize(a,b)
np.testing.assert_allclose(a.numpy(), 2+3) np.testing.assert_allclose(a.numpy(), 2+3)
np.testing.assert_allclose(b.numpy(), 3+2+9) np.testing.assert_allclose(b.numpy(), 3+2+9)

View File

@ -18,7 +18,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
seen = set() seen = set()
if to_prerealize: if to_prerealize:
for pre in to_prerealize: for pre in to_prerealize:
for s in create_schedule([pre.lazydata], seen.copy()): for s in pre.schedule(seen=seen.copy()):
for i,out in enumerate(s.outputs): for i,out in enumerate(s.outputs):
if GRAPH: realized_lazybuffer(out, 0) if GRAPH: realized_lazybuffer(out, 0)
seen.add(out) seen.add(out)

View File

View File

@ -96,7 +96,7 @@ class TinyJit(Generic[ReturnType]):
def __call__(self, *args, **kwargs) -> ReturnType: def __call__(self, *args, **kwargs) -> ReturnType:
input_tensors: List[Tuple[Union[int, str], Tensor]] = \ input_tensors: List[Tuple[Union[int, str], Tensor]] = \
[(cast(Union[int, str], k),v) for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor] [(cast(Union[int, str], k),v) for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor]
Tensor.corealize([x[1] for x in input_tensors]) if len(input_tensors): Tensor.realize(*[x[1] for x in input_tensors])
lbs: List[LazyBuffer] = flatten([v.lazydata.lbs for _,v in input_tensors]) lbs: List[LazyBuffer] = flatten([v.lazydata.lbs for _,v in input_tensors])
expected_sts_var_dtype_device = [(*x.st.unbind(), x.dtype, x.device) for x in lbs] expected_sts_var_dtype_device = [(*x.st.unbind(), x.dtype, x.device) for x in lbs]
input_rawbuffers: List[Buffer] = [v.base.realized for v in lbs if v.base.realized is not None] input_rawbuffers: List[Buffer] = [v.base.realized for v in lbs if v.base.realized is not None]
@ -105,12 +105,10 @@ class TinyJit(Generic[ReturnType]):
[dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))]) [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))])
expected_names, expected_lbs = [x[0] for x in input_tensors], [(x[0], tuple(x[1].keys()), x[2], x[3]) for x in expected_sts_var_dtype_device] expected_names, expected_lbs = [x[0] for x in input_tensors], [(x[0], tuple(x[1].keys()), x[2], x[3]) for x in expected_sts_var_dtype_device]
if self.cnt >= 2: if self.cnt == 0:
# jit exec # jit ignore
assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT" self.ret = self.fxn(*args, **kwargs)
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx] if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
for ei in self.jit_cache: ei.run(var_vals, jit=True)
elif self.cnt == 1: elif self.cnt == 1:
# jit capture # jit capture
self.expected_names: List[Union[int, str]] = expected_names self.expected_names: List[Union[int, str]] = expected_names
@ -118,7 +116,7 @@ class TinyJit(Generic[ReturnType]):
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)): with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)):
capturing.append(self) capturing.append(self)
self.ret = self.fxn(*args, **kwargs) self.ret = self.fxn(*args, **kwargs)
Tensor.corealize(get_parameters(self.ret)) if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
capturing.clear() capturing.clear()
del self.buffer_replace del self.buffer_replace
assert len(self.jit_cache), "didn't JIT anything!" assert len(self.jit_cache), "didn't JIT anything!"
@ -133,10 +131,12 @@ class TinyJit(Generic[ReturnType]):
self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers) self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers)
if DEBUG >= 1 and len(set(self.input_replace.values())) != len(input_rawbuffers): print("WARNING: some input tensors not found") if DEBUG >= 1 and len(set(self.input_replace.values())) != len(input_rawbuffers): print("WARNING: some input tensors not found")
elif self.cnt == 0: elif self.cnt >= 2:
# jit ignore # jit exec
self.ret = self.fxn(*args, **kwargs) assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT"
Tensor.corealize(get_parameters(self.ret)) for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
for ei in self.jit_cache: ei.run(var_vals, jit=True)
# clear jit inputs # clear jit inputs
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None

View File

@ -18,7 +18,7 @@ class Optimizer:
def zero_grad(self): def zero_grad(self):
for param in self.params: param.grad = None for param in self.params: param.grad = None
def step(self): Tensor.corealize(self.schedule_step()) def step(self): Tensor.realize(*self.schedule_step())
def schedule_step(self) -> List[Tensor]: return self._step()+self.params+self.buffers def schedule_step(self) -> List[Tensor]: return self._step()+self.params+self.buffers
def _step(self) -> List[Tensor]: raise NotImplementedError def _step(self) -> List[Tensor]: raise NotImplementedError

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import time, math, itertools, functools import time, math, itertools, functools
from contextlib import ContextDecorator from contextlib import ContextDecorator
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Iterable, Dict, DefaultDict, cast, get_args from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
from collections import defaultdict from collections import defaultdict
import numpy as np import numpy as np
@ -11,10 +11,10 @@ from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from tinygrad.lazy import LazyBuffer from tinygrad.lazy import LazyBuffer
from tinygrad.features.multi import MultiLazyBuffer from tinygrad.features.multi import MultiLazyBuffer
from tinygrad.ops import LoadOps from tinygrad.ops import LoadOps, ScheduleItem
from tinygrad.buffer import Buffer, BufferOptions from tinygrad.buffer import Buffer, BufferOptions
from tinygrad.device import Device from tinygrad.device import Device
from tinygrad.shape.symbolic import sint from tinygrad.shape.symbolic import sint, Variable
from tinygrad.engine.realize import run_schedule, memory_planner from tinygrad.engine.realize import run_schedule, memory_planner
from tinygrad.engine.schedule import create_schedule_with_vars from tinygrad.engine.schedule import create_schedule_with_vars
@ -146,17 +146,23 @@ class Tensor:
# ***** data handlers **** # ***** data handlers ****
@staticmethod def schedule_with_vars(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
def corealize(lst:Iterable[Tensor]): """Create the schedule needed to realize these Tensor(s), with Variables."""
if getenv("FUZZ_SCHEDULE"): if getenv("FUZZ_SCHEDULE"):
from test.external.fuzz_schedule import fuzz_schedule from test.external.fuzz_schedule import fuzz_schedule
fuzz_schedule(flatten([x.lazydata.lbs for x in lst])) fuzz_schedule(flatten([x.lazydata.lbs for x in (self,)+lst]))
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in lst])) schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]), seen)
run_schedule(memory_planner(schedule), var_vals) return memory_planner(schedule), var_vals
def realize(self) -> Tensor: def schedule(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
"""Trigger the computation needed to create this Tensor. This is a light wrapper around corealize.""" """Create the schedule needed to realize these Tensor(s)."""
Tensor.corealize([self]) schedule, var_vals = self.schedule_with_vars(*lst, seen=seen)
assert len(var_vals) == 0
return schedule
def realize(self, *lst:Tensor) -> Tensor:
"""Trigger the computation needed to create these Tensor(s)."""
run_schedule(*self.schedule_with_vars(*lst))
return self return self
def replace(self, x:Tensor) -> Tensor: def replace(self, x:Tensor) -> Tensor: