minor dtype cleanup [pr] (#7124)

* minor dtype cleanup [pr]

* use ptr() function
This commit is contained in:
George Hotz 2024-10-17 17:41:23 +08:00 committed by GitHub
parent 0b2621f63f
commit ded1b38b84
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 407 additions and 415 deletions

View File

@ -37,7 +37,7 @@ print("******** second, the Device ***********")
DEVICE = "CLANG" # NOTE: you can change this! DEVICE = "CLANG" # NOTE: you can change this!
import struct import struct
from tinygrad.dtype import PtrDType, dtypes from tinygrad.dtype import dtypes
from tinygrad.device import Buffer, Device from tinygrad.device import Buffer, Device
from tinygrad.ops import BinaryOps, MetaOps, UOp, UOps from tinygrad.ops import BinaryOps, MetaOps, UOp, UOps
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
@ -49,12 +49,12 @@ b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struc
# NOTE: a._buf is the same as the return from MallocAllocator.alloc # NOTE: a._buf is the same as the return from MallocAllocator.alloc
# describe the computation # describe the computation
buf_1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 1) buf_1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 1)
buf_2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 2) buf_2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 2)
ld_1 = UOp(UOps.LOAD, dtypes.int32, (buf_1, ShapeTracker.from_shape((1,)).to_uop())) ld_1 = UOp(UOps.LOAD, dtypes.int32, (buf_1, ShapeTracker.from_shape((1,)).to_uop()))
ld_2 = UOp(UOps.LOAD, dtypes.int32, (buf_2, ShapeTracker.from_shape((1,)).to_uop())) ld_2 = UOp(UOps.LOAD, dtypes.int32, (buf_2, ShapeTracker.from_shape((1,)).to_uop()))
alu = ld_1 + ld_2 alu = ld_1 + ld_2
output_buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 0) output_buf = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
st_0 = UOp(UOps.STORE, dtypes.void, (output_buf, ShapeTracker.from_shape((1,)).to_uop(), alu)) st_0 = UOp(UOps.STORE, dtypes.void, (output_buf, ShapeTracker.from_shape((1,)).to_uop(), alu))
s = UOp(UOps.SINK, dtypes.void, (st_0,)) s = UOp(UOps.SINK, dtypes.void, (st_0,))

View File

@ -4,7 +4,7 @@ import unittest
from tinygrad import Device from tinygrad import Device
from tinygrad.ops import UOp, UOps, BinaryOps, UnaryOps from tinygrad.ops import UOp, UOps, BinaryOps, UnaryOps
from tinygrad.engine.search import Opt, OptOps from tinygrad.engine.search import Opt, OptOps
from tinygrad.dtype import dtypes, PtrDType from tinygrad.dtype import dtypes
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View from tinygrad.shape.view import View
from tinygrad.codegen.kernel import Kernel from tinygrad.codegen.kernel import Kernel
@ -31,7 +31,7 @@ class TestOpenpilotValidhack(unittest.TestCase):
UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((8, 108, 4)), arg=2, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((8, 108, 4)), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 3, 4, 3, 3), strides=(0, 0, 0, 0, 0, 432, 1, 48, 4, 144, 16), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 3, 4, 3, 3), strides=(0, 0, 0, 0, 0, 432, 1, 48, 4, 144, 16), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 4, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 4, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
x19:=UOp(UOps.CONST, dtypes.float, arg=0.0, src=( x19:=UOp(UOps.CONST, dtypes.float, arg=0.0, src=(
x20:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x20:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
@ -81,7 +81,7 @@ class TestOpenpilotValidhack(unittest.TestCase):
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
x18:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), x18:=UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=48128, mask=((0, 1), (1, 2), (0, 512)), contiguous=False),)), src=()),)), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=48128, mask=((0, 1), (1, 2), (0, 512)), contiguous=False),)), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
x18, x18,

View File

@ -312,15 +312,15 @@ class TestEqStrDType(unittest.TestCase):
def test_ptr_ne(self): def test_ptr_ne(self):
if PtrDType is None: raise unittest.SkipTest("no PtrDType support") if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
# TODO: is this the wrong behavior? # TODO: is this the wrong behavior?
assert PtrDType(dtypes.float32) == dtypes.float32 assert dtypes.float32.ptr() == dtypes.float32
assert not (PtrDType(dtypes.float32) != dtypes.float32) assert not (dtypes.float32.ptr() != dtypes.float32)
assert PtrDType(dtypes.float32) == PtrDType(dtypes.float32) assert dtypes.float32.ptr() == dtypes.float32.ptr()
assert not (PtrDType(dtypes.float32) != PtrDType(dtypes.float32)) assert not (dtypes.float32.ptr() != dtypes.float32.ptr())
#assert PtrDType(dtypes.float32) != dtypes.float32 #assert dtypes.float32.ptr() != dtypes.float32
def test_strs(self): def test_strs(self):
if PtrDType is None: raise unittest.SkipTest("no PtrDType support") if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))") self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
self.assertEqual(str(PtrDType(dtypes.float32)), "PtrDType(dtypes.float)") self.assertEqual(str(dtypes.float32.ptr()), "dtypes.float.ptr()")
class TestHelpers(unittest.TestCase): class TestHelpers(unittest.TestCase):
signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)

View File

