mirror of https://github.com/commaai/tinygrad.git
update docs, remove corealize (#4264)
* update docs, remove corealize * handle 0 line count * tensor schedule
This commit is contained in:
parent
9b7efa72ea
commit
967638f0d5
|
@ -5,3 +5,23 @@ Everything in [Tensor](tensor.md) is syntactic sugar around [function.py](functi
|
|||
::: tinygrad.lazy.LazyBuffer
|
||||
options:
|
||||
show_source: false
|
||||
|
||||
## Lowering
|
||||
|
||||
The [scheduler](/tinygrad/engine/schedule.py) converts the graph of LazyBuffers into a list of `ScheduleItem`. `ast` specifies what compute to run, and `bufs` specifies what buffers to run it on.
|
||||
|
||||
::: tinygrad.ops.ScheduleItem
|
||||
|
||||
The code in [realize](/tinygrad/engine/realize.py) lowers `ScheduleItem` to `ExecItem` with
|
||||
|
||||
::: tinygrad.engine.realize.lower_schedule
|
||||
|
||||
## Execution
|
||||
|
||||
Creating `ExecItem`, which has a run method
|
||||
|
||||
::: tinygrad.engine.realize.ExecItem
|
||||
options:
|
||||
members: true
|
||||
|
||||
Lists of `ExecItem` can be condensed into a single ExecItem with the Graph API (rename to Queue?)
|
|
@ -12,7 +12,8 @@
|
|||
|
||||
## tinygrad ops
|
||||
|
||||
::: tinygrad.Tensor.corealize
|
||||
::: tinygrad.Tensor.schedule_with_vars
|
||||
::: tinygrad.Tensor.schedule
|
||||
::: tinygrad.Tensor.realize
|
||||
::: tinygrad.Tensor.replace
|
||||
::: tinygrad.Tensor.assign
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
#!/bin/bash
|
||||
mkdocs serve -w tinygrad/
|
|
@ -55,7 +55,7 @@ class BenchmarkResnetTrain(unittest.TestCase):
|
|||
return f"{name} x{(bs, cin, xy, xy)}", [layer], cin, xy
|
||||
def _test_layer(self, name, layer, cin, xy):
|
||||
optim = SGD(get_parameters(layer), bs / 128 * 1.0) # need sgd for some params but not consequential for benchmarking
|
||||
with Context(SAVE_SCHEDULE=0): Tensor.corealize([t.assign(t) for t in get_parameters(layer)])
|
||||
with Context(SAVE_SCHEDULE=0): Tensor.realize(*[t.assign(t) for t in get_parameters(layer)])
|
||||
|
||||
JITCNT = getenv("JITCNT", 1)
|
||||
Tensor.training = True
|
||||
|
@ -67,8 +67,8 @@ class BenchmarkResnetTrain(unittest.TestCase):
|
|||
|
||||
y = x.sequential(layer).contiguous().contiguous_backward()
|
||||
y.sum().backward()
|
||||
if getenv("ASSIGN", 1): Tensor.corealize([y, x.grad] + optim.schedule_step())
|
||||
else: Tensor.corealize([y, x.grad] + [t.grad for t in optim.params])
|
||||
if getenv("ASSIGN", 1): Tensor.realize(y, x.grad, *optim.schedule_step())
|
||||
else: Tensor.realize(y, x.grad, *[t.grad for t in optim.params])
|
||||
return y.detach()
|
||||
|
||||
CNT = getenv("CNT", 5)
|
||||
|
|
|
@ -144,7 +144,7 @@ class TestOptReduceLoop(unittest.TestCase):
|
|||
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
class TestOptWChild(unittest.TestCase):
|
||||
@unittest.skip("this no longer happens, use corealize")
|
||||
@unittest.skip("this no longer happens, use realize")
|
||||
def test_unrealized_child(self):
|
||||
a = Tensor.randn(16, 16)
|
||||
b = Tensor.randn(16, 16)
|
||||
|
|
|
@ -171,7 +171,7 @@ class TestAssign(unittest.TestCase):
|
|||
b = Tensor.full((4,), 3).contiguous().realize()
|
||||
a += b
|
||||
b += a
|
||||
Tensor.corealize([a,b])
|
||||
Tensor.realize(a,b)
|
||||
np.testing.assert_allclose(a.numpy(), 5)
|
||||
np.testing.assert_allclose(b.numpy(), 8)
|
||||
|
||||
|
@ -183,7 +183,7 @@ class TestAssign(unittest.TestCase):
|
|||
c = a+9
|
||||
a += b
|
||||
b += c
|
||||
Tensor.corealize([a,b])
|
||||
Tensor.realize(a,b)
|
||||
np.testing.assert_allclose(a.numpy(), 2+3)
|
||||
np.testing.assert_allclose(b.numpy(), 3+2+9)
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
|
|||
seen = set()
|
||||
if to_prerealize:
|
||||
for pre in to_prerealize:
|
||||
for s in create_schedule([pre.lazydata], seen.copy()):
|
||||
for s in pre.schedule(seen=seen.copy()):
|
||||
for i,out in enumerate(s.outputs):
|
||||
if GRAPH: realized_lazybuffer(out, 0)
|
||||
seen.add(out)
|
||||
|
|
|
@ -96,7 +96,7 @@ class TinyJit(Generic[ReturnType]):
|
|||
def __call__(self, *args, **kwargs) -> ReturnType:
|
||||
input_tensors: List[Tuple[Union[int, str], Tensor]] = \
|
||||
[(cast(Union[int, str], k),v) for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor]
|
||||
Tensor.corealize([x[1] for x in input_tensors])
|
||||
if len(input_tensors): Tensor.realize(*[x[1] for x in input_tensors])
|
||||
lbs: List[LazyBuffer] = flatten([v.lazydata.lbs for _,v in input_tensors])
|
||||
expected_sts_var_dtype_device = [(*x.st.unbind(), x.dtype, x.device) for x in lbs]
|
||||
input_rawbuffers: List[Buffer] = [v.base.realized for v in lbs if v.base.realized is not None]
|
||||
|
@ -105,12 +105,10 @@ class TinyJit(Generic[ReturnType]):
|
|||
[dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))])
|
||||
|
||||
expected_names, expected_lbs = [x[0] for x in input_tensors], [(x[0], tuple(x[1].keys()), x[2], x[3]) for x in expected_sts_var_dtype_device]
|
||||
if self.cnt >= 2:
|
||||
# jit exec
|
||||
assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT"
|
||||
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
|
||||
if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
|
||||
for ei in self.jit_cache: ei.run(var_vals, jit=True)
|
||||
if self.cnt == 0:
|
||||
# jit ignore
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
|
||||
elif self.cnt == 1:
|
||||
# jit capture
|
||||
self.expected_names: List[Union[int, str]] = expected_names
|
||||
|
@ -118,7 +116,7 @@ class TinyJit(Generic[ReturnType]):
|
|||
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)):
|
||||
capturing.append(self)
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
Tensor.corealize(get_parameters(self.ret))
|
||||
if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
|
||||
capturing.clear()
|
||||
del self.buffer_replace
|
||||
assert len(self.jit_cache), "didn't JIT anything!"
|
||||
|
@ -133,10 +131,12 @@ class TinyJit(Generic[ReturnType]):
|
|||
|
||||
self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers)
|
||||
if DEBUG >= 1 and len(set(self.input_replace.values())) != len(input_rawbuffers): print("WARNING: some input tensors not found")
|
||||
elif self.cnt == 0:
|
||||
# jit ignore
|
||||
self.ret = self.fxn(*args, **kwargs)
|
||||
Tensor.corealize(get_parameters(self.ret))
|
||||
elif self.cnt >= 2:
|
||||
# jit exec
|
||||
assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT"
|
||||
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
|
||||
if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
|
||||
for ei in self.jit_cache: ei.run(var_vals, jit=True)
|
||||
|
||||
# clear jit inputs
|
||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||
|
|
|
@ -18,7 +18,7 @@ class Optimizer:
|
|||
def zero_grad(self):
|
||||
for param in self.params: param.grad = None
|
||||
|
||||
def step(self): Tensor.corealize(self.schedule_step())
|
||||
def step(self): Tensor.realize(*self.schedule_step())
|
||||
def schedule_step(self) -> List[Tensor]: return self._step()+self.params+self.buffers
|
||||
def _step(self) -> List[Tensor]: raise NotImplementedError
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
from __future__ import annotations
|
||||
import time, math, itertools, functools
|
||||
from contextlib import ContextDecorator
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Iterable, Dict, DefaultDict, cast, get_args
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
|
||||
from collections import defaultdict
|
||||
import numpy as np
|
||||
|
||||
|
@ -11,10 +11,10 @@ from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up
|
|||
from tinygrad.helpers import getenv
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.features.multi import MultiLazyBuffer
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.ops import LoadOps, ScheduleItem
|
||||
from tinygrad.buffer import Buffer, BufferOptions
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.shape.symbolic import sint, Variable
|
||||
from tinygrad.engine.realize import run_schedule, memory_planner
|
||||
from tinygrad.engine.schedule import create_schedule_with_vars
|
||||
|
||||
|
@ -146,17 +146,23 @@ class Tensor:
|
|||
|
||||
# ***** data handlers ****
|
||||
|
||||
@staticmethod
|
||||
def corealize(lst:Iterable[Tensor]):
|
||||
def schedule_with_vars(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
||||
"""Create the schedule needed to realize these Tensor(s), with Variables."""
|
||||
if getenv("FUZZ_SCHEDULE"):
|
||||
from test.external.fuzz_schedule import fuzz_schedule
|
||||
fuzz_schedule(flatten([x.lazydata.lbs for x in lst]))
|
||||
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in lst]))
|
||||
run_schedule(memory_planner(schedule), var_vals)
|
||||
fuzz_schedule(flatten([x.lazydata.lbs for x in (self,)+lst]))
|
||||
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]), seen)
|
||||
return memory_planner(schedule), var_vals
|
||||
|
||||
def realize(self) -> Tensor:
|
||||
"""Trigger the computation needed to create this Tensor. This is a light wrapper around corealize."""
|
||||
Tensor.corealize([self])
|
||||
def schedule(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
|
||||
"""Create the schedule needed to realize these Tensor(s)."""
|
||||
schedule, var_vals = self.schedule_with_vars(*lst, seen=seen)
|
||||
assert len(var_vals) == 0
|
||||
return schedule
|
||||
|
||||
def realize(self, *lst:Tensor) -> Tensor:
|
||||
"""Trigger the computation needed to create these Tensor(s)."""
|
||||
run_schedule(*self.schedule_with_vars(*lst))
|
||||
return self
|
||||
|
||||
def replace(self, x:Tensor) -> Tensor:
|
||||
|
|
Loading…
Reference in New Issue