diff --git a/test/external_test_assign.py b/test/external_test_assign.py deleted file mode 100644 index f5937a5d..00000000 --- a/test/external_test_assign.py +++ /dev/null @@ -1,15 +0,0 @@ -from tinygrad.tensor import Tensor -from tinygrad.ops import GlobalCounters -from tinygrad.graph import nm - -if __name__ == "__main__": - GlobalCounters.cache = [] - a = Tensor.ones(4,4) - b = Tensor.ones(4,4) - a += b - print(a.numpy()) - runner, args = GlobalCounters.cache[0] - b0, b1, b2 = args - print(nm(b0), b0) - print(nm(b1), b1) - print(nm(b2), b2) diff --git a/test/test_assign.py b/test/test_assign.py new file mode 100644 index 00000000..e1b1ec5f --- /dev/null +++ b/test/test_assign.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python +import unittest +import numpy as np +from tinygrad.tensor import Tensor +from tinygrad.lazy import LAZY +from tinygrad.ops import GlobalCounters +from tinygrad.graph import nm + +N = 200 # has to be bigger than the cache to fail + +class TestAssign(unittest.TestCase): + def test_simple_assignment(self): + a = Tensor.arange(N*N).reshape(N,N) + b = Tensor.arange(N*N).reshape(N,N) + a.realize() + b.realize() + ba1 = a.lazydata.realized + bb1 = b.lazydata.realized + a += b + a.realize() + ba2 = a.lazydata.realized + if LAZY: assert ba1 == ba2 and ba1 != bb1 + np.testing.assert_allclose(a.numpy(), (np.arange(N*N)*2).reshape((N,N))) + + def test_permuted_assignment(self): + a = Tensor.arange(N*N).reshape(N,N) + b = Tensor.arange(N*N).reshape(N,N) + a.realize() + b.realize() + ba1 = a.lazydata.realized + bb1 = b.lazydata.realized + a = a.permute(1,0) + a += b + a.realize() + ba2 = a.lazydata.realized + assert ba1 != ba2 and ba1 != bb1 + np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0)) + + def test_post_permuted_assignment(self): + a = Tensor.arange(N*N).reshape(N,N) + b = Tensor.arange(N*N).reshape(N,N) + a.realize() + b.realize() + #GlobalCounters.cache = [] + ba1 = a.lazydata.realized + bb1 = b.lazydata.realized + a.assign(a.permute(1,0) + b) # this should not work! + a.realize() + ba2 = a.lazydata.realized + # NOTE: don't test that it's assigned + #assert ba1 == ba2 and ba1 != bb1 + + """ + if len(GlobalCounters.cache): + runner, args = GlobalCounters.cache[0] + b0, b1, b2 = args + print(nm(b0), id(b0.cl)) + print(nm(b1), id(b1.cl)) + print(nm(b2), id(b2.cl)) + """ + + np.testing.assert_allclose(a.numpy(), np.arange(N*N).reshape((N,N)) + np.arange(N*N).reshape((N,N)).transpose(1,0)) + + # TODO: is there a way to sneak in a permute such that it returns the wrong answer? + +if __name__ == "__main__": + unittest.main() diff --git a/tinygrad/ast.py b/tinygrad/ast.py index fa27a2d4..6658a4a3 100644 --- a/tinygrad/ast.py +++ b/tinygrad/ast.py @@ -31,7 +31,7 @@ class Token: # ast kernel can contain one ReduceOp with arbitrary Binary/Unary ops class ASTKernel: - def __init__(self, ast:LazyOp): + def __init__(self, ast:LazyOp, output_buffer=None): # key for lookup in cache (can change, str might not be right) self.input_ast = ast self.key = str(ast) @@ -47,8 +47,16 @@ class ASTKernel: self.bufs = dedup(get_buffers(ast)) self.ast = ast + # check if the output buffer is allowed to be used + # if it's aliased, don't use it + if output_buffer is not None: + for a in self.bufs: + if a._buf == output_buffer._buf and not a.st.contiguous: + output_buffer = None + break + # create the buffer we are returning (as the same type as the input buffers) and add it as the first buffer - self.ret = type(self.bufs[0])(output_shape if output_shape else self.info.shape, force_create=True) + self.ret = output_buffer if output_buffer else type(self.bufs[0])(output_shape if output_shape else self.info.shape, force_create=True) self.bufs = ([type(self.ret)(self.info.shape, hostbuf=self.ret)] if output_shape else [self.ret]) + self.bufs # TODO: should be optional if it's hitting a function cache diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index d5da2d9d..d81b3b2b 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -14,6 +14,7 @@ sys.setrecursionlimit(10000) OPT = getenv("OPT", 2) NOCONV = getenv("NOCONV", 0) IMAGE = getenv("IMAGE", 0) +LAZY = getenv("LAZY", 1) # late import of Device from tinygrad.device import Device @@ -96,13 +97,15 @@ class LazyBuffer: self.st = shape if isinstance(shape, ShapeTracker) else ShapeTracker(tuple(shape)) self.shape, self.optype, self.op = self.st.shape, optype, op self.realized : Optional[DeviceBuffer] = None + self.output_buffer : Optional[DeviceBuffer] = None self.device, self.dbuffer = device, Device._buffers[device] self.children : weakref.WeakSet[LazyBuffer] = weakref.WeakSet() # NOTE: op should be read only after construction of LazyBuffer for x in get_buffers(op): x.children.add(self) - if not getenv("LAZY", 1): + if not LAZY: self.realize() + if DEBUG >= 4: print(f"create {self}") def __repr__(self): return f"" @@ -141,7 +144,7 @@ class LazyBuffer: # run the ast if we still have to, and log the op if self.realized is None: ast = map_buffers({x:x.realize(self.device) for x in get_buffers(ast)}, ast) - self.realized = self.dbuffer.exec_ast(ast) + self.realized = self.dbuffer.exec_ast(ast, output_buffer=self.output_buffer) log_op(self.realized, ast) assert self.realized.shape == self.shape, f"shape mismatch on realize {self.realized.shape} vs {self.shape}" diff --git a/tinygrad/llops/ops_gpu.py b/tinygrad/llops/ops_gpu.py index 87f2a4c0..bce22f25 100644 --- a/tinygrad/llops/ops_gpu.py +++ b/tinygrad/llops/ops_gpu.py @@ -360,8 +360,8 @@ class GPUBuffer(ExplicitExecAST): return data @classmethod - def exec_ast(cls, ast:LazyOp): - k = CLASTKernel(ast) + def exec_ast(cls, ast:LazyOp, output_buffer:Optional[GPUBuffer]=None): + k = CLASTKernel(ast, output_buffer) if KOPT: from extra.kernel_search import apply_optimization apply_optimization(k, ast, max_interventions=KOPT) diff --git a/tinygrad/llops/ops_llvm.py b/tinygrad/llops/ops_llvm.py index ced1ba70..da50a3d2 100644 --- a/tinygrad/llops/ops_llvm.py +++ b/tinygrad/llops/ops_llvm.py @@ -1,6 +1,6 @@ from __future__ import annotations import math -from typing import Tuple, Union, Dict, Any, List, ClassVar +from typing import Tuple, Union, Dict, Any, List, ClassVar, Optional from tinygrad.helpers import prod from tinygrad.shape import ShapeTracker, ZeroView from tinygrad.ops import LazyOp @@ -104,8 +104,8 @@ class LLVMBuffer(ExplicitExecAST): func_cache : Dict[str, Any] = {} @classmethod - def exec_ast(cls, ast:LazyOp) -> LLVMBuffer: - k = ASTKernel(ast) + def exec_ast(cls, ast:LazyOp, output_buffer:Optional[LLVMBuffer]=None) -> LLVMBuffer: + k = ASTKernel(ast, output_buffer) # cached kernel if k.key in LLVMBuffer.func_cache: diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 1d03859c..32e8fdf7 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -68,8 +68,8 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method def movement_op(self, op, arg=None): return type(self)(self.fxn_for_op[op](self.buf, arg)) if op in self.fxn_for_op else type(self)(getattr(self.buf, op.name.lower())(arg)) def processing_op(self, op, w, C): return type(self)(self.fxn_for_op[op](self.buf, w.buf, C)) @classmethod - def exec_ast(cls, ast:LazyOp, preprocess=lambda x: x): - srcs = [cls.exec_ast(x, preprocess) if isinstance(x, LazyOp) else preprocess(x) for x in ast.src] + def exec_ast(cls, ast:LazyOp, output_buffer:Optional[GenericExecAST]=None, preprocess=lambda x: x): + srcs = [cls.exec_ast(x, preprocess=preprocess) if isinstance(x, LazyOp) else preprocess(x) for x in ast.src] if ast.op in UnaryOps: ret = srcs[0].unary_op(ast.op) elif ast.op in BinaryOps: @@ -84,8 +84,13 @@ class GenericExecAST(DeviceBuffer): # pylint: disable=abstract-method ret = srcs[0].processing_op(ast.op, srcs[1], ast.arg) else: raise TypeError("unknown op") - return ret -def get_lazyop_info(ast:LazyOp): return GenericExecAST.exec_ast(ast, lambda x: GenericExecAST(GenericShape(x.shape))).buf + if output_buffer is not None: + assert output_buffer.shape == ret.shape + output_buffer.buf = ret.buf + return output_buffer + else: + return ret +def get_lazyop_info(ast:LazyOp): return GenericExecAST.exec_ast(ast, preprocess=lambda x: GenericExecAST(GenericShape(x.shape))).buf class GlobalCounters: global_ops : ClassVar[int] = 0 diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d1924ca7..c0713925 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -2,9 +2,10 @@ from __future__ import annotations import functools, itertools import numpy as np -from tinygrad.helpers import prod, argfix, make_pair +from tinygrad.helpers import prod, argfix, make_pair, getenv from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union from tinygrad.lazy import Device, LazyBuffer +from tinygrad.ops import DEBUG # An instantiation of the Function is the Context class Function: @@ -81,12 +82,14 @@ class Tensor: self.lazydata.realize() return self - def assign(self, x): - if not isinstance(x, Tensor): - x = Tensor(x) + def assign(self, x) -> Tensor: + if not isinstance(x, Tensor): x = Tensor(x) assert self.shape == x.shape + assert not x.requires_grad # self requires_grad is okay? + if DEBUG >= 4: print(f"assign {self.lazydata} <- {x.lazydata}") + if self.lazydata.realized is not None and not getenv("DISALLOW_ASSIGN"): x.lazydata.output_buffer = self.lazydata.realized self.lazydata = x.lazydata - return x + return self def detach(self): return Tensor(self.lazydata, device=self.device, requires_grad=False) def numpy(self) -> np.ndarray: return np.array(self.lazydata.toCPU()) @@ -214,7 +217,7 @@ class Tensor: slc = [[(0, s) for s in self.shape] for _ in catargs] for s,k in zip(slc, shape_cumsum): s[dim] = (-k, shape_cumsum[-1]-k) - return functools.reduce(Tensor.__iadd__, [arg.slice(arg=s) for arg,s in zip(catargs, slc)]) + return functools.reduce(Tensor.__add__, [arg.slice(arg=s) for arg,s in zip(catargs, slc)]) # TODO: make this nicer with syntactic sugar in slice def chunk(self, num, dim):