@ -15,7 +15,7 @@ from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule from tinygrad.engine.schedule import BUF_LIMIT, create_schedule
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner
from tinygrad.helpers import prod, Context, getenv, CI, flatten, dedup, AMX from tinygrad.helpers import prod, Context, getenv, CI, flatten, dedup, AMX
from tinygrad.dtype import DType, PtrDType, dtypes from tinygrad.dtype import DType, dtypes
def helper_realized_ast(r:Union[Tensor, List[Tensor]]) -> Tuple[UOp, List[Buffer]]: def helper_realized_ast(r:Union[Tensor, List[Tensor]]) -> Tuple[UOp, List[Buffer]]:
if isinstance(r, Tensor): r = [r] if isinstance(r, Tensor): r = [r]
@ -84,7 +84,7 @@ class TestLinearizer(unittest.TestCase):
def test_multioutput(self): def test_multioutput(self):
dtype, st = dtypes.int, ShapeTracker.from_shape((8,)) dtype, st = dtypes.int, ShapeTracker.from_shape((8,))
g0, g1, g2, g3 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtype), arg=i) for i in range(4)] g0, g1, g2, g3 = [UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), arg=i) for i in range(4)]
a = UOp(UOps.LOAD, dtype, (g2, st.to_uop())) a = UOp(UOps.LOAD, dtype, (g2, st.to_uop()))
b = UOp(UOps.LOAD, dtype, (g3, st.to_uop())) b = UOp(UOps.LOAD, dtype, (g3, st.to_uop()))
out0 = UOp(UOps.STORE, dtypes.void, (g0, st.to_uop(), a + b)) out0 = UOp(UOps.STORE, dtypes.void, (g0, st.to_uop(), a + b))
@ -107,7 +107,7 @@ class TestLinearizer(unittest.TestCase):
Tensor.manual_seed(0) Tensor.manual_seed(0)
x = Tensor.randn(32, dtype=dtypes.float).realize() x = Tensor.randn(32, dtype=dtypes.float).realize()
st_x = x.lazydata.st st_x = x.lazydata.st
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((1, 32)).expand((32, 32)).to_uop())) first_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((1, 32)).expand((32, 32)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (1,))) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (1,)))
second_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((32, 1)).to_uop())) second_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((32, 1)).to_uop()))
@ -143,7 +143,7 @@ class TestLinearizer(unittest.TestCase):
Tensor.manual_seed(0) Tensor.manual_seed(0)
x = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize() x = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
st_x = x.lazydata.st st_x = x.lazydata.st
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((27, 1, 32, 5)).expand((27, 32, 32, 5)).to_uop())) first_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((27, 1, 32, 5)).expand((27, 32, 32, 5)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
second_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((27, 32, 1, 5)).to_uop())) second_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((27, 32, 1, 5)).to_uop()))
@ -205,7 +205,7 @@ class TestLinearizer(unittest.TestCase):
x0 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize() x0 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
x1 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize() x1 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
x2 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize() x2 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
g0, g1, g2, g3 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(4)] g0, g1, g2, g3 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(4)]
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x0.lazydata.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5)).to_uop())) first_x = UOp(UOps.LOAD, dtypes.float, (g1, x0.lazydata.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
second_x = UOp(UOps.LOAD, dtypes.float, (g2, x1.lazydata.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5)).to_uop())) second_x = UOp(UOps.LOAD, dtypes.float, (g2, x1.lazydata.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5)).to_uop()))
@ -232,7 +232,7 @@ class TestLinearizer(unittest.TestCase):
Tensor.manual_seed(0) Tensor.manual_seed(0)
x = Tensor.randn(8, 32, 8, 16, dtype=dtypes.float).realize() x = Tensor.randn(8, 32, 8, 16, dtype=dtypes.float).realize()
st = x.lazydata.st st = x.lazydata.st
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(UOps.LOAD, dtypes.float, (g1, st.reshape((8, 1, 32, 8, 1, 16)).expand((8, 32, 32, 8, 16, 16)).to_uop())) first_x = UOp(UOps.LOAD, dtypes.float, (g1, st.reshape((8, 1, 32, 8, 1, 16)).expand((8, 32, 32, 8, 16, 16)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2, 5))) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2, 5)))
second_x = UOp(UOps.LOAD, dtypes.float, (g1, st.reshape((8, 32, 1, 8, 16, 1)).to_uop())) second_x = UOp(UOps.LOAD, dtypes.float, (g1, st.reshape((8, 32, 1, 8, 16, 1)).to_uop()))
@ -283,7 +283,7 @@ class TestLinearizer(unittest.TestCase):
# check how it works with one reduce optimized and one unoptimized # check how it works with one reduce optimized and one unoptimized
Tensor.manual_seed(0) Tensor.manual_seed(0)
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).softmax(1).realize() x = Tensor.randn(27, 15, 5, dtype=dtypes.float).softmax(1).realize()
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop())) first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop())) second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
@ -314,7 +314,7 @@ class TestLinearizer(unittest.TestCase):
Tensor.manual_seed(0) Tensor.manual_seed(0)
x = Tensor.randn(4, 32, dtype=dtypes.float).realize() x = Tensor.randn(4, 32, dtype=dtypes.float).realize()
x_p = Tensor.randn(4, 32, dtype=dtypes.float).realize() x_p = Tensor.randn(4, 32, dtype=dtypes.float).realize()
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)] g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop())) first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop()))
first_x_p = UOp(UOps.LOAD, dtypes.float, (g2, x_p.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop())) first_x_p = UOp(UOps.LOAD, dtypes.float, (g2, x_p.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
@ -350,7 +350,7 @@ class TestLinearizer(unittest.TestCase):
# check how multireduce works with multioutput # check how multireduce works with multioutput
Tensor.manual_seed(0) Tensor.manual_seed(0)
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize() x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize()
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)] g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
first_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop())) first_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
second_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop())) second_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
@ -373,7 +373,7 @@ class TestLinearizer(unittest.TestCase):
# if we change the shape of store1 to be contiguous, it will match store0 but not the value it's storing (FAIL!) # if we change the shape of store1 to be contiguous, it will match store0 but not the value it's storing (FAIL!)
Tensor.manual_seed(0) Tensor.manual_seed(0)
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize() x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize()
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)] g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
first_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop())) first_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
second_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop())) second_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
@ -397,7 +397,7 @@ class TestLinearizer(unittest.TestCase):
def test_complete_unroll_multireduce(self): def test_complete_unroll_multireduce(self):
Tensor.manual_seed(0) Tensor.manual_seed(0)
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize() x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop())) first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop())) second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop()))
@ -413,7 +413,7 @@ class TestLinearizer(unittest.TestCase):
def test_upcast_multireduce(self): def test_upcast_multireduce(self):
Tensor.manual_seed(0) Tensor.manual_seed(0)
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize() x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop())) first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop())) second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop()))
@ -432,7 +432,7 @@ class TestLinearizer(unittest.TestCase):
# make sure the if block of a grouped reduce can be closed early and the result loaded back in # make sure the if block of a grouped reduce can be closed early and the result loaded back in
Tensor.manual_seed(0) Tensor.manual_seed(0)
x = Tensor.randn(27, 12, 5, dtype=dtypes.float).realize() x = Tensor.randn(27, 12, 5, dtype=dtypes.float).realize()
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5)).to_uop())) first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 12, 1, 5)).to_uop())) second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 12, 1, 5)).to_uop()))
@ -448,7 +448,7 @@ class TestLinearizer(unittest.TestCase):
def test_mean_std_multireduce(self): def test_mean_std_multireduce(self):
Tensor.manual_seed(0) Tensor.manual_seed(0)
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop())) first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1)) neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1))
@ -466,7 +466,7 @@ class TestLinearizer(unittest.TestCase):
def test_mean_std_multireduce_mid_dim(self): def test_mean_std_multireduce_mid_dim(self):
Tensor.manual_seed(0) Tensor.manual_seed(0)
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35)).to_uop())) first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
neg_mean = first_reduce * ast_const(dtypes.float, -0.04, (15, 25, 1, 35)) neg_mean = first_reduce * ast_const(dtypes.float, -0.04, (15, 25, 1, 35))
@ -486,7 +486,7 @@ class TestLinearizer(unittest.TestCase):
# TODO: Similar error to test_multiout_intermediate_multireduce (implicit expand vs shape mismatch) # TODO: Similar error to test_multiout_intermediate_multireduce (implicit expand vs shape mismatch)
Tensor.manual_seed(0) Tensor.manual_seed(0)
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)] g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
first_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop())) first_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1)) neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1))
@ -508,7 +508,7 @@ class TestLinearizer(unittest.TestCase):
def test_var_multireduce(self): def test_var_multireduce(self):
Tensor.manual_seed(0) Tensor.manual_seed(0)
x = Tensor.randn(3, 27, 32, dtype=dtypes.float).realize() x = Tensor.randn(3, 27, 32, dtype=dtypes.float).realize()
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
# push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD # push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32)).to_uop())) first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32)).to_uop()))
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
@ -530,7 +530,7 @@ class TestLinearizer(unittest.TestCase):
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
def test_softmax_multireduce(self): def test_softmax_multireduce(self):
x = Tensor.rand(4, 32).realize() x = Tensor.rand(4, 32).realize()
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32,)).expand((4, 32, 32)).to_uop())) first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32,)).expand((4, 32, 32)).to_uop()))
max_x = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.MAX, (2,))) max_x = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.MAX, (2,)))
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1,)).to_uop())) second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1,)).to_uop()))
@ -559,15 +559,15 @@ class TestLinearizer(unittest.TestCase):
arange = UOp(UOps.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (BinaryOps.ADD, arange_axis)) arange = UOp(UOps.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (BinaryOps.ADD, arange_axis))
output_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape)) output_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape))
out = arange+ast_const(dtypes.int, -1, output_shape) out = arange+ast_const(dtypes.int, -1, output_shape)
store = UOp(UOps.STORE, src=(UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0), ShapeTracker.from_shape(output_shape).to_uop(), out)) store = UOp(UOps.STORE, src=(UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0), ShapeTracker.from_shape(output_shape).to_uop(), out))
sink = UOp(UOps.SINK, src=(store,)) sink = UOp(UOps.SINK, src=(store,))
helper_linearizer_ast(sink, [], wanna_output=[real_arange]) helper_linearizer_ast(sink, [], wanna_output=[real_arange])
with Context(DEBUG=0, NOOPT=0): np.testing.assert_equal(tiny.numpy(), real_arange) with Context(DEBUG=0, NOOPT=0): np.testing.assert_equal(tiny.numpy(), real_arange)
@unittest.skipIf(CI and Device.DEFAULT in {"PTX", "AMD", "NV"}, "very slow") @unittest.skipIf(CI and Device.DEFAULT in {"PTX", "AMD", "NV"}, "very slow")
def test_indexing_multireduce(self): def test_indexing_multireduce(self):
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
g2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=2) g2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2)
arange_input_st = ShapeTracker(views=(View(shape=(16385, 32767), strides=(0, 0), offset=0, mask=((0, 16385), (16383, 32767)), contiguous=False), \ arange_input_st = ShapeTracker(views=(View(shape=(16385, 32767), strides=(0, 0), offset=0, mask=((0, 16385), (16383, 32767)), contiguous=False), \
View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False))) View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False)))
# TODO: do this arange broadcast in the scheduler # TODO: do this arange broadcast in the scheduler
@ -598,7 +598,7 @@ class TestLinearizer(unittest.TestCase):
real_argmax = np.argmax(t.numpy(), axis=0, keepdims=False).reshape(1, 20, 1) real_argmax = np.argmax(t.numpy(), axis=0, keepdims=False).reshape(1, 20, 1)
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa E501 UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa E501
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
@ -611,10 +611,10 @@ class TestLinearizer(unittest.TestCase):
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(20, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa E501 UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(20, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa E501
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa E501 UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa E501
ast_const(dtypes.bool, True, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)), # noqa E501 ast_const(dtypes.bool, True, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)), # noqa E501
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
@ -630,7 +630,7 @@ class TestLinearizer(unittest.TestCase):
real_argmax = np.argmax(t.numpy()) real_argmax = np.argmax(t.numpy())
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501 UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
@ -643,10 +643,10 @@ class TestLinearizer(unittest.TestCase):
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa: E501 UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa: E501
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa: E501 UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa: E501
ast_const(dtypes.bool, True, (200, 1)),)),)), ast_const(dtypes.bool, True, (200, 1)),)),)),
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
@ -669,7 +669,7 @@ class TestLinearizer(unittest.TestCase):
# [Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)] # [Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)]
] ]
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
x_ld0 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop())) x_ld0 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop()))
x_ld1 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop())) x_ld1 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop()))
r0 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.ADD, (1,))) r0 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.ADD, (1,)))
@ -696,7 +696,7 @@ class TestLinearizer(unittest.TestCase):
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),] [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),]
] ]
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
x_ld0 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop())) x_ld0 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop()))
x_ld1 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop())) x_ld1 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop()))
r0 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.MAX, (1,))) r0 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.MAX, (1,)))
@ -730,7 +730,7 @@ class TestLinearizer(unittest.TestCase):
ld1 = x.lazydata.st.reshape((N, N, 1)) ld1 = x.lazydata.st.reshape((N, N, 1))
ast = UOp(UOps.SINK, src=( ast = UOp(UOps.SINK, src=(
UOp(UOps.STORE, src=( UOp(UOps.STORE, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),))), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),))),
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
@ -738,20 +738,20 @@ class TestLinearizer(unittest.TestCase):
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.LOAD, dtypes.float, src=( UOp(UOps.LOAD, dtypes.float, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
ld1.to_uop(),)), ld1.to_uop(),)),
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
ast_const(dtypes.float, 0.75*N, (N, N, 1)), ast_const(dtypes.float, 0.75*N, (N, N, 1)),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(UOps.LOAD, dtypes.float, src=( UOp(UOps.LOAD, dtypes.float, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
ld0.to_uop(),)),)),)), ld0.to_uop(),)),)),)),
UOp(UOps.LOAD, dtypes.float, src=( UOp(UOps.LOAD, dtypes.float, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
UOp(UOps.LOAD, dtypes.float, src=( UOp(UOps.LOAD, dtypes.float, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)), # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)), # noqa: E501
ast_const(dtypes.float, 0.0, (N, 1, 1)), ast_const(dtypes.float, 0.0, (N, 1, 1)),
@ -763,7 +763,7 @@ class TestLinearizer(unittest.TestCase):
wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=0,keepdims=True), a.numpy(), b.numpy())).sum(axis=0),0.0,1.0).reshape(1,1,N) # noqa: E501 wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=0,keepdims=True), a.numpy(), b.numpy())).sum(axis=0),0.0,1.0).reshape(1,1,N) # noqa: E501
ast = UOp(UOps.SINK, src=( ast = UOp(UOps.SINK, src=(
UOp(UOps.STORE, src=( UOp(UOps.STORE, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, N), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, N), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
@ -771,20 +771,20 @@ class TestLinearizer(unittest.TestCase):
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.LOAD, dtypes.float, src=( UOp(UOps.LOAD, dtypes.float, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
ld1.to_uop(),)), ld1.to_uop(),)),
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
ast_const(dtypes.float, 0.75*N, (N, 1, N)), ast_const(dtypes.float, 0.75*N, (N, 1, N)),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(UOps.LOAD, dtypes.float, src=( UOp(UOps.LOAD, dtypes.float, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
ld0.to_uop(),)),)),)), ld0.to_uop(),)),)),)),
UOp(UOps.LOAD, dtypes.float, src=( UOp(UOps.LOAD, dtypes.float, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), # noqa: E501
UOp(UOps.LOAD, dtypes.float, src=( UOp(UOps.LOAD, dtypes.float, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), # noqa: E501
ast_const(dtypes.float, 0.0, (1, 1, N)), ast_const(dtypes.float, 0.0, (1, 1, N)),
@ -799,7 +799,7 @@ class TestLinearizer(unittest.TestCase):
wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(keepdims=True), a.numpy(), b.numpy())).sum(keepdims=True),0.0,1.0).reshape((1,1,1,1))# noqa: E501 wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(keepdims=True), a.numpy(), b.numpy())).sum(keepdims=True),0.0,1.0).reshape((1,1,1,1))# noqa: E501
ast = UOp(UOps.SINK, src=( ast = UOp(UOps.SINK, src=(
UOp(UOps.STORE, src=( UOp(UOps.STORE, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=True),))), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=True),))),
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
@ -807,20 +807,20 @@ class TestLinearizer(unittest.TestCase):
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.LOAD, dtypes.float, src=( UOp(UOps.LOAD, dtypes.float, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(N, 1, 0, 0), offset=0, mask=None, contiguous=True),))),)), # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(N, 1, 0, 0), offset=0, mask=None, contiguous=True),))),)), # noqa: E501
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
ast_const(dtypes.float, 0.75*N, (N, N, 1, 1)), ast_const(dtypes.float, 0.75*N, (N, N, 1, 1)),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=(
UOp(UOps.LOAD, dtypes.float, src=( UOp(UOps.LOAD, dtypes.float, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, N, N), strides=(0, 0, N, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, N, N), strides=(0, 0, N, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501
UOp(UOps.LOAD, dtypes.float, src=( UOp(UOps.LOAD, dtypes.float, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
UOp(UOps.LOAD, dtypes.float, src=( UOp(UOps.LOAD, dtypes.float, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)), # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)), # noqa: E501
ast_const(dtypes.float, 0.0, (1, 1, 1, 1)), ast_const(dtypes.float, 0.0, (1, 1, 1, 1)),
ast_const(dtypes.float, 1.0, (1, 1, 1, 1)),)),)),)) ast_const(dtypes.float, 1.0, (1, 1, 1, 1)),)),)),))
@ -829,7 +829,7 @@ class TestLinearizer(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
def test_end_local(self): def test_end_local(self):
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=i) for i in range(2)] g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=i) for i in range(2)]
load = UOp(UOps.LOAD, dtypes.int, (g1, ShapeTracker.from_shape((32,)).to_uop())) load = UOp(UOps.LOAD, dtypes.int, (g1, ShapeTracker.from_shape((32,)).to_uop()))
reduce = UOp(UOps.REDUCE_AXIS, dtypes.int, (load,), (BinaryOps.ADD, (0,))) reduce = UOp(UOps.REDUCE_AXIS, dtypes.int, (load,), (BinaryOps.ADD, (0,)))
store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((1,)).to_uop(), reduce)) store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((1,)).to_uop(), reduce))
@ -931,7 +931,7 @@ class TestLinearizer(unittest.TestCase):
# make sure const buffers are differentiated from local and mem buffers # make sure const buffers are differentiated from local and mem buffers
ST, DT = ShapeTracker(views=(View(shape=((1,)), strides=(0, 0), offset=0, mask=None, contiguous=False),)).to_uop(), dtypes.int ST, DT = ShapeTracker(views=(View(shape=((1,)), strides=(0, 0), offset=0, mask=None, contiguous=False),)).to_uop(), dtypes.int
VAL = ast_const(DT, 2, ST.arg.shape) VAL = ast_const(DT, 2, ST.arg.shape)
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(DT), arg=i) for i in range(2)] g0, g1 = [UOp(UOps.DEFINE_GLOBAL, DT.ptr(), arg=i) for i in range(2)]
# data1[0] + VAL # data1[0] + VAL
a = UOp(UOps.LOAD, DT, (g1, ST)) + VAL a = UOp(UOps.LOAD, DT, (g1, ST)) + VAL
@ -1368,10 +1368,10 @@ class TestLinearizer(unittest.TestCase):
Tensor.manual_seed(0) Tensor.manual_seed(0)
ast = UOp(UOps.SINK, src=( ast = UOp(UOps.SINK, src=(
UOp(UOps.STORE, src=( UOp(UOps.STORE, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),))), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),))),
UOp(UOps.LOAD, dtypes.float, src=( UOp(UOps.LOAD, dtypes.float, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501
opt = [ opt = [
Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16),
@ -1387,10 +1387,10 @@ class TestLinearizer(unittest.TestCase):
Tensor.manual_seed(0) Tensor.manual_seed(0)
ast = UOp(UOps.SINK, src=( ast = UOp(UOps.SINK, src=(
UOp(UOps.STORE, src=( UOp(UOps.STORE, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),))), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),))),
UOp(UOps.LOAD, dtypes.float, src=( UOp(UOps.LOAD, dtypes.float, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501
opt = [Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=8), opt = [Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=8),
Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8),
@ -1603,16 +1603,16 @@ class TestFloat4(unittest.TestCase):
# from llama 7B shard 4 gpus # from llama 7B shard 4 gpus
ast = UOp(UOps.SINK, src=( ast = UOp(UOps.SINK, src=(
UOp(UOps.STORE, src=( UOp(UOps.STORE, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(UOps.CAST, dtypes.float, src=( UOp(UOps.CAST, dtypes.float, src=(
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.half, src=( UOp(UOps.LOAD, dtypes.half, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1), UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
UOp(UOps.LOAD, dtypes.half, src=( UOp(UOps.LOAD, dtypes.half, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2), UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 0, 1024, 1), offset=0, mask=None, contiguous=False),))),)),)),)),)),)),)) # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 0, 1024, 1), offset=0, mask=None, contiguous=False),))),)),)),)),)),)),)) # noqa: E501
# TODO: fix this, expected might change but should be positive # TODO: fix this, expected might change but should be positive
@ -1632,19 +1632,19 @@ class TestFloat4(unittest.TestCase):
# from float32 stable diffusion red tinybox # from float32 stable diffusion red tinybox
ast = UOp(UOps.SINK, src=( ast = UOp(UOps.SINK, src=(
UOp(UOps.STORE, src=( UOp(UOps.STORE, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.float, src=( UOp(UOps.LOAD, dtypes.float, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False)))),)), # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False)))),)), # noqa: E501
UOp(UOps.LOAD, dtypes.float, src=( UOp(UOps.LOAD, dtypes.float, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501
UOp(UOps.LOAD, dtypes.float, src=( UOp(UOps.LOAD, dtypes.float, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)) # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)) # noqa: E501
for expected, opts in [ for expected, opts in [
@ -1662,13 +1662,13 @@ class TestFloat4(unittest.TestCase):
# from resnet # from resnet
ast = UOp(UOps.SINK, src=( ast = UOp(UOps.SINK, src=(
UOp(UOps.STORE, src=( UOp(UOps.STORE, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0), UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),))), # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),))), # noqa: E501
UOp(UOps.CAST, dtypes.half, src=( UOp(UOps.CAST, dtypes.half, src=(
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (4, 6)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (4, 6)), src=(
UOp(UOps.CAST, dtypes.float, src=( UOp(UOps.CAST, dtypes.float, src=(
UOp(UOps.LOAD, dtypes.half, src=( UOp(UOps.LOAD, dtypes.half, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1), UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True)))),)),)),)),)),)),)) # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True)))),)),)),)),)),)),)) # noqa: E501
for expected, opts in [ for expected, opts in [
(16, [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=3, amt=4)]), # noqa: E501 (16, [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=3, amt=4)]), # noqa: E501
@ -1950,20 +1950,20 @@ class TestKernelOpts(unittest.TestCase):
def test_buf_index_not_found_tensor_core(self): def test_buf_index_not_found_tensor_core(self):
ast = UOp(UOps.SINK, src=( ast = UOp(UOps.SINK, src=(
UOp(UOps.STORE, src=( UOp(UOps.STORE, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),))), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),))),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.CAST, dtypes.float, src=( UOp(UOps.CAST, dtypes.float, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(UOps.LOAD, dtypes.int, src=( UOp(UOps.LOAD, dtypes.int, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1), UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
UOp(UOps.LOAD, dtypes.int, src=( UOp(UOps.LOAD, dtypes.int, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=2), UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501
UOp(UOps.LOAD, dtypes.float, src=( UOp(UOps.LOAD, dtypes.float, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)) # noqa: E501 UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)) # noqa: E501
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
with self.assertRaises(KernelOptError): with self.assertRaises(KernelOptError):
@ -2138,7 +2138,7 @@ class TestKernelOpts(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
def test_padto_group(self): def test_padto_group(self):
Tensor.manual_seed(0) Tensor.manual_seed(0)
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)] g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
ld0 = UOp(UOps.LOAD, dtypes.float, (g1, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501 ld0 = UOp(UOps.LOAD, dtypes.float, (g1, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501
ld1 = UOp(UOps.LOAD, dtypes.float, (g2, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501 ld1 = UOp(UOps.LOAD, dtypes.float, (g2, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501
store = UOp(UOps.STORE, src=(g0, ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)).to_uop(), UOp(UOps.REDUCE_AXIS, dtypes.float, (ld0*ld1,), (BinaryOps.ADD, (0, 2, 4, 6)),))) # noqa: E501 store = UOp(UOps.STORE, src=(g0, ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)).to_uop(), UOp(UOps.REDUCE_AXIS, dtypes.float, (ld0*ld1,), (BinaryOps.ADD, (0, 2, 4, 6)),))) # noqa: E501

View File

@ -5,7 +5,6 @@
import unittest import unittest
from test.helpers import ast_const from test.helpers import ast_const
from tinygrad import Device, dtypes from tinygrad import Device, dtypes
from tinygrad.dtype import PtrDType
from tinygrad.ops import UOp, UOps, BinaryOps, TernaryOps from tinygrad.ops import UOp, UOps, BinaryOps, TernaryOps
from tinygrad.helpers import getenv from tinygrad.helpers import getenv
from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad.shape.shapetracker import ShapeTracker, View
@ -17,7 +16,7 @@ class TestLinearizerDumb(unittest.TestCase):
def test_unmerged_ifs(self): def test_unmerged_ifs(self):
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MAX, src=( UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MAX, src=(
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
@ -26,10 +25,10 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(UOps.CAST, dtypes.float, arg=None, src=( UOp(UOps.CAST, dtypes.float, arg=None, src=(
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.half, arg=None, src=( UOp(UOps.LOAD, dtypes.half, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 512, 4, 9, 4, 9), strides=(0, 25088, 0, 49, 0, 7, 0, 1), offset=-8, mask=((0, 1), (0, 64), (0, 1), (0, 512), (0, 4), (1, 8), (0, 4), (1, 8)), contiguous=False), View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(663552, 0, 0, 36, 1, 1296, 360, 10), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 512, 4, 9, 4, 9), strides=(0, 25088, 0, 49, 0, 7, 0, 1), offset=-8, mask=((0, 1), (0, 64), (0, 1), (0, 512), (0, 4), (1, 8), (0, 4), (1, 8)), contiguous=False), View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(663552, 0, 0, 36, 1, 1296, 360, 10), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(UOps.LOAD, dtypes.half, arg=None, src=( UOp(UOps.LOAD, dtypes.half, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),
ast_const(dtypes.half, 0.9999950000374996, st_src=( ast_const(dtypes.half, 0.9999950000374996, st_src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
@ -50,17 +49,17 @@ class TestLinearizerDumb(unittest.TestCase):
def test_max_simplify_and_cancel(self): def test_max_simplify_and_cancel(self):
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
UOp(UOps.CAST, dtypes.int, arg=None, src=( UOp(UOps.CAST, dtypes.int, arg=None, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
ast_const(dtypes.bool, True, st_src=( ast_const(dtypes.bool, True, st_src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
@ -81,11 +80,11 @@ class TestLinearizerDumb(unittest.TestCase):
def test_expander_new_srcs(self): def test_expander_new_srcs(self):
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0)] opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0)]
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
@ -102,7 +101,7 @@ class TestLinearizerDumb(unittest.TestCase):
def test_llama_embedding(self): def test_llama_embedding(self):
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.CAST, dtypes.half, arg=None, src=( UOp(UOps.CAST, dtypes.half, arg=None, src=(
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
@ -118,12 +117,12 @@ class TestLinearizerDumb(unittest.TestCase):
ast_const(dtypes.int, -1, st_src=( ast_const(dtypes.int, -1, st_src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(UOps.LOAD, dtypes.int, arg=None, src=( UOp(UOps.LOAD, dtypes.int, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
ast_const(dtypes.bool, True, st_src=( ast_const(dtypes.bool, True, st_src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
UOp(UOps.LOAD, dtypes.half, arg=None, src=( UOp(UOps.LOAD, dtypes.half, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(1, 4096, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(1, 4096, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
prg = k.to_program() prg = k.to_program()
@ -134,7 +133,7 @@ class TestLinearizerDumb(unittest.TestCase):
def test_unaligns_idxs(self): def test_unaligns_idxs(self):
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
@ -142,16 +141,16 @@ class TestLinearizerDumb(unittest.TestCase):
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(UOps.LOAD, dtypes.long, arg=None, src=( UOp(UOps.LOAD, dtypes.long, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.long), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(UOps.CAST, dtypes.long, arg=None, src=( UOp(UOps.CAST, dtypes.long, arg=None, src=(
UOp(UOps.LOAD, dtypes.int, arg=None, src=( UOp(UOps.LOAD, dtypes.int, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=2, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
ast_const(dtypes.bool, True, st_src=( ast_const(dtypes.bool, True, st_src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=3)] opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=3)]
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
@ -166,14 +165,14 @@ class TestLinearizerDumb(unittest.TestCase):
def test_unrolled_float4_align(self): def test_unrolled_float4_align(self):
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
UOp(UOps.LOAD, dtypes.long, arg=None, src=( UOp(UOps.LOAD, dtypes.long, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.long), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
ast_const(dtypes.long, -1, st_src=( ast_const(dtypes.long, -1, st_src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
@ -182,7 +181,7 @@ class TestLinearizerDumb(unittest.TestCase):
ast_const(dtypes.float, 0.0, st_src=( ast_const(dtypes.float, 0.0, st_src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0)] opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0)]
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
@ -198,15 +197,15 @@ class TestLinearizerDumb(unittest.TestCase):
def test_upcasted_stores_out_of_order(self): def test_upcasted_stores_out_of_order(self):
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6,)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6,)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(0, 0, 0, 0, 0, 0, 1, 0, 4, 48, 16), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(0, 0, 0, 0, 0, 0, 1, 0, 4, 48, 16), offset=0, mask=None, contiguous=False),)), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=0)] opts = [Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=0)]
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,6 @@ from tinygrad.engine.search import time_linearizer, bufs_from_lin
# stuff needed to unpack a kernel # stuff needed to unpack a kernel
from tinygrad.ops import UOp, UOps, BinaryOps, UnaryOps from tinygrad.ops import UOp, UOps, BinaryOps, UnaryOps
from tinygrad.dtype import PtrDType
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View from tinygrad.shape.view import View
@ -27,7 +26,7 @@ class TestLinearizerOverflow(unittest.TestCase):
def test_overflow_1(self): def test_overflow_1(self):
ast = UOp(UOps.SINK, None, arg=None, src=( ast = UOp(UOps.SINK, None, arg=None, src=(
UOp(UOps.STORE, None, arg=None, src=( UOp(UOps.STORE, None, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
@ -37,15 +36,15 @@ class TestLinearizerOverflow(unittest.TestCase):
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, 64), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, 64), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
x16:=ast_const(dtypes.float, 0.0, st_src=( x16:=ast_const(dtypes.float, 0.0, st_src=(
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
UOp(UOps.ALU, dtypes.float, arg=UnaryOps.SQRT, src=( UOp(UOps.ALU, dtypes.float, arg=UnaryOps.SQRT, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
@ -57,7 +56,7 @@ class TestLinearizerOverflow(unittest.TestCase):
ast_const(dtypes.float, 1e-05, st_src=( ast_const(dtypes.float, 1e-05, st_src=(
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=4, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
x16,)),)),)) x16,)),)),))
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=0)] opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=0)]
@ -67,15 +66,15 @@ class TestLinearizerOverflow(unittest.TestCase):
def test_overflow_2(self): def test_overflow_2(self):
ast = UOp(UOps.SINK, None, arg=None, src=( ast = UOp(UOps.SINK, None, arg=None, src=(
UOp(UOps.STORE, None, arg=None, src=( UOp(UOps.STORE, None, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 1, 1, 1), strides=(65536, 0, 1024, 32, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 1, 1, 1), strides=(65536, 0, 1024, 32, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 512, 1, 32, 4, 34, 4, 34), strides=(0, 32768, 0, 1024, 0, 32, 0, 1), offset=-33, mask=((0, 1), (0, 512), (0, 1), (0, 32), (0, 4), (1, 33), (0, 4), (1, 33)), contiguous=False), View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(591872, 0, 0, 136, 1, 18496, 4760, 35), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 512, 1, 32, 4, 34, 4, 34), strides=(0, 32768, 0, 1024, 0, 32, 0, 1), offset=-33, mask=((0, 1), (0, 512), (0, 1), (0, 32), (0, 4), (1, 33), (0, 4), (1, 33)), contiguous=False), View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(591872, 0, 0, 136, 1, 18496, 4760, 35), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(0, 0, 288, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(0, 0, 288, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=0)] opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=0)]
_test_overflow(ast, opts) _test_overflow(ast, opts)
@ -84,15 +83,15 @@ class TestLinearizerOverflow(unittest.TestCase):
def test_overflow_3(self): def test_overflow_3(self):
ast = UOp(UOps.SINK, None, arg=None, src=( ast = UOp(UOps.SINK, None, arg=None, src=(
UOp(UOps.STORE, None, arg=None, src=( UOp(UOps.STORE, None, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 16, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 16), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 16, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 16), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=2)] opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=2)]
_test_overflow(ast, opts) _test_overflow(ast, opts)
@ -101,15 +100,15 @@ class TestLinearizerOverflow(unittest.TestCase):
def test_overflow_4(self): def test_overflow_4(self):
ast = UOp(UOps.SINK, None, arg=None, src=( ast = UOp(UOps.SINK, None, arg=None, src=(
UOp(UOps.STORE, None, arg=None, src=( UOp(UOps.STORE, None, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 4, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 4), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 4, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 4), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)] opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)]
_test_overflow(ast, opts) _test_overflow(ast, opts)
@ -118,15 +117,15 @@ class TestLinearizerOverflow(unittest.TestCase):
def test_overflow_5(self): def test_overflow_5(self):
ast = UOp(UOps.SINK, None, arg=None, src=( ast = UOp(UOps.SINK, None, arg=None, src=(
UOp(UOps.STORE, None, arg=None, src=( UOp(UOps.STORE, None, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 2), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 2), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=2)] opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=2)]
_test_overflow(ast, opts) _test_overflow(ast, opts)
@ -135,15 +134,15 @@ class TestLinearizerOverflow(unittest.TestCase):
def test_overflow_6(self): def test_overflow_6(self):
ast = UOp(UOps.SINK, None, arg=None, src=( ast = UOp(UOps.SINK, None, arg=None, src=(
UOp(UOps.STORE, None, arg=None, src=( UOp(UOps.STORE, None, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)] opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)]
_test_overflow(ast, opts) _test_overflow(ast, opts)
@ -152,15 +151,15 @@ class TestLinearizerOverflow(unittest.TestCase):
def test_overflow_7(self): def test_overflow_7(self):
ast = UOp(UOps.SINK, None, arg=None, src=( ast = UOp(UOps.SINK, None, arg=None, src=(
UOp(UOps.STORE, None, arg=None, src=( UOp(UOps.STORE, None, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)] opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)]
_test_overflow(ast, opts) _test_overflow(ast, opts)
@ -170,7 +169,7 @@ class TestLinearizerOverflow(unittest.TestCase):
class TestLinearizerOverflowAlt(unittest.TestCase): class TestLinearizerOverflowAlt(unittest.TestCase):
def test_overflow_1(self): def test_overflow_1(self):
BS = 2 BS = 2
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)] g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
in_st_1 = ShapeTracker(views=(View(shape=(1, BS, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, BS), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), in_st_1 = ShapeTracker(views=(View(shape=(1, BS, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, BS), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False),
View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))).to_uop() View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))).to_uop()
in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop() in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop()
@ -182,7 +181,7 @@ class TestLinearizerOverflowAlt(unittest.TestCase):
_test_overflow(ast, opts) _test_overflow(ast, opts)
def test_overflow_2(self): def test_overflow_2(self):
BS = 2 BS = 2
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)] g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
in_st_1 = ShapeTracker(views=(View(shape=(1, BS, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, BS), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), in_st_1 = ShapeTracker(views=(View(shape=(1, BS, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, BS), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False),
View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))).to_uop() View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))).to_uop()
in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop() in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop()

View File

@ -4,7 +4,7 @@ import numpy as np
from tinygrad.codegen.uopgraph import full_graph_rewrite from tinygrad.codegen.uopgraph import full_graph_rewrite
from tinygrad.codegen.linearize import linearize_uop from tinygrad.codegen.linearize import linearize_uop
from tinygrad.device import Buffer, Device from tinygrad.device import Buffer, Device
from tinygrad.dtype import PtrDType, dtypes from tinygrad.dtype import dtypes
from tinygrad.engine.realize import CompiledRunner from tinygrad.engine.realize import CompiledRunner
from tinygrad.helpers import dedup, flatten, getenv, prod from tinygrad.helpers import dedup, flatten, getenv, prod
from tinygrad.renderer.cstyle import CStyleLanguage from tinygrad.renderer.cstyle import CStyleLanguage
@ -30,8 +30,8 @@ def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, CStyleLanguage), "uops are for cstyle") @unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, CStyleLanguage), "uops are for cstyle")
class TestCStyleFailures(unittest.TestCase): class TestCStyleFailures(unittest.TestCase):
def test_inline_const_alu(self): def test_inline_const_alu(self):
a = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
b = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 1) b = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
idx = UOp.const(dtypes.int, 0) idx = UOp.const(dtypes.int, 0)
ld = UOp(UOps.LOAD, dtypes.int, (b, idx)) ld = UOp(UOps.LOAD, dtypes.int, (b, idx))
alu = ld.alu(BinaryOps.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1)) alu = ld.alu(BinaryOps.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1))
@ -45,7 +45,7 @@ class TestCStyleFailures(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
class TestPTXFailures(unittest.TestCase): class TestPTXFailures(unittest.TestCase):
def test_gated_store_with_alu(self): def test_gated_store_with_alu(self):
a = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0) gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
gated_alu_store = UOp(UOps.STORE, dtypes.void, (a, lidx0, UOp.const(dtypes.int, 1), gate_alu)) gated_alu_store = UOp(UOps.STORE, dtypes.void, (a, lidx0, UOp.const(dtypes.int, 1), gate_alu))
sink = UOp(UOps.SINK, dtypes.void, (gated_alu_store,)) sink = UOp(UOps.SINK, dtypes.void, (gated_alu_store,))
@ -55,7 +55,7 @@ class TestPTXFailures(unittest.TestCase):
@unittest.skip("not still valid?") @unittest.skip("not still valid?")
def test_gated_store_with_if(self): def test_gated_store_with_if(self):
a = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0) gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
val = UOp.const(dtypes.int, 1) val = UOp.const(dtypes.int, 1)
if_uop = UOp(UOps.IF, dtypes.void, (gate_alu, val)) if_uop = UOp(UOps.IF, dtypes.void, (gate_alu, val))

View File

@ -9,7 +9,7 @@ import functools
from typing import List, Optional, Union, cast from typing import List, Optional, Union, cast
from tinygrad import nn, dtypes, Device, Tensor from tinygrad import nn, dtypes, Device, Tensor
from tinygrad.dtype import DType, PtrDType from tinygrad.dtype import DType
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View from tinygrad.shape.view import View
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite, track_rewrites from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite, track_rewrites
@ -1611,7 +1611,7 @@ class TestIndexing(unittest.TestCase):
self.assertLess(et, 1e3) self.assertLess(et, 1e3)
def test_no_rewrite_elementwise(self): def test_no_rewrite_elementwise(self):
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(3)] bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop())) ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop()))
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),)) sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),))
@ -1619,7 +1619,7 @@ class TestIndexing(unittest.TestCase):
self.assertEqual(rsink.key, sink.key) self.assertEqual(rsink.key, sink.key)
def test_simple_store_reshape(self): def test_simple_store_reshape(self):
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(2)] bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(())) r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
@ -1632,7 +1632,7 @@ class TestIndexing(unittest.TestCase):
verify_ast(rsink) verify_ast(rsink)
def test_no_reshape_reduceop(self): def test_no_reshape_reduceop(self):
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(2)] bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),)) sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
@ -1641,7 +1641,7 @@ class TestIndexing(unittest.TestCase):
self.assertEqual(sink.key, rsink.key) self.assertEqual(sink.key, rsink.key)
def test_reshape_many(self): def test_reshape_many(self):
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(2)] bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(())) r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
@ -1660,7 +1660,7 @@ class TestIndexing(unittest.TestCase):
sizes = [10*(i+1) for i in range(SZ)] sizes = [10*(i+1) for i in range(SZ)]
tms: List[float] = [] tms: List[float] = []
for sz in sizes: for sz in sizes:
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(2)] bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
for _ in range(sz): r = r + ast_const(dtypes.int, 2, ()) for _ in range(sz): r = r + ast_const(dtypes.int, 2, ())
@ -1682,14 +1682,14 @@ class TestIndexing(unittest.TestCase):
# graph rewrite # graph rewrite
sink = UOp(UOps.SINK, dtypes.void, arg=None, src=( sink = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501 UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=( UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
UOp(UOps.VIEW, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501 UOp(UOps.VIEW, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=( UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
UOp(UOps.LOAD, dtypes.int, arg=None, src=( UOp(UOps.LOAD, dtypes.int, arg=None, src=(
x8:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()), x8:=UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501 UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501
UOp(UOps.LOAD, dtypes.int, arg=None, src=( UOp(UOps.LOAD, dtypes.int, arg=None, src=(
x8, x8,
@ -1709,7 +1709,7 @@ class TestIndexing(unittest.TestCase):
a = Tensor.randint(4,).realize() a = Tensor.randint(4,).realize()
expected_out = a.numpy().sum(0)+1 expected_out = a.numpy().sum(0)+1
# LazyBuffer to pre-rewrite AST # LazyBuffer to pre-rewrite AST
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(2)] bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop())) ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0,))) r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0,)))
swizzle_r = UOp(UOps.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(())) swizzle_r = UOp(UOps.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(()))
@ -1732,7 +1732,7 @@ class TestIndexing(unittest.TestCase):
b = Tensor.randint(4,).realize() b = Tensor.randint(4,).realize()
expected_out = a.numpy().sum(0)+b.numpy().sum(0)+2 expected_out = a.numpy().sum(0)+b.numpy().sum(0)+2
# LazyBuffer to pre-rewrite AST # LazyBuffer to pre-rewrite AST
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(3)] bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop())) ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
r1 = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld1,), (BinaryOps.ADD, (0,))) r1 = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld1,), (BinaryOps.ADD, (0,)))
ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop())) ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop()))
@ -1752,7 +1752,7 @@ class TestIndexing(unittest.TestCase):
swizzle = UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501 swizzle = UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501 UOp(UOps.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501
# there's an EXPAND pushing through the REDUCE_AXIS # there's an EXPAND pushing through the REDUCE_AXIS
self.assertGreater(prod(swizzle.st.shape), prod(swizzle.src[0].st.shape)) self.assertGreater(prod(swizzle.st.shape), prod(swizzle.src[0].st.shape))
@ -1766,7 +1766,7 @@ class TestIndexing(unittest.TestCase):
def test_permute_rewrite(self): def test_permute_rewrite(self):
sink = UOp(UOps.STORE, dtypes.void, arg=None, src=( sink = UOp(UOps.STORE, dtypes.void, arg=None, src=(
x1:=UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(1, ('METAL', 16384, dtypes.float)), src=()), x1:=UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(1, ('METAL', 16384, dtypes.float)), src=()),
x2:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()), x2:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),
UOp(UOps.CONTIGUOUS, dtypes.float, arg=None, src=( UOp(UOps.CONTIGUOUS, dtypes.float, arg=None, src=(
x1, x1,
@ -1778,15 +1778,15 @@ class TestIndexing(unittest.TestCase):
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=( UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
x11:=UOp(UOps.LOAD, dtypes.float, arg=None, src=( x11:=UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(2, ('METAL', 16384, dtypes.float)), src=()), UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 16384, dtypes.float)), src=()),
x2,)),)), x2,)),)),
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=( UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(8, ('METAL', 256, dtypes.float)), src=()), UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(8, ('METAL', 256, dtypes.float)), src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=( UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(10, ('METAL', 16, dtypes.float)), src=()), UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(10, ('METAL', 16, dtypes.float)), src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=( UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
x11,)),)),)),)),)) x11,)),)),)),)),))

