diff --git a/docs/abstractions.py b/docs/abstractions.py index 117436cf..000f5553 100644 --- a/docs/abstractions.py +++ b/docs/abstractions.py @@ -2,7 +2,7 @@ Welcome to the tinygrad documentation ================= -this file will take you on a whirlwind journey from a Tensor to a Byte +this file will take you on a whirlwind journey from a Tensor all the way down tinygrad has been aggressively refactored in the 2.5 years it's been worked on. what you see here is a refined library (with more refining to go still!) @@ -11,23 +11,21 @@ this documentation will help with entry points and understanding the abstraction """ # %% -# == Boilerplate imports (typing mostly) == +# == Boilerplate imports for typing == from __future__ import annotations from typing import Optional, Tuple, Union, Any, Dict, Callable, Type, List from enum import Enum, auto from abc import ABC -import numpy as np -import torch # %% # == Example: Tensor 2+3 == -# Let's trace an addition down through the layers of abstraction. -# We will be using the clang backend +# let's trace an addition down through the layers of abstraction. +# we will be using the clang backend from tinygrad.lazy import Device Device.DEFAULT = "CLANG" -# first, 2+3 as a Tensor +# first, 2+3 as a Tensor, the highest level from tinygrad.tensor import Tensor a = Tensor([2]) b = Tensor([3]) @@ -37,7 +35,7 @@ assert result.numpy()[0] == 5. # %% # == Tensor (in tinygrad/tensor.py, code 8/10) == -# it's worth just reading tinygrad/tensor.py. it's pretty beautiful +# it's worth reading tinygrad/tensor.py. it's pretty beautiful import tinygrad.mlops as mlops # this is the good old familiar Tensor class @@ -52,7 +50,7 @@ class Tensor: # this is where the data (and other tensor properties) actually live lazydata: LazyBuffer - # high level ops (hlops) are defined here. example: relu + # high level ops (hlops) are defined on this class. example: relu def relu(self): return self.maximum(0) # log is an mlop, this is the wrapper function in Tensor @@ -88,8 +86,7 @@ class LazyBuffer: realized: Optional[DeviceBuffer] # LazyOp (in tinygrad/ops.py, code 4/10) -# they form an Abstract Syntax Tree for a single GPU kernel -# LazyOp is an AST node that defines: +# in a tree they form an Abstract Syntax Tree for a single GPU kernel class LazyOp: op: Op # the type of the compute src: Tuple[Union[LazyOp, LazyBuffer], ...] # the sources @@ -164,6 +161,8 @@ class DeviceBuffer(ABC): # InterpretedBuffers are a lot simpler than CompiledBuffers # they are used to implement the CPU(numpy) and TORCH(torch) backends # it's worth reading CPUBuffer (in tinygrad/runtime/ops_cpu.py, code 8/10) +import numpy as np +import torch class InterpretedBuffer(DeviceBuffer): # this is where the data actually lives # finally some classes you recognize! diff --git a/test/test_ops.py b/test/test_ops.py index 04715dca..821ef8f8 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -218,6 +218,10 @@ class TestOps(unittest.TestCase): def test_scalar_rsub(self): helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x) + def test_flip_eye_crash(self): + helper_test_op([], lambda: (torch.eye(10)@torch.eye(10).flip(0)), + lambda: (Tensor.eye(10)@Tensor.eye(10).flip(0)), forward_only=True) + def test_broadcast_full(self): for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul), (torch.div, Tensor.div), (torch.pow, Tensor.pow)]: diff --git a/test/unit/test_example.py b/test/unit/test_example.py index c64c9f6a..c64ebf93 100644 --- a/test/unit/test_example.py +++ b/test/unit/test_example.py @@ -1,6 +1,8 @@ import unittest +import numpy as np from tinygrad.lazy import Device from tinygrad.tensor import Tensor +from tinygrad.helpers import dtypes def multidevice_test(fxn): def ret(self): diff --git a/test/unit/test_symbolic.py b/test/unit/test_symbolic.py index 41ec6bf4..a0bc06a4 100644 --- a/test/unit/test_symbolic.py +++ b/test/unit/test_symbolic.py @@ -23,7 +23,7 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(Variable("a", 3, 8)<4, 0, 1, "(a<4)") self.helper_test_variable(Variable("a", 3, 8)<3, 0, 0, "0") self.helper_test_variable(Variable("a", 3, 8)<2, 0, 0, "0") - + def test_div_becomes_num(self): assert isinstance(Variable("a", 2, 3)//2, NumNode) @@ -101,7 +101,7 @@ class TestSymbolic(unittest.TestCase): def test_sum_div_no_factor(self): self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)") - + def test_mod_factor(self): # NOTE: even though the mod max is 50, it can't know this without knowing about the mul self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)") @@ -126,7 +126,7 @@ class TestSymbolic(unittest.TestCase): def test_mod_mul_sum(self): self.helper_test_variable(Variable.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, "(a+b)") - + def test_sum_0(self): self.helper_test_variable(Variable.sum([Variable("a", 0, 7)]), 0, 7, "a") @@ -164,7 +164,7 @@ class TestSymbolic(unittest.TestCase): def test_div_factor(self): self.helper_test_variable(Variable.sum([Variable.num(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)") - + def test_mul_div(self): self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a") @@ -177,6 +177,9 @@ class TestSymbolic(unittest.TestCase): def test_div_remove(self): self.helper_test_variable(Variable.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0") + def test_div_numerator_negative(self): + self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "((((idx*-10)+99)//11)+-9)") + class TestSymbolicNumeric(unittest.TestCase): def helper_test_numeric(self, f): # TODO: why are the negative tests broken? (even if we did support negative variables) diff --git a/tinygrad/codegen/gpu.py b/tinygrad/codegen/gpu.py index e7cbf6a8..3e2620c5 100644 --- a/tinygrad/codegen/gpu.py +++ b/tinygrad/codegen/gpu.py @@ -268,7 +268,7 @@ class GPUCodegen(ASTKernel): self.prekernel: Set[str] = set() self.kernel: List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(hasattr(buf._buf, "IMAGE") for buf in self.bufs if buf is not None) else [] - if self.lang.half_prekernel: self.prekernel.add(self.lang.half_prekernel+"\n") + if self.lang.half_prekernel and any(x.dtype == dtypes.float16 for x in self.bufs): self.prekernel.add(self.lang.half_prekernel+"\n") if len(self.lang.gid) == 0: self.kernel += [f"for (int idx{i} = 0; idx{i} < {self.output_shape[i]}; idx{i}++) {{\n" for i in range(0, len(self.output_shape))] diff --git a/tinygrad/shape/symbolic.py b/tinygrad/shape/symbolic.py index 2021fc11..e821cb65 100644 --- a/tinygrad/shape/symbolic.py +++ b/tinygrad/shape/symbolic.py @@ -34,14 +34,14 @@ class Node: # *** complex ops *** - def __floordiv__(self, b:int): + def __floordiv__(self, b:int, factoring_allowed=True): assert b != 0 if b < 0: return (self//-b)*-1 if b == 1: return self if isinstance(self, DivNode): return self.a//(self.b*b) # two divs is one div if isinstance(self, MulNode) and self.b % b == 0: return self.a*(self.b//b) if isinstance(self, MulNode) and b % self.b == 0: return self.a//(b//self.b) - if isinstance(self, SumNode): + if isinstance(self, SumNode) and factoring_allowed: factors, tmp_nofactor = partition(self.nodes, lambda x: (isinstance(x, (MulNode, NumNode))) and x.b%b == 0) nofactor = [] # ugh, i doubt this is universally right @@ -65,9 +65,11 @@ class Node: for m in muls: if m > 1 and b%m == 0: return (self//m)//(b//m) + # the numerator of div is not allowed to be negative if self.min < 0: offset = self.min//b - return (self+offset*b)//b - offset + # factor out an "offset" to make the numerator positive. don't allowing factoring again + return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset return create_opnode(DivNode, self, b) def __mod__(self, b:int):