update docs, remove corealize (#4264)

* update docs, remove corealize

* handle 0 line count

* tensor schedule
This commit is contained in:
George Hotz 2024-04-23 12:05:29 +04:00 committed by GitHub
parent 9b7efa72ea
commit 967638f0d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 61 additions and 32 deletions

View File

@ -5,3 +5,23 @@ Everything in [Tensor](tensor.md) is syntactic sugar around [function.py](functi
::: tinygrad.lazy.LazyBuffer
options:
show_source: false
## Lowering
The [scheduler](/tinygrad/engine/schedule.py) converts the graph of LazyBuffers into a list of `ScheduleItem`. `ast` specifies what compute to run, and `bufs` specifies what buffers to run it on.
::: tinygrad.ops.ScheduleItem
The code in [realize](/tinygrad/engine/realize.py) lowers `ScheduleItem` to `ExecItem` with
::: tinygrad.engine.realize.lower_schedule
## Execution
Creating `ExecItem`, which has a run method
::: tinygrad.engine.realize.ExecItem
options:
members: true
Lists of `ExecItem` can be condensed into a single ExecItem with the Graph API (rename to Queue?)

View File

@ -12,7 +12,8 @@
## tinygrad ops
::: tinygrad.Tensor.corealize
::: tinygrad.Tensor.schedule_with_vars
::: tinygrad.Tensor.schedule
::: tinygrad.Tensor.realize
::: tinygrad.Tensor.replace
::: tinygrad.Tensor.assign

2
serve_docs.sh Executable file
View File

@ -0,0 +1,2 @@
#!/bin/bash
mkdocs serve -w tinygrad/

View File

@ -55,7 +55,7 @@ class BenchmarkResnetTrain(unittest.TestCase):
return f"{name} x{(bs, cin, xy, xy)}", [layer], cin, xy
def _test_layer(self, name, layer, cin, xy):
optim = SGD(get_parameters(layer), bs / 128 * 1.0) # need sgd for some params but not consequential for benchmarking
with Context(SAVE_SCHEDULE=0): Tensor.corealize([t.assign(t) for t in get_parameters(layer)])
with Context(SAVE_SCHEDULE=0): Tensor.realize(*[t.assign(t) for t in get_parameters(layer)])
JITCNT = getenv("JITCNT", 1)
Tensor.training = True
@ -67,8 +67,8 @@ class BenchmarkResnetTrain(unittest.TestCase):
y = x.sequential(layer).contiguous().contiguous_backward()
y.sum().backward()
if getenv("ASSIGN", 1): Tensor.corealize([y, x.grad] + optim.schedule_step())
else: Tensor.corealize([y, x.grad] + [t.grad for t in optim.params])
if getenv("ASSIGN", 1): Tensor.realize(y, x.grad, *optim.schedule_step())
else: Tensor.realize(y, x.grad, *[t.grad for t in optim.params])
return y.detach()
CNT = getenv("CNT", 5)

View File

@ -144,7 +144,7 @@ class TestOptReduceLoop(unittest.TestCase):
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
class TestOptWChild(unittest.TestCase):
@unittest.skip("this no longer happens, use corealize")
@unittest.skip("this no longer happens, use realize")
def test_unrealized_child(self):
a = Tensor.randn(16, 16)
b = Tensor.randn(16, 16)

View File

@ -171,7 +171,7 @@ class TestAssign(unittest.TestCase):
b = Tensor.full((4,), 3).contiguous().realize()
a += b
b += a
Tensor.corealize([a,b])
Tensor.realize(a,b)
np.testing.assert_allclose(a.numpy(), 5)
np.testing.assert_allclose(b.numpy(), 8)
@ -183,7 +183,7 @@ class TestAssign(unittest.TestCase):
c = a+9
a += b
b += c
Tensor.corealize([a,b])
Tensor.realize(a,b)
np.testing.assert_allclose(a.numpy(), 2+3)
np.testing.assert_allclose(b.numpy(), 3+2+9)

View File

@ -18,7 +18,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
seen = set()
if to_prerealize:
for pre in to_prerealize:
for s in create_schedule([pre.lazydata], seen.copy()):
for s in pre.schedule(seen=seen.copy()):
for i,out in enumerate(s.outputs):
if GRAPH: realized_lazybuffer(out, 0)
seen.add(out)

View File

View File

@ -96,7 +96,7 @@ class TinyJit(Generic[ReturnType]):
def __call__(self, *args, **kwargs) -> ReturnType:
input_tensors: List[Tuple[Union[int, str], Tensor]] = \
[(cast(Union[int, str], k),v) for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor]
Tensor.corealize([x[1] for x in input_tensors])
if len(input_tensors): Tensor.realize(*[x[1] for x in input_tensors])
lbs: List[LazyBuffer] = flatten([v.lazydata.lbs for _,v in input_tensors])
expected_sts_var_dtype_device = [(*x.st.unbind(), x.dtype, x.device) for x in lbs]
input_rawbuffers: List[Buffer] = [v.base.realized for v in lbs if v.base.realized is not None]
@ -105,12 +105,10 @@ class TinyJit(Generic[ReturnType]):
[dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))])
expected_names, expected_lbs = [x[0] for x in input_tensors], [(x[0], tuple(x[1].keys()), x[2], x[3]) for x in expected_sts_var_dtype_device]
if self.cnt >= 2:
# jit exec
assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT"
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
for ei in self.jit_cache: ei.run(var_vals, jit=True)
if self.cnt == 0:
# jit ignore
self.ret = self.fxn(*args, **kwargs)
if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
elif self.cnt == 1:
# jit capture
self.expected_names: List[Union[int, str]] = expected_names
@ -118,7 +116,7 @@ class TinyJit(Generic[ReturnType]):
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)):
capturing.append(self)
self.ret = self.fxn(*args, **kwargs)
Tensor.corealize(get_parameters(self.ret))
if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
capturing.clear()
del self.buffer_replace
assert len(self.jit_cache), "didn't JIT anything!"
@ -133,10 +131,12 @@ class TinyJit(Generic[ReturnType]):
self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers)
if DEBUG >= 1 and len(set(self.input_replace.values())) != len(input_rawbuffers): print("WARNING: some input tensors not found")
elif self.cnt == 0:
# jit ignore
self.ret = self.fxn(*args, **kwargs)
Tensor.corealize(get_parameters(self.ret))
elif self.cnt >= 2:
# jit exec
assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT"
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
for ei in self.jit_cache: ei.run(var_vals, jit=True)
# clear jit inputs
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None

