mirror of https://github.com/commaai/tinygrad.git
update docs, remove corealize (#4264)
* update docs, remove corealize * handle 0 line count * tensor schedule
This commit is contained in:
parent
9b7efa72ea
commit
967638f0d5
|
@ -5,3 +5,23 @@ Everything in [Tensor](tensor.md) is syntactic sugar around [function.py](functi
|
||||||
::: tinygrad.lazy.LazyBuffer
|
::: tinygrad.lazy.LazyBuffer
|
||||||
options:
|
options:
|
||||||
show_source: false
|
show_source: false
|
||||||
|
|
||||||
|
## Lowering
|
||||||
|
|
||||||
|
The [scheduler](/tinygrad/engine/schedule.py) converts the graph of LazyBuffers into a list of `ScheduleItem`. `ast` specifies what compute to run, and `bufs` specifies what buffers to run it on.
|
||||||
|
|
||||||
|
::: tinygrad.ops.ScheduleItem
|
||||||
|
|
||||||
|
The code in [realize](/tinygrad/engine/realize.py) lowers `ScheduleItem` to `ExecItem` with
|
||||||
|
|
||||||
|
::: tinygrad.engine.realize.lower_schedule
|
||||||
|
|
||||||
|
## Execution
|
||||||
|
|
||||||
|
Creating `ExecItem`, which has a run method
|
||||||
|
|
||||||
|
::: tinygrad.engine.realize.ExecItem
|
||||||
|
options:
|
||||||
|
members: true
|
||||||
|
|
||||||
|
Lists of `ExecItem` can be condensed into a single ExecItem with the Graph API (rename to Queue?)
|
|
@ -12,7 +12,8 @@
|
||||||
|
|
||||||
## tinygrad ops
|
## tinygrad ops
|
||||||
|
|
||||||
::: tinygrad.Tensor.corealize
|
::: tinygrad.Tensor.schedule_with_vars
|
||||||
|
::: tinygrad.Tensor.schedule
|
||||||
::: tinygrad.Tensor.realize
|
::: tinygrad.Tensor.realize
|
||||||
::: tinygrad.Tensor.replace
|
::: tinygrad.Tensor.replace
|
||||||
::: tinygrad.Tensor.assign
|
::: tinygrad.Tensor.assign
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
#!/bin/bash
|
||||||
|
mkdocs serve -w tinygrad/
|
|
@ -55,7 +55,7 @@ class BenchmarkResnetTrain(unittest.TestCase):
|
||||||
return f"{name} x{(bs, cin, xy, xy)}", [layer], cin, xy
|
return f"{name} x{(bs, cin, xy, xy)}", [layer], cin, xy
|
||||||
def _test_layer(self, name, layer, cin, xy):
|
def _test_layer(self, name, layer, cin, xy):
|
||||||
optim = SGD(get_parameters(layer), bs / 128 * 1.0) # need sgd for some params but not consequential for benchmarking
|
optim = SGD(get_parameters(layer), bs / 128 * 1.0) # need sgd for some params but not consequential for benchmarking
|
||||||
with Context(SAVE_SCHEDULE=0): Tensor.corealize([t.assign(t) for t in get_parameters(layer)])
|
with Context(SAVE_SCHEDULE=0): Tensor.realize(*[t.assign(t) for t in get_parameters(layer)])
|
||||||
|
|
||||||
JITCNT = getenv("JITCNT", 1)
|
JITCNT = getenv("JITCNT", 1)
|
||||||
Tensor.training = True
|
Tensor.training = True
|
||||||
|
@ -67,8 +67,8 @@ class BenchmarkResnetTrain(unittest.TestCase):
|
||||||
|
|
||||||
y = x.sequential(layer).contiguous().contiguous_backward()
|
y = x.sequential(layer).contiguous().contiguous_backward()
|
||||||
y.sum().backward()
|
y.sum().backward()
|
||||||
if getenv("ASSIGN", 1): Tensor.corealize([y, x.grad] + optim.schedule_step())
|
if getenv("ASSIGN", 1): Tensor.realize(y, x.grad, *optim.schedule_step())
|
||||||
else: Tensor.corealize([y, x.grad] + [t.grad for t in optim.params])
|
else: Tensor.realize(y, x.grad, *[t.grad for t in optim.params])
|
||||||
return y.detach()
|
return y.detach()
|
||||||
|
|
||||||
CNT = getenv("CNT", 5)
|
CNT = getenv("CNT", 5)
|
||||||
|
|
|
@ -144,7 +144,7 @@ class TestOptReduceLoop(unittest.TestCase):
|
||||||
|
|
||||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||||
class TestOptWChild(unittest.TestCase):
|
class TestOptWChild(unittest.TestCase):
|
||||||
@unittest.skip("this no longer happens, use corealize")
|
@unittest.skip("this no longer happens, use realize")
|
||||||
def test_unrealized_child(self):
|
def test_unrealized_child(self):
|
||||||
a = Tensor.randn(16, 16)
|
a = Tensor.randn(16, 16)
|
||||||
b = Tensor.randn(16, 16)
|
b = Tensor.randn(16, 16)
|
||||||
|
|
|
@ -171,7 +171,7 @@ class TestAssign(unittest.TestCase):
|
||||||
b = Tensor.full((4,), 3).contiguous().realize()
|
b = Tensor.full((4,), 3).contiguous().realize()
|
||||||
a += b
|
a += b
|
||||||
b += a
|
b += a
|
||||||
Tensor.corealize([a,b])
|
Tensor.realize(a,b)
|
||||||
np.testing.assert_allclose(a.numpy(), 5)
|
np.testing.assert_allclose(a.numpy(), 5)
|
||||||
np.testing.assert_allclose(b.numpy(), 8)
|
np.testing.assert_allclose(b.numpy(), 8)
|
||||||
|
|
||||||
|
@ -183,7 +183,7 @@ class TestAssign(unittest.TestCase):
|
||||||
c = a+9
|
c = a+9
|
||||||
a += b
|
a += b
|
||||||
b += c
|
b += c
|
||||||
Tensor.corealize([a,b])
|
Tensor.realize(a,b)
|
||||||
np.testing.assert_allclose(a.numpy(), 2+3)
|
np.testing.assert_allclose(a.numpy(), 2+3)
|
||||||
np.testing.assert_allclose(b.numpy(), 3+2+9)
|
np.testing.assert_allclose(b.numpy(), 3+2+9)
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
|
||||||
seen = set()
|
seen = set()
|
||||||
if to_prerealize:
|
if to_prerealize:
|
||||||
for pre in to_prerealize:
|
for pre in to_prerealize:
|
||||||
for s in create_schedule([pre.lazydata], seen.copy()):
|
for s in pre.schedule(seen=seen.copy()):
|
||||||
for i,out in enumerate(s.outputs):
|
for i,out in enumerate(s.outputs):
|
||||||
if GRAPH: realized_lazybuffer(out, 0)
|
if GRAPH: realized_lazybuffer(out, 0)
|
||||||
seen.add(out)
|
seen.add(out)
|
||||||
|
|
|
@ -96,7 +96,7 @@ class TinyJit(Generic[ReturnType]):
|
||||||
def __call__(self, *args, **kwargs) -> ReturnType:
|
def __call__(self, *args, **kwargs) -> ReturnType:
|
||||||
input_tensors: List[Tuple[Union[int, str], Tensor]] = \
|
input_tensors: List[Tuple[Union[int, str], Tensor]] = \
|
||||||
[(cast(Union[int, str], k),v) for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor]
|
[(cast(Union[int, str], k),v) for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor]
|
||||||
Tensor.corealize([x[1] for x in input_tensors])
|
if len(input_tensors): Tensor.realize(*[x[1] for x in input_tensors])
|
||||||
lbs: List[LazyBuffer] = flatten([v.lazydata.lbs for _,v in input_tensors])
|
lbs: List[LazyBuffer] = flatten([v.lazydata.lbs for _,v in input_tensors])
|
||||||
expected_sts_var_dtype_device = [(*x.st.unbind(), x.dtype, x.device) for x in lbs]
|
expected_sts_var_dtype_device = [(*x.st.unbind(), x.dtype, x.device) for x in lbs]
|
||||||
input_rawbuffers: List[Buffer] = [v.base.realized for v in lbs if v.base.realized is not None]
|
input_rawbuffers: List[Buffer] = [v.base.realized for v in lbs if v.base.realized is not None]
|
||||||
|
@ -105,12 +105,10 @@ class TinyJit(Generic[ReturnType]):
|
||||||
[dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))])
|
[dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))])
|
||||||
|
|
||||||
expected_names, expected_lbs = [x[0] for x in input_tensors], [(x[0], tuple(x[1].keys()), x[2], x[3]) for x in expected_sts_var_dtype_device]
|
expected_names, expected_lbs = [x[0] for x in input_tensors], [(x[0], tuple(x[1].keys()), x[2], x[3]) for x in expected_sts_var_dtype_device]
|
||||||
if self.cnt >= 2:
|
if self.cnt == 0:
|
||||||
# jit exec
|
# jit ignore
|
||||||
assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT"
|
self.ret = self.fxn(*args, **kwargs)
|
||||||
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
|
if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
|
||||||
if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
|
|
||||||
for ei in self.jit_cache: ei.run(var_vals, jit=True)
|
|
||||||
elif self.cnt == 1:
|
elif self.cnt == 1:
|
||||||
# jit capture
|
# jit capture
|
||||||
self.expected_names: List[Union[int, str]] = expected_names
|
self.expected_names: List[Union[int, str]] = expected_names
|
||||||
|
@ -118,7 +116,7 @@ class TinyJit(Generic[ReturnType]):
|
||||||
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)):
|
with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)):
|
||||||
capturing.append(self)
|
capturing.append(self)
|
||||||
self.ret = self.fxn(*args, **kwargs)
|
self.ret = self.fxn(*args, **kwargs)
|
||||||
Tensor.corealize(get_parameters(self.ret))
|
if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:])
|
||||||
capturing.clear()
|
capturing.clear()
|
||||||
del self.buffer_replace
|
del self.buffer_replace
|
||||||
assert len(self.jit_cache), "didn't JIT anything!"
|
assert len(self.jit_cache), "didn't JIT anything!"
|
||||||
|
@ -133,10 +131,12 @@ class TinyJit(Generic[ReturnType]):
|
||||||
|
|
||||||
self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers)
|
self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers)
|
||||||
if DEBUG >= 1 and len(set(self.input_replace.values())) != len(input_rawbuffers): print("WARNING: some input tensors not found")
|
if DEBUG >= 1 and len(set(self.input_replace.values())) != len(input_rawbuffers): print("WARNING: some input tensors not found")
|
||||||
elif self.cnt == 0:
|
elif self.cnt >= 2:
|
||||||
# jit ignore
|
# jit exec
|
||||||
self.ret = self.fxn(*args, **kwargs)
|
assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT"
|
||||||
Tensor.corealize(get_parameters(self.ret))
|
for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
|
||||||
|
if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels")
|
||||||
|
for ei in self.jit_cache: ei.run(var_vals, jit=True)
|
||||||
|
|
||||||
# clear jit inputs
|
# clear jit inputs
|
||||||
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None
|
||||||
|
|
|
@ -18,7 +18,7 @@ class Optimizer:
|
||||||
def zero_grad(self):
|
def zero_grad(self):
|
||||||
for param in self.params: param.grad = None
|
for param in self.params: param.grad = None
|
||||||
|
|
||||||
def step(self): Tensor.corealize(self.schedule_step())
|
def step(self): Tensor.realize(*self.schedule_step())
|
||||||
def schedule_step(self) -> List[Tensor]: return self._step()+self.params+self.buffers
|
def schedule_step(self) -> List[Tensor]: return self._step()+self.params+self.buffers
|
||||||
def _step(self) -> List[Tensor]: raise NotImplementedError
|
def _step(self) -> List[Tensor]: raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import time, math, itertools, functools
|
import time, math, itertools, functools
|
||||||
from contextlib import ContextDecorator
|
from contextlib import ContextDecorator
|
||||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Iterable, Dict, DefaultDict, cast, get_args
|
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
@ -11,10 +11,10 @@ from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up
|
||||||
from tinygrad.helpers import getenv
|
from tinygrad.helpers import getenv
|
||||||
from tinygrad.lazy import LazyBuffer
|
from tinygrad.lazy import LazyBuffer
|
||||||
from tinygrad.features.multi import MultiLazyBuffer
|
from tinygrad.features.multi import MultiLazyBuffer
|
||||||
from tinygrad.ops import LoadOps
|
from tinygrad.ops import LoadOps, ScheduleItem
|
||||||
from tinygrad.buffer import Buffer, BufferOptions
|
from tinygrad.buffer import Buffer, BufferOptions
|
||||||
from tinygrad.device import Device
|
from tinygrad.device import Device
|
||||||
from tinygrad.shape.symbolic import sint
|
from tinygrad.shape.symbolic import sint, Variable
|
||||||
from tinygrad.engine.realize import run_schedule, memory_planner
|
from tinygrad.engine.realize import run_schedule, memory_planner
|
||||||
from tinygrad.engine.schedule import create_schedule_with_vars
|
from tinygrad.engine.schedule import create_schedule_with_vars
|
||||||
|
|
||||||
|
@ -146,17 +146,23 @@ class Tensor:
|
||||||
|
|
||||||
# ***** data handlers ****
|
# ***** data handlers ****
|
||||||
|
|
||||||
@staticmethod
|
def schedule_with_vars(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]:
|
||||||
def corealize(lst:Iterable[Tensor]):
|
"""Create the schedule needed to realize these Tensor(s), with Variables."""
|
||||||
if getenv("FUZZ_SCHEDULE"):
|
if getenv("FUZZ_SCHEDULE"):
|
||||||
from test.external.fuzz_schedule import fuzz_schedule
|
from test.external.fuzz_schedule import fuzz_schedule
|
||||||
fuzz_schedule(flatten([x.lazydata.lbs for x in lst]))
|
fuzz_schedule(flatten([x.lazydata.lbs for x in (self,)+lst]))
|
||||||
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in lst]))
|
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]), seen)
|
||||||
run_schedule(memory_planner(schedule), var_vals)
|
return memory_planner(schedule), var_vals
|
||||||
|
|
||||||
def realize(self) -> Tensor:
|
def schedule(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
|
||||||
"""Trigger the computation needed to create this Tensor. This is a light wrapper around corealize."""
|
"""Create the schedule needed to realize these Tensor(s)."""
|
||||||
Tensor.corealize([self])
|
schedule, var_vals = self.schedule_with_vars(*lst, seen=seen)
|
||||||
|
assert len(var_vals) == 0
|
||||||
|
return schedule
|
||||||
|
|
||||||
|
def realize(self, *lst:Tensor) -> Tensor:
|
||||||
|
"""Trigger the computation needed to create these Tensor(s)."""
|
||||||
|
run_schedule(*self.schedule_with_vars(*lst))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def replace(self, x:Tensor) -> Tensor:
|
def replace(self, x:Tensor) -> Tensor:
|
||||||
|
|
Loading…
Reference in New Issue