mirror of https://github.com/commaai/tinygrad.git
perf: lazyop as dataclass (#1603)
* perf: lazyop as dataclass fix: linter fix: restore eq * use builtin methods, buffers to property to allow freezing * fix: reduce diff * fix: can't freeze due to KOPT tests, mypy * fix: explicit hash * can freeze if tests are fixed * fix: typo --------- Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
parent
0ca0e9ee5e
commit
36ab04ae35
|
@ -2,11 +2,12 @@
|
|||
import unittest, gc
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict
|
||||
from tinygrad.ops import GlobalCounters, LazyOp, LoadOps
|
||||
from tinygrad.nn.state import get_state_dict
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.runtime.lib import RawBuffer, LRUAllocator
|
||||
from tinygrad.helpers import dtypes, prod
|
||||
from tinygrad.ops import Device
|
||||
from test.helpers import derandomize_model
|
||||
|
||||
from examples.llama import Transformer
|
||||
|
||||
|
@ -86,20 +87,6 @@ def check_gc():
|
|||
from extra.introspection import print_objects
|
||||
assert print_objects() == 0
|
||||
|
||||
# for speed
|
||||
def derandomize(x):
|
||||
if isinstance(x, LazyOp):
|
||||
if x.op == LoadOps.RAND: x.op = LoadOps.EMPTY
|
||||
x.src = tuple([derandomize(s) for s in x.src])
|
||||
else:
|
||||
x.op = derandomize(x.op)
|
||||
return x
|
||||
|
||||
def derandomize_model(model):
|
||||
for p in get_parameters(model):
|
||||
p.lazydata = derandomize(p.lazydata)
|
||||
p.realize()
|
||||
|
||||
class TestAllocators(unittest.TestCase):
|
||||
@unittest.skipUnless(Device.DEFAULT == "GPU", "Not Implemented")
|
||||
def test_lru_allocator_tiny_llama(self):
|
||||
|
|
|
@ -2,28 +2,13 @@
|
|||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.state import get_parameters
|
||||
from tinygrad.ops import LazyOp, LoadOps
|
||||
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
|
||||
from tinygrad.helpers import dtypes, CI
|
||||
from tinygrad.lazy import Device
|
||||
from tinygrad.ops import Device
|
||||
from test.helpers import derandomize_model
|
||||
|
||||
from examples.llama import Transformer
|
||||
|
||||
# for speed
|
||||
def derandomize(x):
|
||||
if isinstance(x, LazyOp):
|
||||
if x.op == LoadOps.RAND: x.op = LoadOps.EMPTY
|
||||
x.src = tuple([derandomize(s) for s in x.src])
|
||||
else:
|
||||
x.op = derandomize(x.op)
|
||||
return x
|
||||
|
||||
def derandomize_model(model):
|
||||
for p in get_parameters(model):
|
||||
p.lazydata = derandomize(p.lazydata)
|
||||
p.realize()
|
||||
|
||||
def helper_test_jitted_correctness(gen, train, train_jit):
|
||||
nojit = train(*gen()).numpy()
|
||||
for _ in range(5): jit = train_jit(*gen()).numpy()
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
from tinygrad.ops import LazyOp, LoadOps
|
||||
from tinygrad.nn.state import get_parameters
|
||||
|
||||
# for speed
|
||||
def derandomize(x):
|
||||
if isinstance(x, LazyOp):
|
||||
new_op = LoadOps.EMPTY if x.op == LoadOps.RAND else x.op
|
||||
return LazyOp(new_op, tuple([derandomize(s) for s in x.src]), x.arg)
|
||||
x.op = derandomize(x.op)
|
||||
return x
|
||||
|
||||
def derandomize_model(model):
|
||||
for p in get_parameters(model):
|
||||
p.lazydata = derandomize(p.lazydata)
|
||||
p.realize()
|
|
@ -4,8 +4,9 @@ from tinygrad.tensor import Tensor
|
|||
from tinygrad.nn import optim
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.jit import TinyJit, JIT_SUPPORTED_DEVICE
|
||||
from tinygrad.ops import Device, GlobalCounters, LazyOp, LoadOps
|
||||
from tinygrad.ops import Device, GlobalCounters
|
||||
from tinygrad.helpers import CI, dtypes, getenv, prod
|
||||
from test.helpers import derandomize_model
|
||||
|
||||
from examples.gpt2 import Transformer as GPT2Transformer, MODEL_PARAMS as GPT2_MODEL_PARAMS
|
||||
from examples.hlb_cifar10 import SpeedyResNet
|
||||
|
@ -30,20 +31,6 @@ def helper_test(nm, gen, train, max_memory_allowed, max_kernels_allowed, all_jit
|
|||
if all_jitted:
|
||||
assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count, f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted"
|
||||
|
||||
# for speed
|
||||
def derandomize(x):
|
||||
if isinstance(x, LazyOp):
|
||||
if x.op == LoadOps.RAND: x.op = LoadOps.EMPTY
|
||||
x.src = tuple([derandomize(s) for s in x.src])
|
||||
elif hasattr(x, "op"):
|
||||
x.op = derandomize(x.op)
|
||||
return x
|
||||
|
||||
def derandomize_model(model):
|
||||
for p in get_parameters(model):
|
||||
p.lazydata = derandomize(p.lazydata)
|
||||
p.realize()
|
||||
|
||||
class TestRealWorld(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.old_type = Tensor.default_type
|
||||
|
|
|
@ -47,21 +47,19 @@ class ScheduleItem:
|
|||
inputs: Tuple[LazyBuffer, ...]
|
||||
var_vals: Dict[Variable, int]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LazyOp:
|
||||
__slots__ = "op", "src", "arg", "buffers", "__weakref__"
|
||||
op: Op
|
||||
src: Tuple[Union[LazyOp, LazyBuffer], ...]
|
||||
arg: Any
|
||||
buffers: Tuple[LazyBuffer, ...]
|
||||
def __init__(self, op: Op, src: Tuple[Union[LazyOp, LazyBuffer], ...], arg: Any = None):
|
||||
self.op, self.src, self.arg, self.buffers = op, src, arg, ()
|
||||
arg: Any = None
|
||||
@property
|
||||
def buffers(self):
|
||||
buffers: Tuple[Union[LazyOp, LazyBuffer], ...] = ()
|
||||
try: # NOTE: the linearizer's key function maps the buffers to ints, and LOCAL_BUFFER is used. we don't care about buffers in these cases
|
||||
for x in src: self.buffers += x.buffers
|
||||
except AttributeError: self.buffers = ()
|
||||
for x in self.src: buffers += x.buffers
|
||||
except AttributeError: buffers = ()
|
||||
return buffers
|
||||
|
||||
def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})"
|
||||
def __eq__(self, __value: object) -> bool: return isinstance(__value, LazyOp) and self.op is __value.op and self.src == __value.src and self.arg == __value.arg
|
||||
def __hash__(self) -> int: return hash((self.op, self.src, self.arg))
|
||||
@property
|
||||
def key(self): return (self.op, tuple(map(lambda x: getattr(x, "key", x), self.src)), getattr(self.arg, "key", self.arg))
|
||||
|
||||
|
|
Loading…
Reference in New Issue