View File

@ -8,7 +8,7 @@ from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search
from tinygrad.device import Device, Buffer from tinygrad.device import Device, Buffer
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.dtype import dtypes, PtrDType from tinygrad.dtype import dtypes
from tinygrad.helpers import Context, GlobalCounters from tinygrad.helpers import Context, GlobalCounters
from tinygrad.engine.realize import capturing from tinygrad.engine.realize import capturing
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
@ -48,7 +48,7 @@ class TestTimeLinearizer(unittest.TestCase):
# ast of Tensor.zeros(16).contiguous().realize() # ast of Tensor.zeros(16).contiguous().realize()
ast = UOp(UOps.SINK, src=( ast = UOp(UOps.SINK, src=(
UOp(UOps.STORE, src=( UOp(UOps.STORE, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),))), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),))),
ast_const(dtypes.float, 0.0, st_src=( ast_const(dtypes.float, 0.0, st_src=(
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(0,), offset=0, mask=None, contiguous=False),))),)),)),)) UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(0,), offset=0, mask=None, contiguous=False),))),)),)),))
@ -105,7 +105,7 @@ class TestBEAM(unittest.TestCase):
# taken from https://github.com/tinygrad/tinygrad/issues/4612 # taken from https://github.com/tinygrad/tinygrad/issues/4612
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501 UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (1,)), src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (1,)), src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
@ -115,22 +115,22 @@ class TestBEAM(unittest.TestCase):
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)), # noqa: E501 UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)), # noqa: E501
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501 UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501 UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=4, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501 UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=5, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501 UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
UOp(UOps.LOAD, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=6, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501 UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
ast_const(dtypes.float, 1.4285714285714286, st_src=( ast_const(dtypes.float, 1.4285714285714286, st_src=(
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) # noqa: E501 UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) # noqa: E501

View File

@ -2,7 +2,6 @@ from typing import List
import unittest, time import unittest, time
from test.helpers import assert_equiv_uops from test.helpers import assert_equiv_uops
from tinygrad import dtypes, Device from tinygrad import dtypes, Device
from tinygrad.dtype import PtrDType
from tinygrad.helpers import DEBUG from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo
from tinygrad.ops import UPat, PatternMatcher from tinygrad.ops import UPat, PatternMatcher
@ -32,7 +31,7 @@ class TestGraphRewriteEfficiency(unittest.TestCase):
def test_expand_rewrite(self): def test_expand_rewrite(self):
sink = UOp(UOps.SINK, dtypes.void, arg=KernelInfo(local_dims=2, upcasted=4, dont_use_locals=False), src=( sink = UOp(UOps.SINK, dtypes.void, arg=KernelInfo(local_dims=2, upcasted=4, dont_use_locals=False), src=(
UOp(UOps.STORE, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 4, 64, 8, 16, 1, 1, 3, 3, 4, 1), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 4, 64, 8, 16, 1, 1, 3, 3, 4, 1),
strides=(1179648, 9216, 1, 147456, 576, 0, 0, 64, 192, 36864, 0), strides=(1179648, 9216, 1, 147456, 576, 0, 0, 64, 192, 36864, 0),
offset=0, mask=None, contiguous=False),)), src=()), offset=0, mask=None, contiguous=False),)), src=()),
@ -40,14 +39,14 @@ class TestGraphRewriteEfficiency(unittest.TestCase):
UOp(UOps.CAST, dtypes.float, arg=None, src=( UOp(UOps.CAST, dtypes.float, arg=None, src=(
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
UOp(UOps.LOAD, dtypes.half, arg=None, src=( UOp(UOps.LOAD, dtypes.half, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(
View(shape=(1, 1024, 1, 64, 4, 17, 4, 17), strides=(0, 14400, 0, 225, 0, 15, 0, 1), offset=-16, View(shape=(1, 1024, 1, 64, 4, 17, 4, 17), strides=(0, 14400, 0, 225, 0, 15, 0, 1), offset=-16,
mask=((0, 1), (0, 1024), (0, 1), (0, 64), (0, 4), (1, 16), (0, 4), (1, 16)), contiguous=False), mask=((0, 1), (0, 1024), (0, 1), (0, 64), (0, 4), (1, 16), (0, 4), (1, 16)), contiguous=False),
View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(0, 73984, 4734976, 0, 4624, 295936, 68, 18, 1224, 0, 1), offset=0, View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(0, 73984, 4734976, 0, 4624, 295936, 68, 18, 1224, 0, 1), offset=0,
mask=None, contiguous=False))), src=()),)), mask=None, contiguous=False))), src=()),)),
UOp(UOps.LOAD, dtypes.half, arg=None, src=( UOp(UOps.LOAD, dtypes.half, arg=None, src=(
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()), UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(
View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(7200, 0, 230400, 900, 0, 14400, 15, 0, 0, 225, 1), offset=0, View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(7200, 0, 230400, 900, 0, 14400, 15, 0, 0, 225, 1), offset=0,
mask=None, contiguous=False),)), src=()),)),)),)),)),)),)) mask=None, contiguous=False),)), src=()),)),)),)),)),)),))
@ -225,7 +224,7 @@ class TestUOpGraph(unittest.TestCase):
@unittest.skip("this test isn't valid uops") @unittest.skip("this test isn't valid uops")
def test_noop_vectorize_fold(self): def test_noop_vectorize_fold(self):
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0) d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
idx = UOp.const(dtypes.int, 0) idx = UOp.const(dtypes.int, 0)
ld = UOp(UOps.LOAD, dtypes.float.vec(2), (d0, idx)) ld = UOp(UOps.LOAD, dtypes.float.vec(2), (d0, idx))
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(2), (ld,)) vec = UOp(UOps.VECTORIZE, dtypes.float.vec(2), (ld,))
@ -236,9 +235,9 @@ class TestUOpGraph(unittest.TestCase):
self.assertEqual(len([x for x in uops if x.op is UOps.VECTORIZE]), 0) self.assertEqual(len([x for x in uops if x.op is UOps.VECTORIZE]), 0)
def test_gep_vec_fold(self): def test_gep_vec_fold(self):
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 1) d1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
d2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 2) d2 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 2)
idx = UOp.const(dtypes.int, 0) idx = UOp.const(dtypes.int, 0)
def _test_vec(geps, count=4): def _test_vec(geps, count=4):
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(count), geps) vec = UOp(UOps.VECTORIZE, dtypes.float.vec(count), geps)
@ -342,8 +341,8 @@ class TestUOpGraph(unittest.TestCase):
assert_equiv_uops(uops[-1], wmma) assert_equiv_uops(uops[-1], wmma)
def test_cast_alu_fold(self): def test_cast_alu_fold(self):
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=0) d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0)
d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1) d1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
idx = UOp.const(dtypes.int, 0) idx = UOp.const(dtypes.int, 0)
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx)) ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
alu = ld.lt(1).cast(dtypes.bool) alu = ld.lt(1).cast(dtypes.bool)
@ -352,8 +351,8 @@ class TestUOpGraph(unittest.TestCase):
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 0) self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 0)
def test_double_cast_fold(self): def test_double_cast_fold(self):
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0) d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1) d1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
idx = UOp.const(dtypes.int, 0) idx = UOp.const(dtypes.int, 0)
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx)) ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
alu = ld.cast(dtypes.float).cast(dtypes.float) alu = ld.cast(dtypes.float).cast(dtypes.float)
@ -376,9 +375,9 @@ class TestUOpGraph(unittest.TestCase):
self.assertEqual(out.src[1].arg, 6) self.assertEqual(out.src[1].arg, 6)
def test_fold_gated_load(self): def test_fold_gated_load(self):
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) glbl0 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
glbl1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 1) glbl1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
glbl2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 2) glbl2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2)
idx = UOp.const(dtypes.int, 0) idx = UOp.const(dtypes.int, 0)
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False))) ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False)))
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True))) ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True)))
@ -390,8 +389,8 @@ class TestUOpGraph(unittest.TestCase):
assert_equiv_uops(ld0, UOp.load(glbl2, idx, dtype=dtypes.int)) assert_equiv_uops(ld0, UOp.load(glbl2, idx, dtype=dtypes.int))
def test_fold_gated_load_local(self): def test_fold_gated_load_local(self):
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) glbl0 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
smem = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.int, local=True), (), ("temp", 1)) smem = UOp(UOps.DEFINE_LOCAL, dtypes.int.ptr(local=True), (), ("temp", 1))
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16)) lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
st = UOp(UOps.STORE, dtypes.void, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int))) st = UOp(UOps.STORE, dtypes.void, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int)))
barrier = UOp(UOps.BARRIER, dtypes.void, (st, )) barrier = UOp(UOps.BARRIER, dtypes.void, (st, ))
@ -405,7 +404,7 @@ class TestUOpGraph(unittest.TestCase):
assert_equiv_uops(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int)) assert_equiv_uops(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int))
def test_fold_gated_store(self): def test_fold_gated_store(self):
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
idx0 = UOp.const(dtypes.int, 0) idx0 = UOp.const(dtypes.int, 0)
idx1 = UOp.const(dtypes.int, 0) idx1 = UOp.const(dtypes.int, 0)
val = UOp.const(dtypes.int, 42) val = UOp.const(dtypes.int, 42)
@ -418,13 +417,13 @@ class TestUOpGraph(unittest.TestCase):
@unittest.skip("this is a uop type error") @unittest.skip("this is a uop type error")
def test_asserts_bad_gate(self): def test_asserts_bad_gate(self):
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) glbl0 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
idx = UOp.const(dtypes.int, 0) idx = UOp.const(dtypes.int, 0)
bad_gate = UOp.const(dtypes.int, 1) bad_gate = UOp.const(dtypes.int, 1)
with self.assertRaises(AssertionError): to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))]) with self.assertRaises(AssertionError): to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
def test_switched_range_order(self): def test_switched_range_order(self):
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
c0 = UOp.const(dtypes.int, 0) c0 = UOp.const(dtypes.int, 0)
c2 = UOp.const(dtypes.int, 2) c2 = UOp.const(dtypes.int, 2)
cf = UOp.const(dtypes.float, 0.0) cf = UOp.const(dtypes.float, 0.0)
@ -591,21 +590,21 @@ class TestExpander(unittest.TestCase):
class TestLoadStoreFolder(unittest.TestCase): class TestLoadStoreFolder(unittest.TestCase):
def test_simple_load_fold(self): def test_simple_load_fold(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(4)] load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(4)]
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
sink = float4_rewrite(sink) sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1 assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
def test_two_load_fold(self): def test_two_load_fold(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(8)] load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(8)]
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
sink = float4_rewrite(sink) sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 2 assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 2
def test_simple_load_fold_gated(self): def test_simple_load_fold_gated(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
gate = UOp(UOps.DEFINE_VAR, dtypes.bool) gate = UOp(UOps.DEFINE_VAR, dtypes.bool)
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)] load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
@ -615,7 +614,7 @@ class TestLoadStoreFolder(unittest.TestCase):
self.assertListEqual(list(single_load.src[2].arg), [0.0, 1.0, 2.0, 3.0]) self.assertListEqual(list(single_load.src[2].arg), [0.0, 1.0, 2.0, 3.0])
def test_simple_load_dont_fold_different_gated(self): def test_simple_load_dont_fold_different_gated(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
gate = UOp.variable("g1", False, True, dtypes.bool) gate = UOp.variable("g1", False, True, dtypes.bool)
gate2 = UOp.variable("g2", False, True, dtypes.bool) gate2 = UOp.variable("g2", False, True, dtypes.bool)
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)] load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
@ -624,14 +623,14 @@ class TestLoadStoreFolder(unittest.TestCase):
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 3 assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 3
def test_simple_store_fold(self): def test_simple_store_fold(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i))) for i in range(4)] load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i))) for i in range(4)]
sink = UOp(UOps.SINK, dtypes.void, tuple(load)) sink = UOp(UOps.SINK, dtypes.void, tuple(load))
sink = float4_rewrite(sink) sink = float4_rewrite(sink)
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1 assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1
def test_simple_store_fold_gate(self): def test_simple_store_fold_gate(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
gate = UOp.variable("g1", False, True, dtypes.bool) gate = UOp.variable("g1", False, True, dtypes.bool)
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)] load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
sink = UOp(UOps.SINK, dtypes.void, tuple(load)) sink = UOp(UOps.SINK, dtypes.void, tuple(load))
@ -642,7 +641,7 @@ class TestLoadStoreFolder(unittest.TestCase):
assert str(one_store.src[3]) == str(gate) # huh, why do i need str here? assert str(one_store.src[3]) == str(gate) # huh, why do i need str here?
def test_simple_store_dont_fold(self): def test_simple_store_dont_fold(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
gate = UOp.variable("g1", False, True, dtypes.bool) gate = UOp.variable("g1", False, True, dtypes.bool)
gate2 = UOp.variable("g2", False, True, dtypes.bool) gate2 = UOp.variable("g2", False, True, dtypes.bool)
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)] load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
@ -655,8 +654,8 @@ def gate_rewrite(sink): return graph_rewrite(sink, sym + expander + reducer)
class TestIFUOps(unittest.TestCase): class TestIFUOps(unittest.TestCase):
def test_create_ifs(self): def test_create_ifs(self):
gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) gbuf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float, local=True), (), ("smem", 4)) sbuf = UOp(UOps.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 4))
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5) valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4)) lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
gate = valid&(lidx.ne(2)) gate = valid&(lidx.ne(2))
@ -674,8 +673,8 @@ class TestIFUOps(unittest.TestCase):
self.assertEqual(len(st.src), 3) self.assertEqual(len(st.src), 3)
def test_expand_ifs_one_gate(self): def test_expand_ifs_one_gate(self):
gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) gbuf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float, local=True), (), ("smem", 16)) sbuf = UOp(UOps.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 16))
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1) valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1)
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16)) lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
gate = valid&(lidx.ne(2)) gate = valid&(lidx.ne(2))
@ -694,7 +693,7 @@ class TestIFUOps(unittest.TestCase):
# this will be fixed with the merge gated stores bounty # this will be fixed with the merge gated stores bounty
@unittest.expectedFailure @unittest.expectedFailure
def test_expand_ifs_dumb(self): def test_expand_ifs_dumb(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5) valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4)) lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
gate = valid&(lidx.ne(2)) gate = valid&(lidx.ne(2))

