mirror of https://github.com/commaai/tinygrad.git
fix flip bug, add new unit tests
This commit is contained in:
parent
a4abcf0969
commit
c594a0a835
|
@ -2,7 +2,7 @@
|
||||||
Welcome to the tinygrad documentation
|
Welcome to the tinygrad documentation
|
||||||
=================
|
=================
|
||||||
|
|
||||||
this file will take you on a whirlwind journey from a Tensor to a Byte
|
this file will take you on a whirlwind journey from a Tensor all the way down
|
||||||
tinygrad has been aggressively refactored in the 2.5 years it's been worked on.
|
tinygrad has been aggressively refactored in the 2.5 years it's been worked on.
|
||||||
what you see here is a refined library (with more refining to go still!)
|
what you see here is a refined library (with more refining to go still!)
|
||||||
|
|
||||||
|
@ -11,23 +11,21 @@ this documentation will help with entry points and understanding the abstraction
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# == Boilerplate imports (typing mostly) ==
|
# == Boilerplate imports for typing ==
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Optional, Tuple, Union, Any, Dict, Callable, Type, List
|
from typing import Optional, Tuple, Union, Any, Dict, Callable, Type, List
|
||||||
from enum import Enum, auto
|
from enum import Enum, auto
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# == Example: Tensor 2+3 ==
|
# == Example: Tensor 2+3 ==
|
||||||
# Let's trace an addition down through the layers of abstraction.
|
# let's trace an addition down through the layers of abstraction.
|
||||||
# We will be using the clang backend
|
|
||||||
|
|
||||||
|
# we will be using the clang backend
|
||||||
from tinygrad.lazy import Device
|
from tinygrad.lazy import Device
|
||||||
Device.DEFAULT = "CLANG"
|
Device.DEFAULT = "CLANG"
|
||||||
|
|
||||||
# first, 2+3 as a Tensor
|
# first, 2+3 as a Tensor, the highest level
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
a = Tensor([2])
|
a = Tensor([2])
|
||||||
b = Tensor([3])
|
b = Tensor([3])
|
||||||
|
@ -37,7 +35,7 @@ assert result.numpy()[0] == 5.
|
||||||
|
|
||||||
# %%
|
# %%
|
||||||
# == Tensor (in tinygrad/tensor.py, code 8/10) ==
|
# == Tensor (in tinygrad/tensor.py, code 8/10) ==
|
||||||
# it's worth just reading tinygrad/tensor.py. it's pretty beautiful
|
# it's worth reading tinygrad/tensor.py. it's pretty beautiful
|
||||||
import tinygrad.mlops as mlops
|
import tinygrad.mlops as mlops
|
||||||
|
|
||||||
# this is the good old familiar Tensor class
|
# this is the good old familiar Tensor class
|
||||||
|
@ -52,7 +50,7 @@ class Tensor:
|
||||||
# this is where the data (and other tensor properties) actually live
|
# this is where the data (and other tensor properties) actually live
|
||||||
lazydata: LazyBuffer
|
lazydata: LazyBuffer
|
||||||
|
|
||||||
# high level ops (hlops) are defined here. example: relu
|
# high level ops (hlops) are defined on this class. example: relu
|
||||||
def relu(self): return self.maximum(0)
|
def relu(self): return self.maximum(0)
|
||||||
|
|
||||||
# log is an mlop, this is the wrapper function in Tensor
|
# log is an mlop, this is the wrapper function in Tensor
|
||||||
|
@ -88,8 +86,7 @@ class LazyBuffer:
|
||||||
realized: Optional[DeviceBuffer]
|
realized: Optional[DeviceBuffer]
|
||||||
|
|
||||||
# LazyOp (in tinygrad/ops.py, code 4/10)
|
# LazyOp (in tinygrad/ops.py, code 4/10)
|
||||||
# they form an Abstract Syntax Tree for a single GPU kernel
|
# in a tree they form an Abstract Syntax Tree for a single GPU kernel
|
||||||
# LazyOp is an AST node that defines:
|
|
||||||
class LazyOp:
|
class LazyOp:
|
||||||
op: Op # the type of the compute
|
op: Op # the type of the compute
|
||||||
src: Tuple[Union[LazyOp, LazyBuffer], ...] # the sources
|
src: Tuple[Union[LazyOp, LazyBuffer], ...] # the sources
|
||||||
|
@ -164,6 +161,8 @@ class DeviceBuffer(ABC):
|
||||||
# InterpretedBuffers are a lot simpler than CompiledBuffers
|
# InterpretedBuffers are a lot simpler than CompiledBuffers
|
||||||
# they are used to implement the CPU(numpy) and TORCH(torch) backends
|
# they are used to implement the CPU(numpy) and TORCH(torch) backends
|
||||||
# it's worth reading CPUBuffer (in tinygrad/runtime/ops_cpu.py, code 8/10)
|
# it's worth reading CPUBuffer (in tinygrad/runtime/ops_cpu.py, code 8/10)
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
class InterpretedBuffer(DeviceBuffer):
|
class InterpretedBuffer(DeviceBuffer):
|
||||||
# this is where the data actually lives
|
# this is where the data actually lives
|
||||||
# finally some classes you recognize!
|
# finally some classes you recognize!
|
||||||
|
|
|
@ -218,6 +218,10 @@ class TestOps(unittest.TestCase):
|
||||||
def test_scalar_rsub(self):
|
def test_scalar_rsub(self):
|
||||||
helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x)
|
helper_test_op([(45,65)], lambda x: 2-x, lambda x: 2-x)
|
||||||
|
|
||||||
|
def test_flip_eye_crash(self):
|
||||||
|
helper_test_op([], lambda: (torch.eye(10)@torch.eye(10).flip(0)),
|
||||||
|
lambda: (Tensor.eye(10)@Tensor.eye(10).flip(0)), forward_only=True)
|
||||||
|
|
||||||
def test_broadcast_full(self):
|
def test_broadcast_full(self):
|
||||||
for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul),
|
for torch_op, tinygrad_op in [(torch.add, Tensor.add), (torch.sub, Tensor.sub), (torch.mul, Tensor.mul),
|
||||||
(torch.div, Tensor.div), (torch.pow, Tensor.pow)]:
|
(torch.div, Tensor.div), (torch.pow, Tensor.pow)]:
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
import numpy as np
|
||||||
from tinygrad.lazy import Device
|
from tinygrad.lazy import Device
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
|
from tinygrad.helpers import dtypes
|
||||||
|
|
||||||
def multidevice_test(fxn):
|
def multidevice_test(fxn):
|
||||||
def ret(self):
|
def ret(self):
|
||||||
|
|
|
@ -177,6 +177,9 @@ class TestSymbolic(unittest.TestCase):
|
||||||
def test_div_remove(self):
|
def test_div_remove(self):
|
||||||
self.helper_test_variable(Variable.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0")
|
self.helper_test_variable(Variable.sum([Variable("idx0", 0, 127)*4, Variable("idx2", 0, 3)])//4, 0, 127, "idx0")
|
||||||
|
|
||||||
|
def test_div_numerator_negative(self):
|
||||||
|
self.helper_test_variable((Variable("idx", 0, 9)*-10)//11, -9, 0, "((((idx*-10)+99)//11)+-9)")
|
||||||
|
|
||||||
class TestSymbolicNumeric(unittest.TestCase):
|
class TestSymbolicNumeric(unittest.TestCase):
|
||||||
def helper_test_numeric(self, f):
|
def helper_test_numeric(self, f):
|
||||||
# TODO: why are the negative tests broken? (even if we did support negative variables)
|
# TODO: why are the negative tests broken? (even if we did support negative variables)
|
||||||
|
|
|
@ -268,7 +268,7 @@ class GPUCodegen(ASTKernel):
|
||||||
self.prekernel: Set[str] = set()
|
self.prekernel: Set[str] = set()
|
||||||
self.kernel: List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(hasattr(buf._buf, "IMAGE") for buf in self.bufs if buf is not None) else []
|
self.kernel: List[str] = ["const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"] if any(hasattr(buf._buf, "IMAGE") for buf in self.bufs if buf is not None) else []
|
||||||
|
|
||||||
if self.lang.half_prekernel: self.prekernel.add(self.lang.half_prekernel+"\n")
|
if self.lang.half_prekernel and any(x.dtype == dtypes.float16 for x in self.bufs): self.prekernel.add(self.lang.half_prekernel+"\n")
|
||||||
|
|
||||||
if len(self.lang.gid) == 0:
|
if len(self.lang.gid) == 0:
|
||||||
self.kernel += [f"for (int idx{i} = 0; idx{i} < {self.output_shape[i]}; idx{i}++) {{\n" for i in range(0, len(self.output_shape))]
|
self.kernel += [f"for (int idx{i} = 0; idx{i} < {self.output_shape[i]}; idx{i}++) {{\n" for i in range(0, len(self.output_shape))]
|
||||||
|
|
|
@ -34,14 +34,14 @@ class Node:
|
||||||
|
|
||||||
# *** complex ops ***
|
# *** complex ops ***
|
||||||
|
|
||||||
def __floordiv__(self, b:int):
|
def __floordiv__(self, b:int, factoring_allowed=True):
|
||||||
assert b != 0
|
assert b != 0
|
||||||
if b < 0: return (self//-b)*-1
|
if b < 0: return (self//-b)*-1
|
||||||
if b == 1: return self
|
if b == 1: return self
|
||||||
if isinstance(self, DivNode): return self.a//(self.b*b) # two divs is one div
|
if isinstance(self, DivNode): return self.a//(self.b*b) # two divs is one div
|
||||||
if isinstance(self, MulNode) and self.b % b == 0: return self.a*(self.b//b)
|
if isinstance(self, MulNode) and self.b % b == 0: return self.a*(self.b//b)
|
||||||
if isinstance(self, MulNode) and b % self.b == 0: return self.a//(b//self.b)
|
if isinstance(self, MulNode) and b % self.b == 0: return self.a//(b//self.b)
|
||||||
if isinstance(self, SumNode):
|
if isinstance(self, SumNode) and factoring_allowed:
|
||||||
factors, tmp_nofactor = partition(self.nodes, lambda x: (isinstance(x, (MulNode, NumNode))) and x.b%b == 0)
|
factors, tmp_nofactor = partition(self.nodes, lambda x: (isinstance(x, (MulNode, NumNode))) and x.b%b == 0)
|
||||||
nofactor = []
|
nofactor = []
|
||||||
# ugh, i doubt this is universally right
|
# ugh, i doubt this is universally right
|
||||||
|
@ -65,9 +65,11 @@ class Node:
|
||||||
for m in muls:
|
for m in muls:
|
||||||
if m > 1 and b%m == 0:
|
if m > 1 and b%m == 0:
|
||||||
return (self//m)//(b//m)
|
return (self//m)//(b//m)
|
||||||
|
# the numerator of div is not allowed to be negative
|
||||||
if self.min < 0:
|
if self.min < 0:
|
||||||
offset = self.min//b
|
offset = self.min//b
|
||||||
return (self+offset*b)//b - offset
|
# factor out an "offset" to make the numerator positive. don't allowing factoring again
|
||||||
|
return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
|
||||||
return create_opnode(DivNode, self, b)
|
return create_opnode(DivNode, self, b)
|
||||||
|
|
||||||
def __mod__(self, b:int):
|
def __mod__(self, b:int):
|
||||||
|
|
Loading…
Reference in New Issue