fix flip bug, add new unit tests

This commit is contained in:
George Hotz 2023-03-12 23:55:31 -07:00
parent a4abcf0969
commit c594a0a835
6 changed files with 29 additions and 19 deletions

View File

@ -2,7 +2,7 @@
Welcome to the tinygrad documentation Welcome to the tinygrad documentation
================= =================
this file will take you on a whirlwind journey from a Tensor to a Byte this file will take you on a whirlwind journey from a Tensor all the way down
tinygrad has been aggressively refactored in the 2.5 years it's been worked on. tinygrad has been aggressively refactored in the 2.5 years it's been worked on.
what you see here is a refined library (with more refining to go still!) what you see here is a refined library (with more refining to go still!)
@ -11,23 +11,21 @@ this documentation will help with entry points and understanding the abstraction
""" """
# %% # %%
# == Boilerplate imports (typing mostly) == # == Boilerplate imports for typing ==
from __future__ import annotations from __future__ import annotations
from typing import Optional, Tuple, Union, Any, Dict, Callable, Type, List from typing import Optional, Tuple, Union, Any, Dict, Callable, Type, List
from enum import Enum, auto from enum import Enum, auto
from abc import ABC from abc import ABC
import numpy as np
import torch
# %% # %%
# == Example: Tensor 2+3 == # == Example: Tensor 2+3 ==
# Let's trace an addition down through the layers of abstraction. # let's trace an addition down through the layers of abstraction.
# We will be using the clang backend
# we will be using the clang backend
from tinygrad.lazy import Device from tinygrad.lazy import Device
Device.DEFAULT = "CLANG" Device.DEFAULT = "CLANG"
# first, 2+3 as a Tensor # first, 2+3 as a Tensor, the highest level
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
a = Tensor([2]) a = Tensor([2])
b = Tensor([3]) b = Tensor([3])
@ -37,7 +35,7 @@ assert result.numpy()[0] == 5.
# %% # %%
# == Tensor (in tinygrad/tensor.py, code 8/10) == # == Tensor (in tinygrad/tensor.py, code 8/10) ==
# it's worth just reading tinygrad/tensor.py. it's pretty beautiful # it's worth reading tinygrad/tensor.py. it's pretty beautiful
import tinygrad.mlops as mlops import tinygrad.mlops as mlops
# this is the good old familiar Tensor class # this is the good old familiar Tensor class
@ -52,7 +50,7 @@ class Tensor:
# this is where the data (and other tensor properties) actually live # this is where the data (and other tensor properties) actually live
lazydata: LazyBuffer lazydata: LazyBuffer
# high level ops (hlops) are defined here. example: relu # high level ops (hlops) are defined on this class. example: relu
def relu(self): return self.maximum(0) def relu(self): return self.maximum(0)
# log is an mlop, this is the wrapper function in Tensor # log is an mlop, this is the wrapper function in Tensor
@ -88,8 +86,7 @@ class LazyBuffer:
realized: Optional[DeviceBuffer] realized: Optional[DeviceBuffer]
# LazyOp (in tinygrad/ops.py, code 4/10) # LazyOp (in tinygrad/ops.py, code 4/10)
# they form an Abstract Syntax Tree for a single GPU kernel # in a tree they form an Abstract Syntax Tree for a single GPU kernel
# LazyOp is an AST node that defines:
class LazyOp: class LazyOp:
op: Op # the type of the compute op: Op # the type of the compute
src: Tuple[Union[LazyOp, LazyBuffer], ...] # the sources src: Tuple[Union[LazyOp, LazyBuffer], ...] # the sources
@ -164,6 +161,8 @@ class DeviceBuffer(ABC):
# InterpretedBuffers are a lot simpler than CompiledBuffers # InterpretedBuffers are a lot simpler than CompiledBuffers
# they are used to implement the CPU(numpy) and TORCH(torch) backends # they are used to implement the CPU(numpy) and TORCH(torch) backends
# it's worth reading CPUBuffer (in tinygrad/runtime/ops_cpu.py, code 8/10) # it's worth reading CPUBuffer (in tinygrad/runtime/ops_cpu.py, code 8/10)
import numpy as np
import torch
class InterpretedBuffer(DeviceBuffer): class InterpretedBuffer(DeviceBuffer):
# this is where the data actually lives # this is where the data actually lives
# finally some classes you recognize! # finally some classes you recognize!

View File

@ -218,6 +218,10 @@ class TestOps(unittest.TestCase):
def test_scalar_rsub(self): def test_scalar_rsub(self):
helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x) helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x)
def test_flip_eye_crash(self):
helper_test_op([], lambda: (torch.eye(10)@torch.eye(10).flip(0)),
lambda: (Tensor.eye(10)@Tensor.eye(10).flip(0)), forward_only=True)
def test_broadcast_full(self): def test_broadcast_full(self):
for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul), for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul),
(torch.div, Tensor.div), (torch.pow, Tensor.pow)]: (torch.div, Tensor.div), (torch.pow, Tensor.pow)]:

View File

@ -1,6 +1,8 @@
import unittest import unittest
import numpy as np
from tinygrad.lazy import Device from tinygrad.lazy import Device
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes
def multidevice_test(fxn): def multidevice_test(fxn):
def ret(self): def ret(self):

View File

@ -23,7 +23,7 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(Variable("a", 3, 8)<4, 0, 1, "(a<4)") self.helper_test_variable(Variable("a", 3, 8)<4, 0, 1, "(a<4)")
self.helper_test_variable(Variable("a", 3, 8)<3, 0, 0, "0") self.helper_test_variable(Variable("a", 3, 8)<3, 0, 0, "0")
self.helper_test_variable(Variable("a", 3, 8)<2, 0, 0, "0") self.helper_test_variable(Variable("a", 3, 8)<2, 0, 0, "0")
def test_div_becomes_num(self): def test_div_becomes_num(self):
assert isinstance(Variable("a", 2, 3)//2, NumNode) assert isinstance(Variable("a", 2, 3)//2, NumNode)
@ -101,7 +101,7 @@ class TestSymbolic(unittest.TestCase):
def test_sum_div_no_factor(self): def test_sum_div_no_factor(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)") self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
def test_mod_factor(self): def test_mod_factor(self):
# NOTE: even though the mod max is 50, it can't know this without knowing about the mul # NOTE: even though the mod max is 50, it can't know this without knowing about the mul
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)") self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)")
@ -126,7 +126,7 @@ class TestSymbolic(unittest.TestCase):
def test_mod_mul_sum(self): def test_mod_mul_sum(self):
self.helper_test_variable(Variable.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, "(a+b)") self.helper_test_variable(Variable.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, "(a+b)")
def test_sum_0(self): def test_sum_0(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)]), 0, 7, "a") self.helper_test_variable(Variable.sum([Variable("a", 0, 7)]), 0, 7, "a")
@ -164,7 +164,7 @@ class TestSymbolic(unittest.TestCase):
def test_div_factor(self): def test_div_factor(self):
self.helper_test_variable(Variable.sum([Variable.num(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)") self.helper_test_variable(Variable.sum([Variable.num(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)")
def test_mul_div(self): def test_mul_div(self):
self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a") self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a")
@ -177,6 +177,9 @@ class TestSymbolic(unittest.TestCase):
def test_div_remove(self): def test_div_remove(self):
self.helper_test_variable(Variable.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0") self.helper_test_variable(Variable.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0")
def test_div_numerator_negative(self):
self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "((((idx*-10)+99)//11)+-9)")
class TestSymbolicNumeric(unittest.TestCase): class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f): def helper_test_numeric(self, f):
# TODO: why are the negative tests broken? (even if we did support negative variables) # TODO: why are the negative tests broken? (even if we did support negative variables)

View File

@ -268,7 +268,7 @@ class GPUCodegen(ASTKernel):
self.prekernel: Set[str] = set() self.prekernel: Set[str] = set()
self.kernel: List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(hasattr(buf._buf, "IMAGE") for buf in self.bufs if buf is not None) else [] self.kernel: List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(hasattr(buf._buf, "IMAGE") for buf in self.bufs if buf is not None) else []
if self.lang.half_prekernel: self.prekernel.add(self.lang.half_prekernel+"\n") if self.lang.half_prekernel and any(x.dtype == dtypes.float16 for x in self.bufs): self.prekernel.add(self.lang.half_prekernel+"\n")
if len(self.lang.gid) == 0: if len(self.lang.gid) == 0:
self.kernel += [f"for (int idx{i} = 0; idx{i} < {self.output_shape[i]}; idx{i}++) {{\n" for i in range(0, len(self.output_shape))] self.kernel += [f"for (int idx{i} = 0; idx{i} < {self.output_shape[i]}; idx{i}++) {{\n" for i in range(0, len(self.output_shape))]

View File

@ -34,14 +34,14 @@ class Node:
# *** complex ops *** # *** complex ops ***
def __floordiv__(self, b:int): def __floordiv__(self, b:int, factoring_allowed=True):
assert b != 0 assert b != 0
if b < 0: return (self//-b)*-1 if b < 0: return (self//-b)*-1
if b == 1: return self if b == 1: return self
if isinstance(self, DivNode): return self.a//(self.b*b) # two divs is one div if isinstance(self, DivNode): return self.a//(self.b*b) # two divs is one div
if isinstance(self, MulNode) and self.b % b == 0: return self.a*(self.b//b) if isinstance(self, MulNode) and self.b % b == 0: return self.a*(self.b//b)
if isinstance(self, MulNode) and b % self.b == 0: return self.a//(b//self.b) if isinstance(self, MulNode) and b % self.b == 0: return self.a//(b//self.b)
if isinstance(self, SumNode): if isinstance(self, SumNode) and factoring_allowed:
factors, tmp_nofactor = partition(self.nodes, lambda x: (isinstance(x, (MulNode, NumNode))) and x.b%b == 0) factors, tmp_nofactor = partition(self.nodes, lambda x: (isinstance(x, (MulNode, NumNode))) and x.b%b == 0)
nofactor = [] nofactor = []
# ugh, i doubt this is universally right # ugh, i doubt this is universally right
@ -65,9 +65,11 @@ class Node:
for m in muls: for m in muls:
if m > 1 and b%m == 0: if m > 1 and b%m == 0:
return (self//m)//(b//m) return (self//m)//(b//m)
# the numerator of div is not allowed to be negative
if self.min < 0: if self.min < 0:
offset = self.min//b offset = self.min//b
return (self+offset*b)//b - offset # factor out an "offset" to make the numerator positive. don't allowing factoring again
return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
return create_opnode(DivNode, self, b) return create_opnode(DivNode, self, b)
def __mod__(self, b:int): def __mod__(self, b:int):