View File

@ -4,7 +4,7 @@ import numpy as np
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.helpers import CI, DEBUG, getenv, Context from tinygrad.helpers import CI, DEBUG, getenv, Context
from tinygrad.dtype import dtypes, DType, PtrDType from tinygrad.dtype import dtypes, DType
from tinygrad.device import Buffer, Device from tinygrad.device import Buffer, Device
from tinygrad.ops import UOps, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401 from tinygrad.ops import UOps, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
from tinygrad.renderer import Program from tinygrad.renderer import Program
@ -30,8 +30,8 @@ def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], src:Tuple[UOp, ...], ar
def _test_single_value(vals, op, dts): def _test_single_value(vals, op, dts):
uops = [] uops = []
output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1] output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1]
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), 0) buf_store = uop(uops, UOps.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), i+1) for i,dtype in enumerate(dts)] buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)]
loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts)) loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts))
alu = uop(uops, UOps.ALU, output_dtype, loads, op) alu = uop(uops, UOps.ALU, output_dtype, loads, op)
out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
@ -46,7 +46,7 @@ def _test_single_value(vals, op, dts):
def _test_single_value_const(vals, op, dts): def _test_single_value_const(vals, op, dts):
uops = [] uops = []
output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1] output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1]
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), 0) buf_store = uop(uops, UOps.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts)) loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
alu = uop(uops, UOps.ALU, output_dtype, loads, op) alu = uop(uops, UOps.ALU, output_dtype, loads, op)
out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
@ -59,7 +59,7 @@ def _test_single_value_const(vals, op, dts):
def _test_uops_result(output_dtype, uops, res): def _test_uops_result(output_dtype, uops, res):
# uops = [] # uops = []
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), 0) buf_store = uop(uops, UOps.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
# res = output_fn(uops) # res = output_fn(uops)
out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), res)) out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), res))
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
@ -246,7 +246,7 @@ class TestConstantFolding(unittest.TestCase):
class TestGatedStoreRewrite(unittest.TestCase): class TestGatedStoreRewrite(unittest.TestCase):
@unittest.expectedFailure @unittest.expectedFailure
def test_tiny_gate_store(self): def test_tiny_gate_store(self):
gmem = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) gmem = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4)) gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
idx = gidx0 * UOp.const(dtypes.int, 2) idx = gidx0 * UOp.const(dtypes.int, 2)
val = UOp.const(dtypes.float, 42.0) val = UOp.const(dtypes.float, 42.0)
@ -263,8 +263,8 @@ class TestGatedStoreRewrite(unittest.TestCase):
@unittest.expectedFailure @unittest.expectedFailure
def test_gate_some_stores(self): def test_gate_some_stores(self):
gmem0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) gmem0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
gmem1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 1) gmem1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4)) gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
idx = gidx0*UOp.const(dtypes.int, 2) idx = gidx0*UOp.const(dtypes.int, 2)
val = UOp.const(dtypes.float, 42.0) val = UOp.const(dtypes.float, 42.0)
@ -282,8 +282,8 @@ class TestGatedStoreRewrite(unittest.TestCase):
# scaled down version of TestLinearizerDumb.test_unmerged_ifs # scaled down version of TestLinearizerDumb.test_unmerged_ifs
@unittest.expectedFailure @unittest.expectedFailure
def test_merge_ifs_alt(self): def test_merge_ifs_alt(self):
gmem0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) gmem0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
gmem1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 1) gmem1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4)) gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
idx = gidx0*UOp.const(dtypes.int, 2) idx = gidx0*UOp.const(dtypes.int, 2)
val = UOp.const(dtypes.float, 42.0) val = UOp.const(dtypes.float, 42.0)
@ -305,7 +305,7 @@ class TestLocalAccess(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
def test_local_basic(self): def test_local_basic(self):
uops = [] uops = []
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.float32, local=True), (), ('smem', 16)) smem = uop(uops, UOps.DEFINE_LOCAL, dtypes.float32.ptr(local=True), (), ('smem', 16))
st = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), uop(uops, UOps.CONST, dtypes.float32, (), 42.0))) st = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), uop(uops, UOps.CONST, dtypes.float32, (), 42.0)))
barr = uop(uops, UOps.BARRIER, dtypes.void, (st,)) barr = uop(uops, UOps.BARRIER, dtypes.void, (st,))
sres = uop(uops, UOps.LOAD, dtypes.float32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), barr)) sres = uop(uops, UOps.LOAD, dtypes.float32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), barr))
@ -314,7 +314,7 @@ class TestLocalAccess(unittest.TestCase):
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
def test_local_indirect(self): def test_local_indirect(self):
uops = [] uops = []
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.int32, local=True), (), ('smem', 16)) smem = uop(uops, UOps.DEFINE_LOCAL, dtypes.int32.ptr(local=True), (), ('smem', 16))
st1 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), uop(uops, UOps.CONST, dtypes.int32, (), 2))) st1 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), uop(uops, UOps.CONST, dtypes.int32, (), 2)))
st2 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 2), uop(uops, UOps.CONST, dtypes.int32, (), 42))) st2 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 2), uop(uops, UOps.CONST, dtypes.int32, (), 42)))
barr = uop(uops, UOps.BARRIER, dtypes.void, (st1,st2)) barr = uop(uops, UOps.BARRIER, dtypes.void, (st1,st2))
@ -325,7 +325,7 @@ class TestLocalAccess(unittest.TestCase):
@unittest.skipUnless(getenv("PTX"), "This only tests assembly backends") @unittest.skipUnless(getenv("PTX"), "This only tests assembly backends")
class TestAssembly(unittest.TestCase): class TestAssembly(unittest.TestCase):
def test_bitshift_left(self): def test_bitshift_left(self):
g1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 0) g1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
c1 = UOp(UOps.CONST, dtypes.int, (), 2) c1 = UOp(UOps.CONST, dtypes.int, (), 2)
c2 = UOp(UOps.CONST, dtypes.int, (), 3) c2 = UOp(UOps.CONST, dtypes.int, (), 3)
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1)) l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
@ -337,7 +337,7 @@ class TestAssembly(unittest.TestCase):
self.assertEqual(uops[-2].arg, BinaryOps.MUL) self.assertEqual(uops[-2].arg, BinaryOps.MUL)
def test_bitshift_right(self): def test_bitshift_right(self):
g1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 0) g1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
c1 = UOp(UOps.CONST, dtypes.int, (), 2) c1 = UOp(UOps.CONST, dtypes.int, (), 2)
c2 = UOp(UOps.CONST, dtypes.int, (), 3) c2 = UOp(UOps.CONST, dtypes.int, (), 3)
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1)) l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
@ -361,7 +361,7 @@ class TestUOpMethod(unittest.TestCase):
def test_uop_variables(self): def test_uop_variables(self):
a = UOp.variable("a", 1, 10) a = UOp.variable("a", 1, 10)
uop_var = UOp.const(dtypes.int, a) uop_var = UOp.const(dtypes.int, a)
st_var = UOp(UOps.LOAD, dtypes.float, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0), st_var = UOp(UOps.LOAD, dtypes.float, (UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0),
ShapeTracker.from_shape((2, a)).to_uop())) ShapeTracker.from_shape((2, a)).to_uop()))
ast_vars = (st_var+uop_var).variables() ast_vars = (st_var+uop_var).variables()
self.assertEqual(len(ast_vars), 1) self.assertEqual(len(ast_vars), 1)
@ -376,7 +376,7 @@ class TestUOpMethod(unittest.TestCase):
self.assertEqual((gidx0*3+1).const_factor(), 1) self.assertEqual((gidx0*3+1).const_factor(), 1)
def test_replace(self): def test_replace(self):
x = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.void), (), 0) x = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
self.assertIs(x.replace(arg=None).arg, None) self.assertIs(x.replace(arg=None).arg, None)
with self.assertRaises(AssertionError): x.replace(field="a") with self.assertRaises(AssertionError): x.replace(field="a")
@ -403,7 +403,7 @@ class TestIndexingOrdering(unittest.TestCase):
# NOTE: these tests skip type_verify since they add dtype to STORE # NOTE: these tests skip type_verify since they add dtype to STORE
@unittest.expectedFailure @unittest.expectedFailure
def test_simple_order(self): def test_simple_order(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
uops = to_uops_list([st1, st0], skip_check=True) uops = to_uops_list([st1, st0], skip_check=True)
@ -412,8 +412,8 @@ class TestIndexingOrdering(unittest.TestCase):
@unittest.expectedFailure @unittest.expectedFailure
def test_ordering_multi_output(self): def test_ordering_multi_output(self):
buf0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) buf0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
buf1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 1) buf1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
st0_0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf0, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) st0_0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf0, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
st1_0 = UOp(UOps.STORE, dtypes.float, (buf0, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) st1_0 = UOp(UOps.STORE, dtypes.float, (buf0, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
st0_1 = UOp(UOps.STORE, dtypes.float.vec(4), (buf1, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) st0_1 = UOp(UOps.STORE, dtypes.float.vec(4), (buf1, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
@ -430,7 +430,7 @@ class TestIndexingOrdering(unittest.TestCase):
assert stores[2].src[1] < stores[3].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}" assert stores[2].src[1] < stores[3].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
def test_simple_order_with_special(self): def test_simple_order_with_special(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4)) gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))

View File

@ -5,7 +5,7 @@ from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import lower_schedule_item from tinygrad.engine.realize import lower_schedule_item
from tinygrad.codegen.linearize import linearize_uop from tinygrad.codegen.linearize import linearize_uop
from tinygrad.ops import BinaryOps, TernaryOps, flops_mem, UOps, UOp from tinygrad.ops import BinaryOps, TernaryOps, flops_mem, UOps, UOp
from tinygrad.dtype import PtrDType, dtypes from tinygrad.dtype import dtypes
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
# **************** new FlopCounter **************** # **************** new FlopCounter ****************
@ -119,7 +119,7 @@ class TestUOpsStats(unittest.TestCase):
#MULACC should have the same stats as MUL + ADD #MULACC should have the same stats as MUL + ADD
def test_mulacc(self): def test_mulacc(self):
globl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple()) globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1) o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2) o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1)) u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1))
@ -129,7 +129,7 @@ class TestUOpsStats(unittest.TestCase):
u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD) u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
uops = linearize_uop(u5.sink()) uops = linearize_uop(u5.sink())
globl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple()) globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1) o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2) o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1)) u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1))

View File

@ -1,6 +1,6 @@
from typing import Dict, List, Optional from typing import Dict, List, Optional
import unittest import unittest
from tinygrad.dtype import PtrDType, dtypes from tinygrad.dtype import dtypes
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, UOps, UPat, \ from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, UOps, UPat, \
graph_rewrite, contexts, track_rewrites graph_rewrite, contexts, track_rewrites
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json from tinygrad.viz.serve import get_details, get_metadata, uop_to_json
@ -27,7 +27,7 @@ class TestViz(unittest.TestCase):
pm = PatternMatcher([ pm = PatternMatcher([
(UPat.var("x")*1, lambda x:x), (UPat.var("x")*1, lambda x:x),
]) ])
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0))) a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
uops = helper_test_viz(a*1, pm) uops = helper_test_viz(a*1, pm)
self.assertEqual(len(uops), 1) self.assertEqual(len(uops), 1)
self.assertEqual(uops[0], a) self.assertEqual(uops[0], a)
@ -37,15 +37,15 @@ class TestViz(unittest.TestCase):
(UPat.var("x")+UPat.var("x"), lambda x:x*2), (UPat.var("x")+UPat.var("x"), lambda x:x*2),
(UPat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))), (UPat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))),
]) ])
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0))) a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
uops = helper_test_viz(a+a, pm) uops = helper_test_viz(a+a, pm)
self.assertEqual(len(uops), 2) self.assertEqual(len(uops), 2)
self.assertEqual(uops[0], a*2) self.assertEqual(uops[0], a*2)
self.assertEqual(uops[1], graph_rewrite(a+a, pm)) self.assertEqual(uops[1], graph_rewrite(a+a, pm))
def test_rewrite_with_ctx(self): def test_rewrite_with_ctx(self):
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0))) a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
b = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 1), UOp.const(dtypes.int, 0))) b = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1), UOp.const(dtypes.int, 0)))
def store_load(visited:Dict[UOp, None], x:UOp) -> Optional[UOp]: def store_load(visited:Dict[UOp, None], x:UOp) -> Optional[UOp]:
if x in visited: return None if x in visited: return None
visited[x] = None visited[x] = None
@ -85,7 +85,7 @@ class TestViz(unittest.TestCase):
self.assertEqual(len(ret), 1) self.assertEqual(len(ret), 1)
def test_fold_const(self): def test_fold_const(self):
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0))) a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
graph = uop_to_json(a) graph = uop_to_json(a)
assert not any(v[0].startswith("CONST") for v in graph.values()) assert not any(v[0].startswith("CONST") for v in graph.values())
assert len([x for x in graph.values() if "CONST" in x[0]]) == 1 assert len([x for x in graph.values() if "CONST" in x[0]]) == 1

View File

@ -3,12 +3,12 @@ from typing import Tuple
from tinygrad.codegen.linearize import linearize_uop from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.uopgraph import full_graph_rewrite, is_increasing from tinygrad.codegen.uopgraph import full_graph_rewrite, is_increasing
from tinygrad.dtype import dtypes, PtrDType from tinygrad.dtype import dtypes
from tinygrad.ops import UOp, UOps, BinaryOps from tinygrad.ops import UOp, UOps, BinaryOps
def get_gated_load_uop(valid:UOp, idx:UOp): def get_gated_load_uop(valid:UOp, idx:UOp):
return UOp(UOps.LOAD, dtypes.float, ( return UOp(UOps.LOAD, dtypes.float, (
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0), UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
idx, idx,
UOp.const(dtypes.float, 0.0), UOp.const(dtypes.float, 0.0),
valid valid

View File

@ -7,7 +7,7 @@ from typing import Tuple
# *** fake symobilc uops *** # *** fake symobilc uops ***
from tinygrad.helpers import DEBUG from tinygrad.helpers import DEBUG
from tinygrad.dtype import dtypes, PtrDType, ConstType from tinygrad.dtype import dtypes, ConstType
from tinygrad.codegen.linearize import linearize_uop from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops, graph_rewrite from tinygrad.ops import BinaryOps, UOp, UOps, print_uops, graph_rewrite
@ -16,7 +16,7 @@ import functools
def render(self) -> Tuple[str, ConstType, ConstType]: def render(self) -> Tuple[str, ConstType, ConstType]:
# NOTE: we need STORE so the ALU op has children # NOTE: we need STORE so the ALU op has children
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0) glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
uops = linearize_uop(full_graph_rewrite(UOp(UOps.STORE, dtypes.void, (glbl, UOp.const(dtypes.int, 0), self)).sink())) uops = linearize_uop(full_graph_rewrite(UOp(UOps.STORE, dtypes.void, (glbl, UOp.const(dtypes.int, 0), self)).sink()))
if DEBUG>=5: print_uops(uops) if DEBUG>=5: print_uops(uops)
from tinygrad.renderer.cstyle import CStyleLanguage from tinygrad.renderer.cstyle import CStyleLanguage

View File

@ -3,7 +3,6 @@ import unittest
from tinygrad import Tensor from tinygrad import Tensor
from tinygrad.codegen.kernel import Kernel from tinygrad.codegen.kernel import Kernel
from tinygrad.dtype import PtrDType
from tinygrad.helpers import DEBUG from tinygrad.helpers import DEBUG
from tinygrad.ops import UOp, UOps, ReduceOps, print_uops from tinygrad.ops import UOp, UOps, ReduceOps, print_uops
from tinygrad.codegen.kernel import verify_ast from tinygrad.codegen.kernel import verify_ast
@ -27,9 +26,9 @@ def helper_test_verify_ast(*stores:UOp) -> Kernel:
class TestVerifyAST(unittest.TestCase): class TestVerifyAST(unittest.TestCase):
def test_tiny_add(self): def test_tiny_add(self):
dtype = dtypes.int dtype = dtypes.int
buf_0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtype), (), 0) buf_0 = UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), 0)
buf_1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtype), (), 1) buf_1 = UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), 1)
buf_2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtype), (), 2) buf_2 = UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), 2)
a = UOp(UOps.LOAD, dtype, (buf_1, ShapeTracker.from_shape((32, 1)).to_uop())) a = UOp(UOps.LOAD, dtype, (buf_1, ShapeTracker.from_shape((32, 1)).to_uop()))
b = UOp(UOps.LOAD, dtype, (buf_2, ShapeTracker.from_shape((32, 1)).to_uop())) b = UOp(UOps.LOAD, dtype, (buf_2, ShapeTracker.from_shape((32, 1)).to_uop()))
store = UOp(UOps.STORE, dtypes.void, (buf_0, ShapeTracker.from_shape((32, 1)).to_uop(), a+b)) store = UOp(UOps.STORE, dtypes.void, (buf_0, ShapeTracker.from_shape((32, 1)).to_uop(), a+b))
@ -37,7 +36,7 @@ class TestVerifyAST(unittest.TestCase):
def test_exactly_one_full_shape(self): def test_exactly_one_full_shape(self):
dtype = dtypes.int dtype = dtypes.int
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtype), (), i) for i in range(6)] bufs = [UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), i) for i in range(6)]
a = UOp(UOps.LOAD, dtype, (bufs[2], ShapeTracker.from_shape((32, 1)).to_uop())) a = UOp(UOps.LOAD, dtype, (bufs[2], ShapeTracker.from_shape((32, 1)).to_uop()))
b = UOp(UOps.LOAD, dtype, (bufs[3], ShapeTracker.from_shape((32, 1)).to_uop())) b = UOp(UOps.LOAD, dtype, (bufs[3], ShapeTracker.from_shape((32, 1)).to_uop()))
st0 = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), a+b) st0 = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), a+b)
@ -47,28 +46,28 @@ class TestVerifyAST(unittest.TestCase):
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st0, st1) with self.assertRaises(InvalidASTException): helper_test_verify_ast(st0, st1)
def test_no_implicit_broadcasting(self): def test_no_implicit_broadcasting(self):
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)] bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop())) a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop()))
b = a + UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.MAX, (1,))) b = a + UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.MAX, (1,)))
st = UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b)) st = UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b))
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st) with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
def test_shrink_ok(self): def test_shrink_ok(self):
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)] bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)).to_uop())) a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)).to_uop()))
b = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(0, 1), offset=0, mask=None, contiguous=False),)).to_uop())) b = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(0, 1), offset=0, mask=None, contiguous=False),)).to_uop()))
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), a+b) st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), a+b)
helper_test_verify_ast(st) helper_test_verify_ast(st)
def test_reduce_store(self): def test_reduce_store(self):
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)] bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop())) a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
r = UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,))) r = UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r) st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r)
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st) with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
def test_reduce_add_store(self): def test_reduce_add_store(self):
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)] bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop())) a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
r = UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,))) r = UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r+a) st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r+a)

View File

@ -9,7 +9,7 @@ from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp
graph_rewrite, track_rewrites, Variable, sint graph_rewrite, track_rewrites, Variable, sint
from tinygrad.device import Device from tinygrad.device import Device
from tinygrad.renderer import Renderer, TensorCore, Program from tinygrad.renderer import Renderer, TensorCore, Program
from tinygrad.dtype import ImageDType, PtrDType from tinygrad.dtype import ImageDType
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put
from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX
from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.shapetracker import ShapeTracker
@ -660,7 +660,7 @@ class Kernel:
st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD] st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD]
local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape)) local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape))
st_uop = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uop() st_uop = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uop()
membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in, True), (), (f"temp{-(-1-i)}", st_uop.arg.real_size())) membuf = UOp(UOps.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (), (f"temp{-(-1-i)}", st_uop.arg.real_size()))
local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, st_uop, src)), fix_st_fxn) local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, st_uop, src)), fix_st_fxn)
srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, st_uop, local_store))) srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, st_uop, local_store)))
else: else:
@ -690,7 +690,7 @@ class Kernel:
for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \ for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
(1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)]) (1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
st_uop = ShapeTracker.from_shape(local_shape).to_uop() st_uop = ShapeTracker.from_shape(local_shape).to_uop()
local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(op.dtype, True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size())) local_buffer = UOp(UOps.DEFINE_LOCAL, op.dtype.ptr(local=True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, start))) local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, start)))
grouped_reduce = UOp(UOps.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], second_axis)) grouped_reduce = UOp(UOps.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], second_axis))
if op is self.reduceops[-1]: return grouped_reduce if op is self.reduceops[-1]: return grouped_reduce

View File

@ -13,7 +13,7 @@ if TYPE_CHECKING: from tinygrad.renderer import Renderer
# ***** float4/image store handling ***** # ***** float4/image store handling *****
def fold_expanded(ex, buf): def fold_expanded(ex, buf):
if buf.dtype != PtrDType(dtypes.float) and buf.dtype != PtrDType(dtypes.half) and not isinstance(buf.dtype, ImageDType): return None if buf.dtype != dtypes.float.ptr() and buf.dtype != dtypes.half.ptr() and not isinstance(buf.dtype, ImageDType): return None
new_srcs = dedup(list(ex.src)) new_srcs = dedup(list(ex.src))
old_new_srcs = new_srcs[:] old_new_srcs = new_srcs[:]
is_load, is_image = new_srcs[0].op is UOps.LOAD, isinstance(buf.dtype, ImageDType) is_load, is_image = new_srcs[0].op is UOps.LOAD, isinstance(buf.dtype, ImageDType)
@ -32,7 +32,7 @@ def fold_expanded(ex, buf):
offsets_rootsrc[root_src][arg] = i offsets_rootsrc[root_src][arg] = i
# then rewrite everything we can # then rewrite everything we can
lengths = [4] if is_image else ([8,4,2] if buf.dtype == PtrDType(dtypes.half) and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2])) lengths = [4] if is_image else ([8,4,2] if buf.dtype == dtypes.half.ptr() and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]))
used = set() used = set()
for rootsrc, offsets in offsets_rootsrc.items(): for rootsrc, offsets in offsets_rootsrc.items():
for o in offsets: for o in offsets:

View File

@ -18,7 +18,7 @@ class DType:
assert self.count == 1, f"can't vectorize {self} with size {sz}" assert self.count == 1, f"can't vectorize {self} with size {sz}"
if sz == 1 or self.name == 'void': return self # void doesn't vectorize, and sz=1 is scalar if sz == 1 or self.name == 'void': return self # void doesn't vectorize, and sz=1 is scalar
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz) return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
def ptr(self) -> Union[PtrDType, ImageDType]: return PtrDType(self) def ptr(self, local=False) -> Union[PtrDType, ImageDType]: return PtrDType(self, local)
def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
# dependent typing? # dependent typing?
@ -29,10 +29,9 @@ class ImageDType(DType):
local: bool = False # images are never local local: bool = False # images are never local
def scalar(self) -> DType: return self.base def scalar(self) -> DType: return self.base
def vec(self, sz:int): return self.base.vec(sz) def vec(self, sz:int): return self.base.vec(sz)
def ptr(self) -> Union[PtrDType, ImageDType]: return self def ptr(self, local=False) -> Union[PtrDType, ImageDType]: return self
def __repr__(self): return f"dtypes.{self.name}({self.shape})" def __repr__(self): return f"dtypes.{self.name}({self.shape})"
# @dataclass(frozen=True, init=False, repr=False, eq=False)
class PtrDType(DType): class PtrDType(DType):
def __init__(self, dt:DType, local=False): def __init__(self, dt:DType, local=False):
self.base, self.local = dt, local self.base, self.local = dt, local
@ -40,7 +39,7 @@ class PtrDType(DType):
def __hash__(self): return super().__hash__() def __hash__(self): return super().__hash__()
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
def __ne__(self, dt): return not (self == dt) def __ne__(self, dt): return not (self == dt)
def __repr__(self): return f"PtrDType({super().__repr__()}, local=True)" if self.local else f"PtrDType({super().__repr__()})" def __repr__(self): return f"{super().__repr__()}.ptr(local=True)" if self.local else f"{super().__repr__()}.ptr()"
class dtypes: class dtypes:
@staticmethod @staticmethod

View File

@ -102,17 +102,19 @@ class CStyleLanguage(Renderer):
def get_kernel_modifier(self, uops:List[UOp]) -> str: return "" def get_kernel_modifier(self, uops:List[UOp]) -> str: return ""
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str: def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501 tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
buftypes = [(name,f"{'write_only' if mutable else 'read_only'} image2d_t" if dtype.name.startswith('image') else buftypes = [(name, self.render_dtype(dtype, mutable)+self.buffer_suffix if isinstance(dtype, (ImageDType, PtrDType)) else
("" if mutable else "const ")+self.render_dtype(dtype)+self.buffer_suffix if isinstance(dtype, PtrDType) else
self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs] self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
prg = ''.join([f"{self.kernel_prefix}void {self.get_kernel_modifier(uops)}{function_name}(",] + prg = ''.join([f"{self.kernel_prefix}void {self.get_kernel_modifier(uops)}{function_name}(",] +
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] + [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"]) [") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
return prg if prefix is None else "\n".join(prefix)+f"\n{prg}" return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
def render_dtype(self, dt:DType) -> str: def render_dtype(self, dt:DType, mutable=True) -> str:
if isinstance(dt, ImageDType):
return f"{'write_only' if mutable else 'read_only'} image2d_t"
if isinstance(dt, PtrDType): if isinstance(dt, PtrDType):
return (self.smem_prefix if dt.local else self.buffer_prefix) + self.render_dtype(dt.base) + ("*" if isinstance(dt, PtrDType) else "") return ("" if mutable else "const ") + (self.smem_prefix if dt.local else self.buffer_prefix) +\
self.render_dtype(dt.base) + ("*" if isinstance(dt, PtrDType) else "")
return self.type_map.get(scalar:=dt.scalar(), scalar.name) + (str(dt.count) if (dt.count) > 1 else "") return self.type_map.get(scalar:=dt.scalar(), scalar.name) + (str(dt.count) if (dt.count) > 1 else "")
def __getitem__(self, key): return self.r[key] # hacky helper def __getitem__(self, key): return self.r[key] # hacky helper
@ -126,12 +128,8 @@ class CStyleLanguage(Renderer):
depth = 1 depth = 1
c: DefaultDict[str, int] = defaultdict(int) c: DefaultDict[str, int] = defaultdict(int)
for u in uops: for u in uops:
if u.op is UOps.DEFINE_GLOBAL: if u.op in (UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR):
r[u] = f"data{u.arg}" r[u] = f"data{u.arg}" if u.op is UOps.DEFINE_GLOBAL else u.arg[0]
bufs[u] = (r[u], (u.dtype, False))
continue
if u.op is UOps.DEFINE_VAR:
r[u] = u.arg[0]
bufs[u] = (r[u], (u.dtype, False)) bufs[u] = (r[u], (u.dtype, False))
continue continue