fix flip bug, add new unit tests

This commit is contained in:
George Hotz 2023-03-12 23:55:31 -07:00
parent a4abcf0969
commit c594a0a835
6 changed files with 29 additions and 19 deletions

View File

@ -2,7 +2,7 @@
Welcome to the tinygrad documentation
=================
this file will take you on a whirlwind journey from a Tensor to a Byte
this file will take you on a whirlwind journey from a Tensor all the way down
tinygrad has been aggressively refactored in the 2.5 years it's been worked on.
what you see here is a refined library (with more refining to go still!)
@ -11,23 +11,21 @@ this documentation will help with entry points and understanding the abstraction
"""
# %%
# == Boilerplate imports (typing mostly) ==
# == Boilerplate imports for typing ==
from __future__ import annotations
from typing import Optional, Tuple, Union, Any, Dict, Callable, Type, List
from enum import Enum, auto
from abc import ABC
import numpy as np
import torch
# %%
# == Example: Tensor 2+3 ==
# Let's trace an addition down through the layers of abstraction.
# We will be using the clang backend
# let's trace an addition down through the layers of abstraction.
# we will be using the clang backend
from tinygrad.lazy import Device
Device.DEFAULT = "CLANG"
# first, 2+3 as a Tensor
# first, 2+3 as a Tensor, the highest level
from tinygrad.tensor import Tensor
a = Tensor([2])
b = Tensor([3])
@ -37,7 +35,7 @@ assert result.numpy()[0] == 5.
# %%
# == Tensor (in tinygrad/tensor.py, code 8/10) ==
# it's worth just reading tinygrad/tensor.py. it's pretty beautiful
# it's worth reading tinygrad/tensor.py. it's pretty beautiful
import tinygrad.mlops as mlops
# this is the good old familiar Tensor class
@ -52,7 +50,7 @@ class Tensor:
# this is where the data (and other tensor properties) actually live
lazydata: LazyBuffer
# high level ops (hlops) are defined here. example: relu
# high level ops (hlops) are defined on this class. example: relu
def relu(self): return self.maximum(0)
# log is an mlop, this is the wrapper function in Tensor
@ -88,8 +86,7 @@ class LazyBuffer:
realized: Optional[DeviceBuffer]
# LazyOp (in tinygrad/ops.py, code 4/10)
# they form an Abstract Syntax Tree for a single GPU kernel
# LazyOp is an AST node that defines:
# in a tree they form an Abstract Syntax Tree for a single GPU kernel
class LazyOp:
op: Op # the type of the compute
src: Tuple[Union[LazyOp, LazyBuffer], ...] # the sources
@ -164,6 +161,8 @@ class DeviceBuffer(ABC):
# InterpretedBuffers are a lot simpler than CompiledBuffers
# they are used to implement the CPU(numpy) and TORCH(torch) backends
# it's worth reading CPUBuffer (in tinygrad/runtime/ops_cpu.py, code 8/10)
import numpy as np
import torch
class InterpretedBuffer(DeviceBuffer):
# this is where the data actually lives
# finally some classes you recognize!

View File

@ -218,6 +218,10 @@ class TestOps(unittest.TestCase):
def test_scalar_rsub(self):
helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x)
def test_flip_eye_crash(self):
helper_test_op([], lambda: (torch.eye(10)@torch.eye(10).flip(0)),
lambda: (Tensor.eye(10)@Tensor.eye(10).flip(0)), forward_only=True)
def test_broadcast_full(self):
for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul),
(torch.div, Tensor.div), (torch.pow, Tensor.pow)]:

View File

@ -1,6 +1,8 @@
import unittest
import numpy as np
from tinygrad.lazy import Device
from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes
def multidevice_test(fxn):
def ret(self):

View File

@ -23,7 +23,7 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(Variable("a", 3, 8)<4, 0, 1, "(a<4)")
self.helper_test_variable(Variable("a", 3, 8)<3, 0, 0, "0")
self.helper_test_variable(Variable("a", 3, 8)<2, 0, 0, "0")
def test_div_becomes_num(self):
assert isinstance(Variable("a", 2, 3)//2, NumNode)
@ -101,7 +101,7 @@ class TestSymbolic(unittest.TestCase):
def test_sum_div_no_factor(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*5, Variable("b", 0, 3)*5]) // 2, 0, 25, "(((a*5)+(b*5))//2)")
def test_mod_factor(self):
# NOTE: even though the mod max is 50, it can't know this without knowing about the mul
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)")
@ -126,7 +126,7 @@ class TestSymbolic(unittest.TestCase):
def test_mod_mul_sum(self):
self.helper_test_variable(Variable.sum([Variable("b", 0, 2), Variable("a", 0, 5)*10])%9, 0, 7, "(a+b)")
def test_sum_0(self):
self.helper_test_variable(Variable.sum([Variable("a", 0, 7)]), 0, 7, "a")
@ -164,7 +164,7 @@ class TestSymbolic(unittest.TestCase):
def test_div_factor(self):
self.helper_test_variable(Variable.sum([Variable.num(-40), Variable("a", 0, 10)*2, Variable("b", 0, 10)*40]) // 40, -1, 9, "(-1+b)")
def test_mul_div(self):
self.helper_test_variable((Variable("a", 0, 10)*4)//4, 0, 10, "a")
@ -177,6 +177,9 @@ class TestSymbolic(unittest.TestCase):
def test_div_remove(self):
self.helper_test_variable(Variable.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0")
def test_div_numerator_negative(self):
self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "((((idx*-10)+99)//11)+-9)")
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):
# TODO: why are the negative tests broken? (even if we did support negative variables)

View File

@ -268,7 +268,7 @@ class GPUCodegen(ASTKernel):
self.prekernel: Set[str] = set()
self.kernel: List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(hasattr(buf._buf, "IMAGE") for buf in self.bufs if buf is not None) else []
if self.lang.half_prekernel: self.prekernel.add(self.lang.half_prekernel+"\n")
if self.lang.half_prekernel and any(x.dtype == dtypes.float16 for x in self.bufs): self.prekernel.add(self.lang.half_prekernel+"\n")
if len(self.lang.gid) == 0:
self.kernel += [f"for (int idx{i} = 0; idx{i} < {self.output_shape[i]}; idx{i}++) {{\n" for i in range(0, len(self.output_shape))]

View File

@ -34,14 +34,14 @@ class Node:
# *** complex ops ***
def __floordiv__(self, b:int):
def __floordiv__(self, b:int, factoring_allowed=True):
assert b != 0
if b < 0: return (self//-b)*-1
if b == 1: return self
if isinstance(self, DivNode): return self.a//(self.b*b) # two divs is one div
if isinstance(self, MulNode) and self.b % b == 0: return self.a*(self.b//b)
if isinstance(self, MulNode) and b % self.b == 0: return self.a//(b//self.b)
if isinstance(self, SumNode):
if isinstance(self, SumNode) and factoring_allowed:
factors, tmp_nofactor = partition(self.nodes, lambda x: (isinstance(x, (MulNode, NumNode))) and x.b%b == 0)
nofactor = []
# ugh, i doubt this is universally right
@ -65,9 +65,11 @@ class Node:
for m in muls:
if m > 1 and b%m == 0:
return (self//m)//(b//m)
# the numerator of div is not allowed to be negative
if self.min < 0:
offset = self.min//b
return (self+offset*b)//b - offset
# factor out an "offset" to make the numerator positive. don't allowing factoring again
return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
return create_opnode(DivNode, self, b)
def __mod__(self, b:int):