View File

@ -18,7 +18,7 @@ class Optimizer:
def zero_grad(self):
for param in self.params: param.grad = None
def step(self): Tensor.corealize(self.schedule_step())
def step(self): Tensor.realize(*self.schedule_step())
def schedule_step(self) -> List[Tensor]: return self._step()+self.params+self.buffers
def _step(self) -> List[Tensor]: raise NotImplementedError

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import time, math, itertools, functools
from contextlib import ContextDecorator
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Iterable, Dict, DefaultDict, cast, get_args
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
from collections import defaultdict
import numpy as np
@ -11,10 +11,10 @@ from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up
from tinygrad.helpers import getenv
from tinygrad.lazy import LazyBuffer
from tinygrad.features.multi import MultiLazyBuffer
from tinygrad.ops import LoadOps
from tinygrad.ops import LoadOps, ScheduleItem
from tinygrad.buffer import Buffer, BufferOptions
from tinygrad.device import Device
from tinygrad.shape.symbolic import sint
from tinygrad.shape.symbolic import sint, Variable
from tinygrad.engine.realize import run_schedule, memory_planner
from tinygrad.engine.schedule import create_schedule_with_vars
@ -146,17 +146,23 @@ class Tensor:
# ***** data handlers ****
@staticmethod
def corealize(lst:Iterable[Tensor]):
def schedule_with_vars(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
"""Create the schedule needed to realize these Tensor(s), with Variables."""
if getenv("FUZZ_SCHEDULE"):
from test.external.fuzz_schedule import fuzz_schedule
fuzz_schedule(flatten([x.lazydata.lbs for x in lst]))
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in lst]))
run_schedule(memory_planner(schedule), var_vals)
fuzz_schedule(flatten([x.lazydata.lbs for x in (self,)+lst]))
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]), seen)
return memory_planner(schedule), var_vals
def realize(self) -> Tensor:
"""Trigger the computation needed to create this Tensor. This is a light wrapper around corealize."""
Tensor.corealize([self])
def schedule(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
"""Create the schedule needed to realize these Tensor(s)."""
schedule, var_vals = self.schedule_with_vars(*lst, seen=seen)
assert len(var_vals) == 0
return schedule
def realize(self, *lst:Tensor) -> Tensor:
"""Trigger the computation needed to create these Tensor(s)."""
run_schedule(*self.schedule_with_vars(*lst))
return self
def replace(self, x:Tensor) -> Tensor: