mirror of https://github.com/commaai/tinygrad.git
assign buffer reuse (#547)
* assign buffer reuse works * fix assign for torch and cpu * allow assign from numpy * fix llvm output_buffer * add some assign tests * fix assignment test * test should fail without lazy * env var to disable assign
This commit is contained in:
parent
473bbd3e35
commit
5de850f6d5
|
@ -1,15 +0,0 @@
|
|||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.graph import nm
|
||||
|
||||
if __name__ == "__main__":
|
||||
GlobalCounters.cache = []
|
||||
a = Tensor.ones(4,4)
|
||||
b = Tensor.ones(4,4)
|
||||
a += b
|
||||
print(a.numpy())
|
||||
runner, args = GlobalCounters.cache[0]
|
||||
b0, b1, b2 = args
|
||||
print(nm(b0), b0)
|
||||
print(nm(b1), b1)
|
||||
print(nm(b2), b2)
|
|
@ -0,0 +1,67 @@
|
|||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import LAZY
|
||||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.graph import nm
|
||||
|
||||
N = 200 # has to be bigger than the cache to fail
|
||||
|
||||
class TestAssign(unittest.TestCase):
|
||||
def test_simple_assignment(self):
|
||||
a = Tensor.arange(N*N).reshape(N,N)
|
||||
b = Tensor.arange(N*N).reshape(N,N)
|
||||
a.realize()
|
||||
b.realize()
|
||||
ba1 = a.lazydata.realized
|
||||
bb1 = b.lazydata.realized
|
||||
a += b
|
||||
a.realize()
|
||||
ba2 = a.lazydata.realized
|
||||
if LAZY: assert ba1 == ba2 and ba1 != bb1
|
||||
np.testing.assert_allclose(a.numpy(), (np.arange(N*N)*2).reshape((N,N)))
|
||||
|
||||
def test_permuted_assignment(self):
|
||||
a = Tensor.arange(N*N).reshape(N,N)
|
||||
b = Tensor.arange(N*N).reshape(N,N)
|
||||
a.realize()
|
||||
b.realize()
|
||||
ba1 = a.lazydata.realized
|
||||
bb1 = b.lazydata.realized
|
||||
a = a.permute(1,0)
|
||||
a += b
|
||||
a.realize()
|
||||
ba2 = a.lazydata.realized
|
||||
assert ba1 != ba2 and ba1 != bb1
|
||||
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
|
||||
|
||||
def test_post_permuted_assignment(self):
|
||||
a = Tensor.arange(N*N).reshape(N,N)
|
||||
b = Tensor.arange(N*N).reshape(N,N)
|
||||
a.realize()
|
||||
b.realize()
|
||||
#GlobalCounters.cache = []
|
||||
ba1 = a.lazydata.realized
|
||||
bb1 = b.lazydata.realized
|
||||
a.assign(a.permute(1,0) + b) # this should not work!
|
||||
a.realize()
|
||||
ba2 = a.lazydata.realized
|
||||
# NOTE: don't test that it's assigned
|
||||
#assert ba1 == ba2 and ba1 != bb1
|
||||
|
||||
"""
|
||||
if len(GlobalCounters.cache):
|
||||
runner, args = GlobalCounters.cache[0]
|
||||
b0, b1, b2 = args
|
||||
print(nm(b0), id(b0.cl))
|
||||
print(nm(b1), id(b1.cl))
|
||||
print(nm(b2), id(b2.cl))
|
||||
"""
|
||||
|
||||
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
|
||||
|
||||
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -31,7 +31,7 @@ class Token:
|
|||
|
||||
# ast kernel can contain one ReduceOp with arbitrary Binary/Unary ops
|
||||
class ASTKernel:
|
||||
def __init__(self, ast:LazyOp):
|
||||
def __init__(self, ast:LazyOp, output_buffer=None):
|
||||
# key for lookup in cache (can change, str might not be right)
|
||||
self.input_ast = ast
|
||||
self.key = str(ast)
|
||||
|
@ -47,8 +47,16 @@ class ASTKernel:
|
|||
self.bufs = dedup(get_buffers(ast))
|
||||
self.ast = ast
|
||||
|
||||
# check if the output buffer is allowed to be used
|
||||
# if it's aliased, don't use it
|
||||
if output_buffer is not None:
|
||||
for a in self.bufs:
|
||||
if a._buf == output_buffer._buf and not a.st.contiguous:
|
||||
output_buffer = None
|
||||
break
|
||||
|
||||
# create the buffer we are returning (as the same type as the input buffers) and add it as the first buffer
|
||||
self.ret = type(self.bufs[0])(output_shape if output_shape else self.info.shape, force_create=True)
|
||||
self.ret = output_buffer if output_buffer else type(self.bufs[0])(output_shape if output_shape else self.info.shape, force_create=True)
|
||||
self.bufs = ([type(self.ret)(self.info.shape, hostbuf=self.ret)] if output_shape else [self.ret]) + self.bufs
|
||||
|
||||
# TODO: should be optional if it's hitting a function cache
|
||||
|
|
|
@ -14,6 +14,7 @@ sys.setrecursionlimit(10000)
|
|||
OPT = getenv("OPT", 2)
|
||||
NOCONV = getenv("NOCONV", 0)
|
||||
IMAGE = getenv("IMAGE", 0)
|
||||
LAZY = getenv("LAZY", 1)
|
||||
|
||||
# late import of Device
|
||||
from tinygrad.device import Device
|
||||
|
@ -96,13 +97,15 @@ class LazyBuffer:
|
|||
self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
|
||||
self.shape, self.optype, self.op = self.st.shape, optype, op
|
||||
self.realized : Optional[DeviceBuffer] = None
|
||||
self.output_buffer : Optional[DeviceBuffer] = None
|
||||
self.device, self.dbuffer = device, Device._buffers[device]
|
||||
self.children : weakref.WeakSet[LazyBuffer] = weakref.WeakSet()
|
||||
# NOTE: op should be read only after construction of LazyBuffer
|
||||
for x in get_buffers(op):
|
||||
x.children.add(self)
|
||||
if not getenv("LAZY", 1):
|
||||
if not LAZY:
|
||||
self.realize()
|
||||
if DEBUG >= 4: print(f"create {self}")
|
||||
|
||||
def __repr__(self): return f"<LB {self.shape} op:{self.op.op if self.realized is None else 'realized'}>"
|
||||
|
||||
|
@ -141,7 +144,7 @@ class LazyBuffer:
|
|||
# run the ast if we still have to, and log the op
|
||||
if self.realized is None:
|
||||
ast = map_buffers({x:x.realize(self.device) for x in get_buffers(ast)}, ast)
|
||||
self.realized = self.dbuffer.exec_ast(ast)
|
||||
self.realized = self.dbuffer.exec_ast(ast, output_buffer=self.output_buffer)
|
||||
log_op(self.realized, ast)
|
||||
|
||||
assert self.realized.shape == self.shape, f"shape mismatch on realize {self.realized.shape} vs {self.shape}"
|
||||
|
|
|
@ -360,8 +360,8 @@ class GPUBuffer(ExplicitExecAST):
|
|||
return data
|
||||
|
||||
@classmethod
|
||||
def exec_ast(cls, ast:LazyOp):
|
||||
k = CLASTKernel(ast)
|
||||
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[GPUBuffer]=None):
|
||||
k = CLASTKernel(ast, output_buffer)
|
||||
if KOPT:
|
||||
from extra.kernel_search import apply_optimization
|
||||
apply_optimization(k, ast, max_interventions=KOPT)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
import math
|
||||
from typing import Tuple, Union, Dict, Any, List, ClassVar
|
||||
from typing import Tuple, Union, Dict, Any, List, ClassVar, Optional
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape import ShapeTracker, ZeroView
|
||||
from tinygrad.ops import LazyOp
|
||||
|
@ -104,8 +104,8 @@ class LLVMBuffer(ExplicitExecAST):
|
|||
|
||||
func_cache : Dict[str, Any] = {}
|
||||
@classmethod
|
||||
def exec_ast(cls, ast:LazyOp) -> LLVMBuffer:
|
||||
k = ASTKernel(ast)
|
||||
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[LLVMBuffer]=None) -> LLVMBuffer:
|
||||
k = ASTKernel(ast, output_buffer)
|
||||
|
||||
# cached kernel
|
||||
if k.key in LLVMBuffer.func_cache:
|
||||
|
|
|
@ -68,8 +68,8 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method
|
|||
def movement_op(self, op, arg=None): return type(self)(self.fxn_for_op[op](self.buf, arg)) if op in self.fxn_for_op else type(self)(getattr(self.buf, op.name.lower())(arg))
|
||||
def processing_op(self, op, w, C): return type(self)(self.fxn_for_op[op](self.buf, w.buf, C))
|
||||
@classmethod
|
||||
def exec_ast(cls, ast:LazyOp, preprocess=lambda x: x):
|
||||
srcs = [cls.exec_ast(x, preprocess) if isinstance(x, LazyOp) else preprocess(x) for x in ast.src]
|
||||
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[GenericExecAST]=None, preprocess=lambda x: x):
|
||||
srcs = [cls.exec_ast(x, preprocess=preprocess) if isinstance(x, LazyOp) else preprocess(x) for x in ast.src]
|
||||
if ast.op in UnaryOps:
|
||||
ret = srcs[0].unary_op(ast.op)
|
||||
elif ast.op in BinaryOps:
|
||||
|
@ -84,8 +84,13 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method
|
|||
ret = srcs[0].processing_op(ast.op, srcs[1], ast.arg)
|
||||
else:
|
||||
raise TypeError("unknown op")
|
||||
return ret
|
||||
def get_lazyop_info(ast:LazyOp): return GenericExecAST.exec_ast(ast, lambda x: GenericExecAST(GenericShape(x.shape))).buf
|
||||
if output_buffer is not None:
|
||||
assert output_buffer.shape == ret.shape
|
||||
output_buffer.buf = ret.buf
|
||||
return output_buffer
|
||||
else:
|
||||
return ret
|
||||
def get_lazyop_info(ast:LazyOp): return GenericExecAST.exec_ast(ast, preprocess=lambda x: GenericExecAST(GenericShape(x.shape))).buf
|
||||
|
||||
class GlobalCounters:
|
||||
global_ops : ClassVar[int] = 0
|
||||
|
|
|
@ -2,9 +2,10 @@
|
|||
from __future__ import annotations
|
||||
import functools, itertools
|
||||
import numpy as np
|
||||
from tinygrad.helpers import prod, argfix, make_pair
|
||||
from tinygrad.helpers import prod, argfix, make_pair, getenv
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union
|
||||
from tinygrad.lazy import Device, LazyBuffer
|
||||
from tinygrad.ops import DEBUG
|
||||
|
||||
# An instantiation of the Function is the Context
|
||||
class Function:
|
||||
|
@ -81,12 +82,14 @@ class Tensor:
|
|||
self.lazydata.realize()
|
||||
return self
|
||||
|
||||
def assign(self, x):
|
||||
if not isinstance(x, Tensor):
|
||||
x = Tensor(x)
|
||||
def assign(self, x) -> Tensor:
|
||||
if not isinstance(x, Tensor): x = Tensor(x)
|
||||
assert self.shape == x.shape
|
||||
assert not x.requires_grad # self requires_grad is okay?
|
||||
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
|
||||
if self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized
|
||||
self.lazydata = x.lazydata
|
||||
return x
|
||||
return self
|
||||
|
||||
def detach(self): return Tensor(self.lazydata, device=self.device, requires_grad=False)
|
||||
def numpy(self) -> np.ndarray: return np.array(self.lazydata.toCPU())
|
||||
|
@ -214,7 +217,7 @@ class Tensor:
|
|||
slc = [[(0, s) for s in self.shape] for _ in catargs]
|
||||
for s,k in zip(slc, shape_cumsum):
|
||||
s[dim] = (-k, shape_cumsum[-1]-k)
|
||||
return functools.reduce(Tensor.__iadd__, [arg.slice(arg=s) for arg,s in zip(catargs, slc)])
|
||||
return functools.reduce(Tensor.__add__, [arg.slice(arg=s) for arg,s in zip(catargs, slc)])
|
||||
|
||||
# TODO: make this nicer with syntactic sugar in slice
|
||||
def chunk(self, num, dim):
|
||||
|
|
Loading…
Reference in New Issue