From 967638f0d578dea3b23e6b30eeac93892495b706 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Tue, 23 Apr 2024 12:05:29 +0400 Subject: [PATCH] update docs, remove corealize (#4264) * update docs, remove corealize * handle 0 line count * tensor schedule --- docs/developer.md | 20 ++++++++++++++++ docs/tensor.md | 3 ++- serve_docs.sh | 2 ++ test/external/external_benchmark_resnet.py | 6 ++--- test/external/external_test_opt.py | 2 +- test/test_assign.py | 4 ++-- test/test_schedule.py | 2 +- tinygrad/engine/__init__.py | 0 tinygrad/engine/jit.py | 24 +++++++++---------- tinygrad/nn/optim.py | 2 +- tinygrad/tensor.py | 28 +++++++++++++--------- 11 files changed, 61 insertions(+), 32 deletions(-) create mode 100755 serve_docs.sh create mode 100644 tinygrad/engine/__init__.py diff --git a/docs/developer.md b/docs/developer.md index 762d251d..856e0691 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -5,3 +5,23 @@ Everything in [Tensor](tensor.md) is syntactic sugar around [function.py](functi ::: tinygrad.lazy.LazyBuffer options: show_source: false + +## Lowering + +The [scheduler](/tinygrad/engine/schedule.py) converts the graph of LazyBuffers into a list of `ScheduleItem`. `ast` specifies what compute to run, and `bufs` specifies what buffers to run it on. + +::: tinygrad.ops.ScheduleItem + +The code in [realize](/tinygrad/engine/realize.py) lowers `ScheduleItem` to `ExecItem` with + +::: tinygrad.engine.realize.lower_schedule + +## Execution + +Creating `ExecItem`, which has a run method + +::: tinygrad.engine.realize.ExecItem + options: + members: true + +Lists of `ExecItem` can be condensed into a single ExecItem with the Graph API (rename to Queue?) \ No newline at end of file diff --git a/docs/tensor.md b/docs/tensor.md index bcf6f35f..ff90298a 100644 --- a/docs/tensor.md +++ b/docs/tensor.md @@ -12,7 +12,8 @@ ## tinygrad ops -::: tinygrad.Tensor.corealize +::: tinygrad.Tensor.schedule_with_vars +::: tinygrad.Tensor.schedule ::: tinygrad.Tensor.realize ::: tinygrad.Tensor.replace ::: tinygrad.Tensor.assign diff --git a/serve_docs.sh b/serve_docs.sh new file mode 100755 index 00000000..e35dd82f --- /dev/null +++ b/serve_docs.sh @@ -0,0 +1,2 @@ +#!/bin/bash +mkdocs serve -w tinygrad/ diff --git a/test/external/external_benchmark_resnet.py b/test/external/external_benchmark_resnet.py index 3a51b2e6..426f2a4f 100644 --- a/test/external/external_benchmark_resnet.py +++ b/test/external/external_benchmark_resnet.py @@ -55,7 +55,7 @@ class BenchmarkResnetTrain(unittest.TestCase): return f"{name} x{(bs, cin, xy, xy)}", [layer], cin, xy def _test_layer(self, name, layer, cin, xy): optim = SGD(get_parameters(layer), bs / 128 * 1.0) # need sgd for some params but not consequential for benchmarking - with Context(SAVE_SCHEDULE=0): Tensor.corealize([t.assign(t) for t in get_parameters(layer)]) + with Context(SAVE_SCHEDULE=0): Tensor.realize(*[t.assign(t) for t in get_parameters(layer)]) JITCNT = getenv("JITCNT", 1) Tensor.training = True @@ -67,8 +67,8 @@ class BenchmarkResnetTrain(unittest.TestCase): y = x.sequential(layer).contiguous().contiguous_backward() y.sum().backward() - if getenv("ASSIGN", 1): Tensor.corealize([y, x.grad] + optim.schedule_step()) - else: Tensor.corealize([y, x.grad] + [t.grad for t in optim.params]) + if getenv("ASSIGN", 1): Tensor.realize(y, x.grad, *optim.schedule_step()) + else: Tensor.realize(y, x.grad, *[t.grad for t in optim.params]) return y.detach() CNT = getenv("CNT", 5) diff --git a/test/external/external_test_opt.py b/test/external/external_test_opt.py index b26569ae..22690e26 100644 --- a/test/external/external_test_opt.py +++ b/test/external/external_test_opt.py @@ -144,7 +144,7 @@ class TestOptReduceLoop(unittest.TestCase): @unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented") class TestOptWChild(unittest.TestCase): - @unittest.skip("this no longer happens, use corealize") + @unittest.skip("this no longer happens, use realize") def test_unrealized_child(self): a = Tensor.randn(16, 16) b = Tensor.randn(16, 16) diff --git a/test/test_assign.py b/test/test_assign.py index 7901d8d3..e9e561b1 100644 --- a/test/test_assign.py +++ b/test/test_assign.py @@ -171,7 +171,7 @@ class TestAssign(unittest.TestCase): b = Tensor.full((4,), 3).contiguous().realize() a += b b += a - Tensor.corealize([a,b]) + Tensor.realize(a,b) np.testing.assert_allclose(a.numpy(), 5) np.testing.assert_allclose(b.numpy(), 8) @@ -183,7 +183,7 @@ class TestAssign(unittest.TestCase): c = a+9 a += b b += c - Tensor.corealize([a,b]) + Tensor.realize(a,b) np.testing.assert_allclose(a.numpy(), 2+3) np.testing.assert_allclose(b.numpy(), 3+2+9) diff --git a/test/test_schedule.py b/test/test_schedule.py index aa84ab73..5cde92c5 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -18,7 +18,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt seen = set() if to_prerealize: for pre in to_prerealize: - for s in create_schedule([pre.lazydata], seen.copy()): + for s in pre.schedule(seen=seen.copy()): for i,out in enumerate(s.outputs): if GRAPH: realized_lazybuffer(out, 0) seen.add(out) diff --git a/tinygrad/engine/__init__.py b/tinygrad/engine/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index b50301ee..82f77ab2 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -96,7 +96,7 @@ class TinyJit(Generic[ReturnType]): def __call__(self, *args, **kwargs) -> ReturnType: input_tensors: List[Tuple[Union[int, str], Tensor]] = \ [(cast(Union[int, str], k),v) for k,v in itertools.chain(enumerate(args), sorted(kwargs.items())) if v.__class__ is Tensor] - Tensor.corealize([x[1] for x in input_tensors]) + if len(input_tensors): Tensor.realize(*[x[1] for x in input_tensors]) lbs: List[LazyBuffer] = flatten([v.lazydata.lbs for _,v in input_tensors]) expected_sts_var_dtype_device = [(*x.st.unbind(), x.dtype, x.device) for x in lbs] input_rawbuffers: List[Buffer] = [v.base.realized for v in lbs if v.base.realized is not None] @@ -105,12 +105,10 @@ class TinyJit(Generic[ReturnType]): [dict(x.unbind() for x in itertools.chain(args, kwargs.values()) if isinstance(x, Variable))]) expected_names, expected_lbs = [x[0] for x in input_tensors], [(x[0], tuple(x[1].keys()), x[2], x[3]) for x in expected_sts_var_dtype_device] - if self.cnt >= 2: - # jit exec - assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT" - for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx] - if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels") - for ei in self.jit_cache: ei.run(var_vals, jit=True) + if self.cnt == 0: + # jit ignore + self.ret = self.fxn(*args, **kwargs) + if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:]) elif self.cnt == 1: # jit capture self.expected_names: List[Union[int, str]] = expected_names @@ -118,7 +116,7 @@ class TinyJit(Generic[ReturnType]): with Context(GRAPH=getenv("JITGRAPH", GRAPH.value), BEAM=getenv("JITBEAM", BEAM.value)): capturing.append(self) self.ret = self.fxn(*args, **kwargs) - Tensor.corealize(get_parameters(self.ret)) + if len(params:=get_parameters(self.ret)): Tensor.realize(params[0], *params[1:]) capturing.clear() del self.buffer_replace assert len(self.jit_cache), "didn't JIT anything!" @@ -133,10 +131,12 @@ class TinyJit(Generic[ReturnType]): self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers) if DEBUG >= 1 and len(set(self.input_replace.values())) != len(input_rawbuffers): print("WARNING: some input tensors not found") - elif self.cnt == 0: - # jit ignore - self.ret = self.fxn(*args, **kwargs) - Tensor.corealize(get_parameters(self.ret)) + elif self.cnt >= 2: + # jit exec + assert self.expected_names == expected_names and self.expected_lbs == expected_lbs, "args mismatch in JIT" + for (j,i),input_idx in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx] + if DEBUG >= 1 and len(self.jit_cache) >= 10: print(f"jit execs {len(self.jit_cache)} kernels") + for ei in self.jit_cache: ei.run(var_vals, jit=True) # clear jit inputs for (j,i) in self.input_replace.keys(): self.jit_cache[j].rawbufs[i] = None diff --git a/tinygrad/nn/optim.py b/tinygrad/nn/optim.py index 63d52230..55c0c825 100644 --- a/tinygrad/nn/optim.py +++ b/tinygrad/nn/optim.py @@ -18,7 +18,7 @@ class Optimizer: def zero_grad(self): for param in self.params: param.grad = None - def step(self): Tensor.corealize(self.schedule_step()) + def step(self): Tensor.realize(*self.schedule_step()) def schedule_step(self) -> List[Tensor]: return self._step()+self.params+self.buffers def _step(self) -> List[Tensor]: raise NotImplementedError diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 66321902..062fea3d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2,7 +2,7 @@ from __future__ import annotations import time, math, itertools, functools from contextlib import ContextDecorator -from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Iterable, Dict, DefaultDict, cast, get_args +from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, Dict, DefaultDict, cast, get_args, Set from collections import defaultdict import numpy as np @@ -11,10 +11,10 @@ from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up from tinygrad.helpers import getenv from tinygrad.lazy import LazyBuffer from tinygrad.features.multi import MultiLazyBuffer -from tinygrad.ops import LoadOps +from tinygrad.ops import LoadOps, ScheduleItem from tinygrad.buffer import Buffer, BufferOptions from tinygrad.device import Device -from tinygrad.shape.symbolic import sint +from tinygrad.shape.symbolic import sint, Variable from tinygrad.engine.realize import run_schedule, memory_planner from tinygrad.engine.schedule import create_schedule_with_vars @@ -146,17 +146,23 @@ class Tensor: # ***** data handlers **** - @staticmethod - def corealize(lst:Iterable[Tensor]): + def schedule_with_vars(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: + """Create the schedule needed to realize these Tensor(s), with Variables.""" if getenv("FUZZ_SCHEDULE"): from test.external.fuzz_schedule import fuzz_schedule - fuzz_schedule(flatten([x.lazydata.lbs for x in lst])) - schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in lst])) - run_schedule(memory_planner(schedule), var_vals) + fuzz_schedule(flatten([x.lazydata.lbs for x in (self,)+lst])) + schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]), seen) + return memory_planner(schedule), var_vals - def realize(self) -> Tensor: - """Trigger the computation needed to create this Tensor. This is a light wrapper around corealize.""" - Tensor.corealize([self]) + def schedule(self, *lst:Tensor, seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]: + """Create the schedule needed to realize these Tensor(s).""" + schedule, var_vals = self.schedule_with_vars(*lst, seen=seen) + assert len(var_vals) == 0 + return schedule + + def realize(self, *lst:Tensor) -> Tensor: + """Trigger the computation needed to create these Tensor(s).""" + run_schedule(*self.schedule_with_vars(*lst)) return self def replace(self, x:Tensor) -> Tensor: