assign buffer reuse (#547)

* assign buffer reuse works

* fix assign for torch and cpu

* allow assign from numpy

* fix llvm output_buffer

* add some assign tests

* fix assignment test

* test should fail without lazy

* env var to disable assign
This commit is contained in:
George Hotz 2023-02-09 11:53:02 -06:00 committed by GitHub
parent 473bbd3e35
commit 5de850f6d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 105 additions and 34 deletions

View File

@ -1,15 +0,0 @@
from tinygrad.tensor import Tensor
from tinygrad.ops import GlobalCounters
from tinygrad.graph import nm
if __name__ == "__main__":
GlobalCounters.cache = []
a = Tensor.ones(4,4)
b = Tensor.ones(4,4)
a += b
print(a.numpy())
runner, args = GlobalCounters.cache[0]
b0, b1, b2 = args
print(nm(b0), b0)
print(nm(b1), b1)
print(nm(b2), b2)

67
test/test_assign.py Normal file
View File

@ -0,0 +1,67 @@
#!/usr/bin/env python
import unittest
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.lazy import LAZY
from tinygrad.ops import GlobalCounters
from tinygrad.graph import nm
N = 200 # has to be bigger than the cache to fail
class TestAssign(unittest.TestCase):
def test_simple_assignment(self):
a = Tensor.arange(N*N).reshape(N,N)
b = Tensor.arange(N*N).reshape(N,N)
a.realize()
b.realize()
ba1 = a.lazydata.realized
bb1 = b.lazydata.realized
a += b
a.realize()
ba2 = a.lazydata.realized
if LAZY: assert ba1 == ba2 and ba1 != bb1
np.testing.assert_allclose(a.numpy(), (np.arange(N*N)*2).reshape((N,N)))
def test_permuted_assignment(self):
a = Tensor.arange(N*N).reshape(N,N)
b = Tensor.arange(N*N).reshape(N,N)
a.realize()
b.realize()
ba1 = a.lazydata.realized
bb1 = b.lazydata.realized
a = a.permute(1,0)
a += b
a.realize()
ba2 = a.lazydata.realized
assert ba1 != ba2 and ba1 != bb1
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
def test_post_permuted_assignment(self):
a = Tensor.arange(N*N).reshape(N,N)
b = Tensor.arange(N*N).reshape(N,N)
a.realize()
b.realize()
#GlobalCounters.cache = []
ba1 = a.lazydata.realized
bb1 = b.lazydata.realized
a.assign(a.permute(1,0) + b) # this should not work!
a.realize()
ba2 = a.lazydata.realized
# NOTE: don't test that it's assigned
#assert ba1 == ba2 and ba1 != bb1
"""
if len(GlobalCounters.cache):
runner, args = GlobalCounters.cache[0]
b0, b1, b2 = args
print(nm(b0), id(b0.cl))
print(nm(b1), id(b1.cl))
print(nm(b2), id(b2.cl))
"""
np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0))
# TODO: is there a way to sneak in a permute such that it returns the wrong answer?
if __name__ == "__main__":
unittest.main()

View File

@ -31,7 +31,7 @@ class Token:
# ast kernel can contain one ReduceOp with arbitrary Binary/Unary ops
class ASTKernel:
def __init__(self, ast:LazyOp):
def __init__(self, ast:LazyOp, output_buffer=None):
# key for lookup in cache (can change, str might not be right)
self.input_ast = ast
self.key = str(ast)
@ -47,8 +47,16 @@ class ASTKernel:
self.bufs = dedup(get_buffers(ast))
self.ast = ast
# check if the output buffer is allowed to be used
# if it's aliased, don't use it
if output_buffer is not None:
for a in self.bufs:
if a._buf == output_buffer._buf and not a.st.contiguous:
output_buffer = None
break
# create the buffer we are returning (as the same type as the input buffers) and add it as the first buffer
self.ret = type(self.bufs[0])(output_shape if output_shape else self.info.shape, force_create=True)
self.ret = output_buffer if output_buffer else type(self.bufs[0])(output_shape if output_shape else self.info.shape, force_create=True)
self.bufs = ([type(self.ret)(self.info.shape, hostbuf=self.ret)] if output_shape else [self.ret]) + self.bufs
# TODO: should be optional if it's hitting a function cache

View File

@ -14,6 +14,7 @@ sys.setrecursionlimit(10000)
OPT = getenv("OPT", 2)
NOCONV = getenv("NOCONV", 0)
IMAGE = getenv("IMAGE", 0)
LAZY = getenv("LAZY", 1)
# late import of Device
from tinygrad.device import Device
@ -96,13 +97,15 @@ class LazyBuffer:
self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape))
self.shape, self.optype, self.op = self.st.shape, optype, op
self.realized : Optional[DeviceBuffer] = None
self.output_buffer : Optional[DeviceBuffer] = None
self.device, self.dbuffer = device, Device._buffers[device]
self.children : weakref.WeakSet[LazyBuffer] = weakref.WeakSet()
# NOTE: op should be read only after construction of LazyBuffer
for x in get_buffers(op):
x.children.add(self)
if not getenv("LAZY", 1):
if not LAZY:
self.realize()
if DEBUG >= 4: print(f"create {self}")
def __repr__(self): return f"<LB {self.shape} op:{self.op.op if self.realized is None else 'realized'}>"
@ -141,7 +144,7 @@ class LazyBuffer:
# run the ast if we still have to, and log the op
if self.realized is None:
ast = map_buffers({x:x.realize(self.device) for x in get_buffers(ast)}, ast)
self.realized = self.dbuffer.exec_ast(ast)
self.realized = self.dbuffer.exec_ast(ast, output_buffer=self.output_buffer)
log_op(self.realized, ast)
assert self.realized.shape == self.shape, f"shape mismatch on realize {self.realized.shape} vs {self.shape}"

View File

@ -360,8 +360,8 @@ class GPUBuffer(ExplicitExecAST):
return data
@classmethod
def exec_ast(cls, ast:LazyOp):
k = CLASTKernel(ast)
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[GPUBuffer]=None):
k = CLASTKernel(ast, output_buffer)
if KOPT:
from extra.kernel_search import apply_optimization
apply_optimization(k, ast, max_interventions=KOPT)

View File

@ -1,6 +1,6 @@
from __future__ import annotations
import math
from typing import Tuple, Union, Dict, Any, List, ClassVar
from typing import Tuple, Union, Dict, Any, List, ClassVar, Optional
from tinygrad.helpers import prod
from tinygrad.shape import ShapeTracker, ZeroView
from tinygrad.ops import LazyOp
@ -104,8 +104,8 @@ class LLVMBuffer(ExplicitExecAST):
func_cache : Dict[str, Any] = {}
@classmethod
def exec_ast(cls, ast:LazyOp) -> LLVMBuffer:
k = ASTKernel(ast)
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[LLVMBuffer]=None) -> LLVMBuffer:
k = ASTKernel(ast, output_buffer)
# cached kernel
if k.key in LLVMBuffer.func_cache:

View File

@ -68,8 +68,8 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method
def movement_op(self, op, arg=None): return type(self)(self.fxn_for_op[op](self.buf, arg)) if op in self.fxn_for_op else type(self)(getattr(self.buf, op.name.lower())(arg))
def processing_op(self, op, w, C): return type(self)(self.fxn_for_op[op](self.buf, w.buf, C))
@classmethod
def exec_ast(cls, ast:LazyOp, preprocess=lambda x: x):
srcs = [cls.exec_ast(x, preprocess) if isinstance(x, LazyOp) else preprocess(x) for x in ast.src]
def exec_ast(cls, ast:LazyOp, output_buffer:Optional[GenericExecAST]=None, preprocess=lambda x: x):
srcs = [cls.exec_ast(x, preprocess=preprocess) if isinstance(x, LazyOp) else preprocess(x) for x in ast.src]
if ast.op in UnaryOps:
ret = srcs[0].unary_op(ast.op)
elif ast.op in BinaryOps:
@ -84,8 +84,13 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method
ret = srcs[0].processing_op(ast.op, srcs[1], ast.arg)
else:
raise TypeError("unknown op")
return ret
def get_lazyop_info(ast:LazyOp): return GenericExecAST.exec_ast(ast, lambda x: GenericExecAST(GenericShape(x.shape))).buf
if output_buffer is not None:
assert output_buffer.shape == ret.shape
output_buffer.buf = ret.buf
return output_buffer
else:
return ret
def get_lazyop_info(ast:LazyOp): return GenericExecAST.exec_ast(ast, preprocess=lambda x: GenericExecAST(GenericShape(x.shape))).buf
class GlobalCounters:
global_ops : ClassVar[int] = 0

View File

@ -2,9 +2,10 @@
from __future__ import annotations
import functools, itertools
import numpy as np
from tinygrad.helpers import prod, argfix, make_pair
from tinygrad.helpers import prod, argfix, make_pair, getenv
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union
from tinygrad.lazy import Device, LazyBuffer
from tinygrad.ops import DEBUG
# An instantiation of the Function is the Context
class Function:
@ -81,12 +82,14 @@ class Tensor:
self.lazydata.realize()
return self
def assign(self, x):
if not isinstance(x, Tensor):
x = Tensor(x)
def assign(self, x) -> Tensor:
if not isinstance(x, Tensor): x = Tensor(x)
assert self.shape == x.shape
assert not x.requires_grad # self requires_grad is okay?
if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}")
if self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized
self.lazydata = x.lazydata
return x
return self
def detach(self): return Tensor(self.lazydata, device=self.device, requires_grad=False)
def numpy(self) -> np.ndarray: return np.array(self.lazydata.toCPU())
@ -214,7 +217,7 @@ class Tensor:
slc = [[(0, s) for s in self.shape] for _ in catargs]
for s,k in zip(slc, shape_cumsum):
s[dim] = (-k, shape_cumsum[-1]-k)
return functools.reduce(Tensor.__iadd__, [arg.slice(arg=s) for arg,s in zip(catargs, slc)])
return functools.reduce(Tensor.__add__, [arg.slice(arg=s) for arg,s in zip(catargs, slc)])
# TODO: make this nicer with syntactic sugar in slice
def chunk(self, num, dim):