diff --git a/docs/abstractions2.py b/docs/abstractions2.py index 34fa4ded..6006d9d8 100644 --- a/docs/abstractions2.py +++ b/docs/abstractions2.py @@ -37,7 +37,7 @@ print("******** second, the Device ***********") DEVICE = "CLANG" # NOTE: you can change this! import struct -from tinygrad.dtype import PtrDType, dtypes +from tinygrad.dtype import dtypes from tinygrad.device import Buffer, Device from tinygrad.ops import BinaryOps, MetaOps, UOp, UOps from tinygrad.shape.shapetracker import ShapeTracker @@ -49,12 +49,12 @@ b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struc # NOTE: a._buf is the same as the return from MallocAllocator.alloc # describe the computation -buf_1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 1) -buf_2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 2) +buf_1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 1) +buf_2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 2) ld_1 = UOp(UOps.LOAD, dtypes.int32, (buf_1, ShapeTracker.from_shape((1,)).to_uop())) ld_2 = UOp(UOps.LOAD, dtypes.int32, (buf_2, ShapeTracker.from_shape((1,)).to_uop())) alu = ld_1 + ld_2 -output_buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 0) +output_buf = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0) st_0 = UOp(UOps.STORE, dtypes.void, (output_buf, ShapeTracker.from_shape((1,)).to_uop(), alu)) s = UOp(UOps.SINK, dtypes.void, (st_0,)) diff --git a/test/external/external_test_valid_remove.py b/test/external/external_test_valid_remove.py index bf5e03b7..4deca007 100644 --- a/test/external/external_test_valid_remove.py +++ b/test/external/external_test_valid_remove.py @@ -4,7 +4,7 @@ import unittest from tinygrad import Device from tinygrad.ops import UOp, UOps, BinaryOps, UnaryOps from tinygrad.engine.search import Opt, OptOps -from tinygrad.dtype import dtypes, PtrDType +from tinygrad.dtype import dtypes from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.codegen.kernel import Kernel @@ -31,7 +31,7 @@ class TestOpenpilotValidhack(unittest.TestCase): UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((8, 108, 4)), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 3, 4, 3, 3), strides=(0, 0, 0, 0, 0, 432, 1, 48, 4, 144, 16), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 4, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x19:=UOp(UOps.CONST, dtypes.float, arg=0.0, src=( x20:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), @@ -81,7 +81,7 @@ class TestOpenpilotValidhack(unittest.TestCase): UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - x18:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + x18:=UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=48128, mask=((0, 1), (1, 2), (0, 512)), contiguous=False),)), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( x18, diff --git a/test/test_dtype.py b/test/test_dtype.py index c652d517..d2046a45 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -312,15 +312,15 @@ class TestEqStrDType(unittest.TestCase): def test_ptr_ne(self): if PtrDType is None: raise unittest.SkipTest("no PtrDType support") # TODO: is this the wrong behavior? - assert PtrDType(dtypes.float32) == dtypes.float32 - assert not (PtrDType(dtypes.float32) != dtypes.float32) - assert PtrDType(dtypes.float32) == PtrDType(dtypes.float32) - assert not (PtrDType(dtypes.float32) != PtrDType(dtypes.float32)) - #assert PtrDType(dtypes.float32) != dtypes.float32 + assert dtypes.float32.ptr() == dtypes.float32 + assert not (dtypes.float32.ptr() != dtypes.float32) + assert dtypes.float32.ptr() == dtypes.float32.ptr() + assert not (dtypes.float32.ptr() != dtypes.float32.ptr()) + #assert dtypes.float32.ptr() != dtypes.float32 def test_strs(self): if PtrDType is None: raise unittest.SkipTest("no PtrDType support") self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))") - self.assertEqual(str(PtrDType(dtypes.float32)), "PtrDType(dtypes.float)") + self.assertEqual(str(dtypes.float32.ptr()), "dtypes.float.ptr()") class TestHelpers(unittest.TestCase): signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64) diff --git a/test/test_linearizer.py b/test/test_linearizer.py index c1b9f179..defb580a 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -15,7 +15,7 @@ from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.engine.schedule import BUF_LIMIT, create_schedule from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner from tinygrad.helpers import prod, Context, getenv, CI, flatten, dedup, AMX -from tinygrad.dtype import DType, PtrDType, dtypes +from tinygrad.dtype import DType, dtypes def helper_realized_ast(r:Union[Tensor, List[Tensor]]) -> Tuple[UOp, List[Buffer]]: if isinstance(r, Tensor): r = [r] @@ -84,7 +84,7 @@ class TestLinearizer(unittest.TestCase): def test_multioutput(self): dtype, st = dtypes.int, ShapeTracker.from_shape((8,)) - g0, g1, g2, g3 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtype), arg=i) for i in range(4)] + g0, g1, g2, g3 = [UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), arg=i) for i in range(4)] a = UOp(UOps.LOAD, dtype, (g2, st.to_uop())) b = UOp(UOps.LOAD, dtype, (g3, st.to_uop())) out0 = UOp(UOps.STORE, dtypes.void, (g0, st.to_uop(), a + b)) @@ -107,7 +107,7 @@ class TestLinearizer(unittest.TestCase): Tensor.manual_seed(0) x = Tensor.randn(32, dtype=dtypes.float).realize() st_x = x.lazydata.st - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] + g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((1, 32)).expand((32, 32)).to_uop())) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (1,))) second_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((32, 1)).to_uop())) @@ -143,7 +143,7 @@ class TestLinearizer(unittest.TestCase): Tensor.manual_seed(0) x = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize() st_x = x.lazydata.st - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] + g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((27, 1, 32, 5)).expand((27, 32, 32, 5)).to_uop())) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) second_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((27, 32, 1, 5)).to_uop())) @@ -205,7 +205,7 @@ class TestLinearizer(unittest.TestCase): x0 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize() x1 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize() x2 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize() - g0, g1, g2, g3 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(4)] + g0, g1, g2, g3 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(4)] first_x = UOp(UOps.LOAD, dtypes.float, (g1, x0.lazydata.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5)).to_uop())) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) second_x = UOp(UOps.LOAD, dtypes.float, (g2, x1.lazydata.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5)).to_uop())) @@ -232,7 +232,7 @@ class TestLinearizer(unittest.TestCase): Tensor.manual_seed(0) x = Tensor.randn(8, 32, 8, 16, dtype=dtypes.float).realize() st = x.lazydata.st - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] + g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(UOps.LOAD, dtypes.float, (g1, st.reshape((8, 1, 32, 8, 1, 16)).expand((8, 32, 32, 8, 16, 16)).to_uop())) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2, 5))) second_x = UOp(UOps.LOAD, dtypes.float, (g1, st.reshape((8, 32, 1, 8, 16, 1)).to_uop())) @@ -283,7 +283,7 @@ class TestLinearizer(unittest.TestCase): # check how it works with one reduce optimized and one unoptimized Tensor.manual_seed(0) x = Tensor.randn(27, 15, 5, dtype=dtypes.float).softmax(1).realize() - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] + g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop())) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop())) @@ -314,7 +314,7 @@ class TestLinearizer(unittest.TestCase): Tensor.manual_seed(0) x = Tensor.randn(4, 32, dtype=dtypes.float).realize() x_p = Tensor.randn(4, 32, dtype=dtypes.float).realize() - g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)] + g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop())) first_x_p = UOp(UOps.LOAD, dtypes.float, (g2, x_p.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop())) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) @@ -350,7 +350,7 @@ class TestLinearizer(unittest.TestCase): # check how multireduce works with multioutput Tensor.manual_seed(0) x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize() - g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)] + g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] first_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop())) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) second_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop())) @@ -373,7 +373,7 @@ class TestLinearizer(unittest.TestCase): # if we change the shape of store1 to be contiguous, it will match store0 but not the value it's storing (FAIL!) Tensor.manual_seed(0) x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize() - g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)] + g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] first_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop())) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) second_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop())) @@ -397,7 +397,7 @@ class TestLinearizer(unittest.TestCase): def test_complete_unroll_multireduce(self): Tensor.manual_seed(0) x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize() - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] + g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop())) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop())) @@ -413,7 +413,7 @@ class TestLinearizer(unittest.TestCase): def test_upcast_multireduce(self): Tensor.manual_seed(0) x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize() - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] + g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop())) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop())) @@ -432,7 +432,7 @@ class TestLinearizer(unittest.TestCase): # make sure the if block of a grouped reduce can be closed early and the result loaded back in Tensor.manual_seed(0) x = Tensor.randn(27, 12, 5, dtype=dtypes.float).realize() - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] + g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5)).to_uop())) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 12, 1, 5)).to_uop())) @@ -448,7 +448,7 @@ class TestLinearizer(unittest.TestCase): def test_mean_std_multireduce(self): Tensor.manual_seed(0) x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] + g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop())) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1)) @@ -466,7 +466,7 @@ class TestLinearizer(unittest.TestCase): def test_mean_std_multireduce_mid_dim(self): Tensor.manual_seed(0) x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] + g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35)).to_uop())) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) neg_mean = first_reduce * ast_const(dtypes.float, -0.04, (15, 25, 1, 35)) @@ -486,7 +486,7 @@ class TestLinearizer(unittest.TestCase): # TODO: Similar error to test_multiout_intermediate_multireduce (implicit expand vs shape mismatch) Tensor.manual_seed(0) x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() - g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)] + g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] first_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop())) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1)) @@ -508,7 +508,7 @@ class TestLinearizer(unittest.TestCase): def test_var_multireduce(self): Tensor.manual_seed(0) x = Tensor.randn(3, 27, 32, dtype=dtypes.float).realize() - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] + g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] # push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32)).to_uop())) first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) @@ -530,7 +530,7 @@ class TestLinearizer(unittest.TestCase): @unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet") def test_softmax_multireduce(self): x = Tensor.rand(4, 32).realize() - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] + g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32,)).expand((4, 32, 32)).to_uop())) max_x = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.MAX, (2,))) second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1,)).to_uop())) @@ -559,15 +559,15 @@ class TestLinearizer(unittest.TestCase): arange = UOp(UOps.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (BinaryOps.ADD, arange_axis)) output_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape)) out = arange+ast_const(dtypes.int, -1, output_shape) - store = UOp(UOps.STORE, src=(UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0), ShapeTracker.from_shape(output_shape).to_uop(), out)) + store = UOp(UOps.STORE, src=(UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0), ShapeTracker.from_shape(output_shape).to_uop(), out)) sink = UOp(UOps.SINK, src=(store,)) helper_linearizer_ast(sink, [], wanna_output=[real_arange]) with Context(DEBUG=0, NOOPT=0): np.testing.assert_equal(tiny.numpy(), real_arange) @unittest.skipIf(CI and Device.DEFAULT in {"PTX", "AMD", "NV"}, "very slow") def test_indexing_multireduce(self): - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] - g2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=2) + g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] + g2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2) arange_input_st = ShapeTracker(views=(View(shape=(16385, 32767), strides=(0, 0), offset=0, mask=((0, 16385), (16383, 32767)), contiguous=False), \ View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False))) # TODO: do this arange broadcast in the scheduler @@ -598,7 +598,7 @@ class TestLinearizer(unittest.TestCase): real_argmax = np.argmax(t.numpy(), axis=0, keepdims=False).reshape(1, 20, 1) ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa E501 UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( @@ -611,10 +611,10 @@ class TestLinearizer(unittest.TestCase): UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(20, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa E501 UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa E501 ast_const(dtypes.bool, True, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)), # noqa E501 UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( @@ -630,7 +630,7 @@ class TestLinearizer(unittest.TestCase): real_argmax = np.argmax(t.numpy()) ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501 UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( @@ -643,10 +643,10 @@ class TestLinearizer(unittest.TestCase): UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa: E501 UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa: E501 ast_const(dtypes.bool, True, (200, 1)),)),)), UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( @@ -669,7 +669,7 @@ class TestLinearizer(unittest.TestCase): # [Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)] ] - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] + g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] x_ld0 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop())) x_ld1 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop())) r0 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.ADD, (1,))) @@ -696,7 +696,7 @@ class TestLinearizer(unittest.TestCase): [Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),] ] - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)] + g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] x_ld0 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop())) x_ld1 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop())) r0 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.MAX, (1,))) @@ -730,7 +730,7 @@ class TestLinearizer(unittest.TestCase): ld1 = x.lazydata.st.reshape((N, N, 1)) ast = UOp(UOps.SINK, src=( UOp(UOps.STORE, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),))), UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( @@ -738,20 +738,20 @@ class TestLinearizer(unittest.TestCase): UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), ld1.to_uop(),)), UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( ast_const(dtypes.float, 0.75*N, (N, N, 1)), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), ld0.to_uop(),)),)),)), UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501 UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)), # noqa: E501 ast_const(dtypes.float, 0.0, (N, 1, 1)), @@ -763,7 +763,7 @@ class TestLinearizer(unittest.TestCase): wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=0,keepdims=True), a.numpy(), b.numpy())).sum(axis=0),0.0,1.0).reshape(1,1,N) # noqa: E501 ast = UOp(UOps.SINK, src=( UOp(UOps.STORE, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, N), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( @@ -771,20 +771,20 @@ class TestLinearizer(unittest.TestCase): UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), ld1.to_uop(),)), UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( ast_const(dtypes.float, 0.75*N, (N, 1, N)), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), ld0.to_uop(),)),)),)), UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), # noqa: E501 UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), # noqa: E501 ast_const(dtypes.float, 0.0, (1, 1, N)), @@ -799,7 +799,7 @@ class TestLinearizer(unittest.TestCase): wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(keepdims=True), a.numpy(), b.numpy())).sum(keepdims=True),0.0,1.0).reshape((1,1,1,1))# noqa: E501 ast = UOp(UOps.SINK, src=( UOp(UOps.STORE, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=True),))), UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( @@ -807,20 +807,20 @@ class TestLinearizer(unittest.TestCase): UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(N, 1, 0, 0), offset=0, mask=None, contiguous=True),))),)), # noqa: E501 UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( ast_const(dtypes.float, 0.75*N, (N, N, 1, 1)), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=( UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, N, N), strides=(0, 0, N, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501 UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501 UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)), # noqa: E501 ast_const(dtypes.float, 0.0, (1, 1, 1, 1)), ast_const(dtypes.float, 1.0, (1, 1, 1, 1)),)),)),)) @@ -829,7 +829,7 @@ class TestLinearizer(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals") @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_end_local(self): - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=i) for i in range(2)] + g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=i) for i in range(2)] load = UOp(UOps.LOAD, dtypes.int, (g1, ShapeTracker.from_shape((32,)).to_uop())) reduce = UOp(UOps.REDUCE_AXIS, dtypes.int, (load,), (BinaryOps.ADD, (0,))) store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((1,)).to_uop(), reduce)) @@ -931,7 +931,7 @@ class TestLinearizer(unittest.TestCase): # make sure const buffers are differentiated from local and mem buffers ST, DT = ShapeTracker(views=(View(shape=((1,)), strides=(0, 0), offset=0, mask=None, contiguous=False),)).to_uop(), dtypes.int VAL = ast_const(DT, 2, ST.arg.shape) - g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(DT), arg=i) for i in range(2)] + g0, g1 = [UOp(UOps.DEFINE_GLOBAL, DT.ptr(), arg=i) for i in range(2)] # data1[0] + VAL a = UOp(UOps.LOAD, DT, (g1, ST)) + VAL @@ -1368,10 +1368,10 @@ class TestLinearizer(unittest.TestCase): Tensor.manual_seed(0) ast = UOp(UOps.SINK, src=( UOp(UOps.STORE, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),))), UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501 opt = [ Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16), @@ -1387,10 +1387,10 @@ class TestLinearizer(unittest.TestCase): Tensor.manual_seed(0) ast = UOp(UOps.SINK, src=( UOp(UOps.STORE, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),))), UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501 opt = [Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8), @@ -1603,16 +1603,16 @@ class TestFloat4(unittest.TestCase): # from llama 7B shard 4 gpus ast = UOp(UOps.SINK, src=( UOp(UOps.STORE, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501 UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( UOp(UOps.CAST, dtypes.float, src=( UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.half, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501 UOp(UOps.LOAD, dtypes.half, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 0, 1024, 1), offset=0, mask=None, contiguous=False),))),)),)),)),)),)),)) # noqa: E501 # TODO: fix this, expected might change but should be positive @@ -1632,19 +1632,19 @@ class TestFloat4(unittest.TestCase): # from float32 stable diffusion red tinybox ast = UOp(UOps.SINK, src=( UOp(UOps.STORE, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501 UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False)))),)), # noqa: E501 UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501 UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)) # noqa: E501 for expected, opts in [ @@ -1662,13 +1662,13 @@ class TestFloat4(unittest.TestCase): # from resnet ast = UOp(UOps.SINK, src=( UOp(UOps.STORE, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),))), # noqa: E501 UOp(UOps.CAST, dtypes.half, src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (4, 6)), src=( UOp(UOps.CAST, dtypes.float, src=( UOp(UOps.LOAD, dtypes.half, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True)))),)),)),)),)),)),)) # noqa: E501 for expected, opts in [ (16, [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=3, amt=4)]), # noqa: E501 @@ -1950,20 +1950,20 @@ class TestKernelOpts(unittest.TestCase): def test_buf_index_not_found_tensor_core(self): ast = UOp(UOps.SINK, src=( UOp(UOps.STORE, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),))), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.CAST, dtypes.float, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.LOAD, dtypes.int, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1), + UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501 UOp(UOps.LOAD, dtypes.int, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=2), + UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501 UOp(UOps.LOAD, dtypes.float, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)) # noqa: E501 k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) with self.assertRaises(KernelOptError): @@ -2138,7 +2138,7 @@ class TestKernelOpts(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared") def test_padto_group(self): Tensor.manual_seed(0) - g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)] + g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] ld0 = UOp(UOps.LOAD, dtypes.float, (g1, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501 ld1 = UOp(UOps.LOAD, dtypes.float, (g2, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501 store = UOp(UOps.STORE, src=(g0, ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)).to_uop(), UOp(UOps.REDUCE_AXIS, dtypes.float, (ld0*ld1,), (BinaryOps.ADD, (0, 2, 4, 6)),))) # noqa: E501 diff --git a/test/test_linearizer_dumb.py b/test/test_linearizer_dumb.py index 0ebf9df7..d417e9f0 100644 --- a/test/test_linearizer_dumb.py +++ b/test/test_linearizer_dumb.py @@ -5,7 +5,6 @@ import unittest from test.helpers import ast_const from tinygrad import Device, dtypes -from tinygrad.dtype import PtrDType from tinygrad.ops import UOp, UOps, BinaryOps, TernaryOps from tinygrad.helpers import getenv from tinygrad.shape.shapetracker import ShapeTracker, View @@ -17,7 +16,7 @@ class TestLinearizerDumb(unittest.TestCase): def test_unmerged_ifs(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MAX, src=( UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( @@ -26,10 +25,10 @@ class TestLinearizerDumb(unittest.TestCase): UOp(UOps.CAST, dtypes.float, arg=None, src=( UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 512, 4, 9, 4, 9), strides=(0, 25088, 0, 49, 0, 7, 0, 1), offset=-8, mask=((0, 1), (0, 64), (0, 1), (0, 512), (0, 4), (1, 8), (0, 4), (1, 8)), contiguous=False), View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(663552, 0, 0, 36, 1, 1296, 360, 10), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), ast_const(dtypes.half, 0.9999950000374996, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), @@ -50,17 +49,17 @@ class TestLinearizerDumb(unittest.TestCase): def test_max_simplify_and_cancel(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=( UOp(UOps.CAST, dtypes.int, arg=None, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.bool, True, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), @@ -81,11 +80,11 @@ class TestLinearizerDumb(unittest.TestCase): def test_expander_new_srcs(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) @@ -102,7 +101,7 @@ class TestLinearizerDumb(unittest.TestCase): def test_llama_embedding(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.CAST, dtypes.half, arg=None, src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( @@ -118,12 +117,12 @@ class TestLinearizerDumb(unittest.TestCase): ast_const(dtypes.int, -1, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.LOAD, dtypes.int, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.bool, True, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(1, 4096, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) prg = k.to_program() @@ -134,7 +133,7 @@ class TestLinearizerDumb(unittest.TestCase): def test_unaligns_idxs(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( @@ -142,16 +141,16 @@ class TestLinearizerDumb(unittest.TestCase): UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.LOAD, dtypes.long, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.long), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.CAST, dtypes.long, arg=None, src=( UOp(UOps.LOAD, dtypes.int, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), ast_const(dtypes.bool, True, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=3)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) @@ -166,14 +165,14 @@ class TestLinearizerDumb(unittest.TestCase): def test_unrolled_float4_align(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=( UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.LOAD, dtypes.long, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.long), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)), ast_const(dtypes.long, -1, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), @@ -182,7 +181,7 @@ class TestLinearizerDumb(unittest.TestCase): ast_const(dtypes.float, 0.0, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) @@ -198,15 +197,15 @@ class TestLinearizerDumb(unittest.TestCase): def test_upcasted_stores_out_of_order(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6,)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(0, 0, 0, 0, 0, 0, 1, 0, 4, 48, 16), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=0)] k = Kernel(ast, opts=Device[Device.DEFAULT].renderer) diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 7a7eb25b..3e07b7a7 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -5,7 +5,6 @@ from tinygrad.codegen.kernel import Kernel, KernelOptError from tinygrad.ops import UOp, UOps, BinaryOps, UnaryOps, TernaryOps from tinygrad.engine.search import Opt, OptOps from tinygrad import Device, dtypes, Tensor -from tinygrad.dtype import PtrDType from tinygrad.helpers import CI from test.external.fuzz_linearizer import compare_linearizer from test.helpers import is_dtype_supported, ast_const @@ -43,30 +42,30 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_1(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(16, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 16), strides=(16, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(16, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) helper_test_lin(Kernel(ast), [], failed_platforms=[]) def test_failure_2(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 2, 37, 9, 1, 1), strides=(666, 333, 9, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (4, 5)), src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 2, 111, 27), strides=(6160, 3080, 28, 1), offset=0, mask=((0, 32), (0, 2), (0, 110), (0, 27)), contiguous=False), View(shape=(32, 2, 37, 9, 2, 2), strides=(5994, 2997, 81, 3, 27, 1), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) opts = [Opt(op=OptOps.LOCAL, axis=0, amt=32)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -74,11 +73,11 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_3(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 16, 1), strides=(128, 16, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 16, 16), strides=(2048, 256, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) opts = [Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.UNROLL, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=32)] # METAL: AssertionError: Error Domain=AGXMetalG13X Code=3 "Threadgroup memory size (65536) exceeds the maximum threadgroup memory allowed (32768)" UserInfo={NSLocalizedDescription=Threadgroup memory size (65536) exceeds the maximum threadgroup memory allowed (32768)} @@ -87,7 +86,7 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_5(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( @@ -98,7 +97,7 @@ class TestLinearizerFailures(unittest.TestCase): ast_const(dtypes.float, 1.0, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 1, 4, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x5,)),)),)),)) opts = [Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=0)] @@ -108,7 +107,7 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_6(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( @@ -123,11 +122,11 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_7(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 32, 1, 34, 1, 34), strides=(36992, 1156, 0, 34, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 4)), src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 32, 6, 8, 4, 6, 8, 4), strides=(2048, 64, 6291456, 8, 0, 1048576, 1, 0), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 8), (0, 1), (0, 6), (0, 8), (0, 1)), contiguous=False), View(shape=(512, 32, 6, 35, 6, 35), strides=(1179648, 36864, 6144, 192, 32, 1), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 32), (0, 6), (0, 32)), contiguous=False), View(shape=(512, 32, 238, 238), strides=(1411200, 44100, 210, 1), offset=0, mask=((0, 512), (0, 32), (0, 210), (0, 210)), contiguous=False), View(shape=(512, 32, 7, 34, 7, 34), strides=(1812608, 56644, 8092, 238, 34, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=0, amt=4)] # test/test_linearizer_failures.py Fatal Python error: Segmentation fault @@ -136,7 +135,7 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_8(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.float, arg=UnaryOps.SQRT, src=( UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( @@ -146,10 +145,10 @@ class TestLinearizerFailures(unittest.TestCase): UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( x9:=UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 4096), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)), x9,)),)), ast_const(dtypes.float, 0.000244140625, st_src=( @@ -164,15 +163,15 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_9(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 0, 0, 4500, 0, 0, 0, 0, 900, 60, 12, 4, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 3, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 4500, 0, 0, 0, 0, 0, 0, 900, 60, 12, 4, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -180,26 +179,26 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_10(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.half, arg=BinaryOps.ADD, src=( UOp(UOps.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.ADD, (3,)), src=( UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1024, 50257), strides=(0, 0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1024, 50257), strides=(0, 0, 1, 1024), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=3, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=3, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) helper_test_lin(Kernel(ast), [], failed_platforms=[]) def test_failure_11(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 3)), src=( @@ -209,15 +208,15 @@ class TestLinearizerFailures(unittest.TestCase): UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True),)), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.float, 0.0, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.float, 1.0, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), @@ -235,15 +234,15 @@ class TestLinearizerFailures(unittest.TestCase): UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)), x42:=ast_const(dtypes.float, 0.0, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)), ast_const(dtypes.float, 1.0, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 6, 6), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)), @@ -253,7 +252,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=4, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64,), strides=(1,), offset=0, mask=None, contiguous=True), View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)), ast_const(dtypes.float, 5.425347222222222e-05, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64,), strides=(0,), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)), @@ -261,33 +260,33 @@ class TestLinearizerFailures(unittest.TestCase): UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64,), strides=(0,), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 3, 2, 2), strides=(2304, 36, 12, 2, 6, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)), x42,)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=5, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 3, 3, 2, 2), strides=(576, 9, 3, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)), UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=6, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 3, 3, 2, 2), strides=(576, 9, 3, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=7, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=7, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 64, 3, 3, 2, 2), strides=(576, 9, 3, 1, 0, 0), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 3, 2, 3, 2), strides=(2304, 36, 12, 2, 4, 1), offset=0, mask=None, contiguous=False), View(shape=(512, 64, 6, 6), strides=(2304, 36, 6, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)) helper_test_lin(Kernel(ast), [], failed_platforms=[]) def test_failure_12(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( x5:=UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), ast_const(dtypes.float, 1.0, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x5,)),)),)),)) opts = [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.GROUP, axis=0, amt=4)] @@ -297,7 +296,7 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_12_multireduce(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 8)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( @@ -305,12 +304,12 @@ class TestLinearizerFailures(unittest.TestCase): x6:=UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), ast_const(dtypes.float, 1.0, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x6,)), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 8)), src=( @@ -322,19 +321,19 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_13(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.half, arg=BinaryOps.ADD, src=( UOp(UOps.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.ADD, (3,)), src=( UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 51864), strides=(51864, 0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 51864), strides=(0, 0, 1, 384), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=3, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=3, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(0, 0, 1, 0), offset=19584, mask=None, contiguous=False),)), src=()),)),)),)),)) opts = [Opt(op=OptOps.GROUP, axis=0, amt=4)] helper_test_lin(Kernel(ast), opts, failed_platforms=["METAL", "GPU", "CUDA"]) @@ -342,19 +341,19 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_14(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( x5:=UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), ast_const(dtypes.float, 1.0, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x5,)),)),)),)) opts = [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4)] @@ -364,7 +363,7 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_15(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 196, 14, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( @@ -373,27 +372,27 @@ class TestLinearizerFailures(unittest.TestCase): UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5,)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 480, 1, 1), strides=(0, 0, 0, 14, 1, 196, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 480, 1, 1), strides=(0, 0, 480, 0, 0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=4, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.ALU, dtypes.float, arg=UnaryOps.SQRT, src=( UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=5, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), ast_const(dtypes.float, 1e-05, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=6, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 112, 14, 14, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.GROUP, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=16)] # COMPILE_ERROR on METAL in fuzz_linearizer ast 115: Error Domain=AGXMetalG14X Code=3 "Compiler encountered an internal error" @@ -402,12 +401,12 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_16(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1024), strides=(0, 1024, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)), ast_const(dtypes.float, 0.0009765625, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) @@ -418,15 +417,15 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_17(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 40, 1, 28, 28, 1, 1), strides=(31360, 0, 784, 0, 28, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 40, 240, 28, 28, 1, 1), strides=(0, 0, 1, 40, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 40, 240, 28, 28, 1, 1), strides=(188160, 0, 0, 784, 28, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.GROUPTOP, axis=0, amt=16), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.LOCAL, axis=1, amt=4)] # COMPILE_ERROR on METAL in fuzz_linearizer ast 178: Error Domain=AGXMetalG14X Code=3 "Compiler encountered an internal error" @@ -435,23 +434,23 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_18(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1536), strides=(1536, 0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1536), strides=(0, 0, 1536, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=4, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.GROUPTOP, axis=0, amt=256), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=3)] # COMPILE_ERROR on METAL in fuzz_linearizer ast 239: Error Domain=AGXMetalG14X Code=3 "Compiler encountered an internal error" @@ -460,15 +459,15 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_19(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 9, 7, 3, 3), strides=(2268, 0, 567, 0, 63, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 4, 9, 7, 3, 3), strides=(0, 0, 36, 9, 0, 0, -3, -1), offset=8, mask=None, contiguous=False),)), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 4, 9, 7, 3, 3), strides=(252, 0, 0, 63, 7, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=0), Opt(op=OptOps.GROUP, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=7), Opt(op=OptOps.UPCAST, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=3)] # COMPILE_ERROR on METAL in fuzz_linearizer ast 379: Error Domain=AGXMetalG14X Code=3 "Compiler encountered an internal error" @@ -477,11 +476,11 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_20(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 4), strides=(4, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 4), strides=(0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), ast_const(dtypes.float, 1.0, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 4), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) @@ -491,7 +490,7 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_21(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(65, 1), offset=0, mask=None, contiguous=True),)), src=()), ast_const(dtypes.float, 1.0, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(45, 65), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)) @@ -503,7 +502,7 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_22(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( x4:=ast_const(dtypes.float, 0.000244140625, st_src=( @@ -520,65 +519,65 @@ class TestLinearizerFailures(unittest.TestCase): UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=4, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=5, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=6, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=7, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=7, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 96, 8, 16), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=8, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=8, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=9, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=9, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=10, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=10, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=11, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=11, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=12, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=12, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=13, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=13, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=14, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=14, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=15, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=15, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 8640, 180, 18, 1), offset=19, mask=((1, 2), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=16, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=16, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 32, 48, 8, 16), strides=(0, 17280, 180, 18, 1), offset=19, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(2, 32, 48, 8, 16), strides=(0, 12288, 128, 16, 1), offset=0, mask=((0, 1), (0, 32), (0, 48), (0, 8), (0, 16)), contiguous=False), View(shape=(1536, 2, 128), strides=(128, 196608, 1), offset=0, mask=None, contiguous=False), View(shape=(32, 96, 8, 16), strides=(12288, 128, 16, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)),)),)),)),)),)),)),)), UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=17, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=17, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)), ast_const(dtypes.float, 2.0, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)), @@ -586,7 +585,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=18, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=18, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 96, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)), x4,)), ast_const(dtypes.float, 1e-05, st_src=( @@ -598,10 +597,10 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_23(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16), Opt(op=OptOps.LOCAL, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -609,10 +608,10 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_24(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)) opts = [Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=8), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=2)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -621,7 +620,7 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_25(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( @@ -636,7 +635,7 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_26(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( @@ -676,11 +675,11 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_27(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 1), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.MAX, (3,)), src=( UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 169, 13, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) all_failing_opts = [ [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=7), Opt(op=OptOps.UPCAST, axis=0, amt=0)], @@ -691,13 +690,13 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_28(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bfloat16), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.bfloat16.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.bfloat16, arg=TernaryOps.WHERE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=( x5:=UOp(UOps.CAST, dtypes.bfloat16, arg=None, src=( UOp(UOps.LOAD, dtypes.int, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)), x9:=ast_const(dtypes.bfloat16, 230.0, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()),)),)), @@ -728,17 +727,17 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_29(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.CAST, dtypes.half, arg=None, src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( UOp(UOps.CAST, dtypes.float, arg=None, src=( UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 128, 1, 64, 4, 58, 4, 58), strides=(0, 200704, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, 128), (0, 1), (0, 64), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(128, 1, 64, 56, 56, 64, 3, 3), strides=(3444736, 0, 0, 232, 1, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1, 64, 56, 56, 64, 3, 3), strides=(0, 0, 576, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) opts = [Opt(op=OptOps.TC, axis=0, amt=1), Opt(op=OptOps.PADTO, axis=2, amt=32)] helper_test_lin(Kernel(ast), opts, failed_platforms=[], atol=1.0) @@ -746,17 +745,17 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_30(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 12, 31, 31, 1, 1, 1), strides=(11532, 0, 961, 31, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.CAST, dtypes.half, arg=None, src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( UOp(UOps.CAST, dtypes.float, arg=None, src=( UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 12, 31, 31, 3, 2, 2), strides=(3072, 0, 0, 32, 1, 1024, 32, 1), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 12, 31, 31, 3, 2, 2), strides=(0, 0, 12, 0, 0, 4, 2, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) opts = [Opt(op=OptOps.PADTO, axis=3, amt=32), Opt(op=OptOps.LOCAL, axis=3, amt=32), Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @@ -765,17 +764,17 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_31(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 1), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( UOp(UOps.ALU, dtypes.float, arg=UnaryOps.EXP2, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 169, 13, 1), offset=0, mask=None, contiguous=True),)), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.float, 1.4426950408889634, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)) @@ -788,17 +787,17 @@ class TestLinearizerFailures(unittest.TestCase): # Memory access fault on tinybox red ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 256, 14, 14, 1, 1, 1), strides=(50176, 0, 196, 14, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.CAST, dtypes.half, arg=None, src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( UOp(UOps.CAST, dtypes.float, arg=None, src=( UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 256, 4, 16, 4, 16), strides=(0, 50176, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 256), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 256, 14, 14, 256, 3, 3), strides=(1048576, 0, 0, 64, 1, 4096, 1088, 17), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 256, 14, 14, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) opts = [Opt(op=OptOps.TC, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=7), Opt(op=OptOps.UNROLL, axis=1, amt=0), Opt(op=OptOps.LOCAL, axis=1, amt=16)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[], atol=0.1, rtol=0.05) @@ -807,12 +806,12 @@ class TestLinearizerFailures(unittest.TestCase): # UOps.UNMUL left after linearize ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( x5:=UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(1,), offset=0, mask=((0, 26040),), contiguous=False),)), src=()),)), UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( @@ -831,7 +830,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=((0, 26040),), contiguous=False),)), src=()),)),)), UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(1,), offset=-26040, mask=((26040, 32640),), contiguous=False),)), src=()),)), ast_const(dtypes.float, -0.18257418583505536, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32640,), strides=(0,), offset=0, mask=((26040, 32640),), contiguous=False),)), src=()),)),)),)), @@ -848,16 +847,16 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_34(self, unroll=False): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 1, 1), strides=(180, 0, 30, 3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 2, 5), strides=(77, 0, 0, 7, 1, 0, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 2, 5), strides=(0, 0, 10, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), ast_const(dtypes.float, 0.0, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) @@ -871,7 +870,7 @@ class TestLinearizerFailures(unittest.TestCase): # UOps.UNMUL left after linearize ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.uchar), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(5, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.CAST, dtypes.uchar, arg=None, src=( UOp(UOps.ALU, dtypes.uint, arg=BinaryOps.ADD, src=( @@ -891,7 +890,7 @@ class TestLinearizerFailures(unittest.TestCase): # fuzz: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=1 FUZZ_NTH=28 DEBUG=2 python3 ./test/external/fuzz_linearizer.py --logfile /tmp/beautiful_mnist.kernels.txt ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( @@ -899,13 +898,13 @@ class TestLinearizerFailures(unittest.TestCase): UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.CAST, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.uchar, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.uchar), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 5, 5), strides=(784, 0, 0, 28, 1, 0, 28, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 5, 5), strides=(0, 0, 25, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.float, 0.0, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) @@ -918,16 +917,16 @@ class TestLinearizerFailures(unittest.TestCase): # fuzz: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=1 FUZZ_NTH=87 DEBUG=2 python3 ./test/external/fuzz_linearizer.py --logfile /tmp/beautiful_mnist.kernels.txt ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 32, 1, 1, 1, 5, 5, 256), strides=(0, 0, 6400, 0, 0, 0, 1280, 256, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 3, 4)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.CAST, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.uchar, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.uchar), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 32, 24, 24, 1, 5, 5, 256), strides=(784, 0, 0, 28, 1, 0, 28, 1, 1568), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 32, 24, 24, 1, 5, 5, 256), strides=(18432, 0, 576, 24, 1, 0, 0, 0, 36864), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) for axis in [0,1,3,4]: opts = [Opt(op=OptOps.TC, axis=axis, amt=2)] @@ -939,7 +938,7 @@ class TestLinearizerFailures(unittest.TestCase): # fuzz: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=1 FUZZ_NTH=127 DEBUG=2 python3 ./test/external/fuzz_linearizer.py --logfile /tmp/beautiful_mnist.kernels.txt ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( @@ -947,13 +946,13 @@ class TestLinearizerFailures(unittest.TestCase): UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.CAST, dtypes.float, arg=None, src=( UOp(UOps.LOAD, dtypes.uchar, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.uchar), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 5, 5), strides=(784, 0, 0, 28, 1, 0, 28, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 5, 5), strides=(0, 0, 25, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.float, 0.0, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) @@ -966,7 +965,7 @@ class TestLinearizerFailures(unittest.TestCase): # fuzz: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=2 DEBUG=2 FUZZ_NTH=3 python3 ./test/external/fuzz_linearizer.py --logfile /tmp/beautiful_mnist.kernels.txt ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( @@ -984,17 +983,17 @@ class TestLinearizerFailures(unittest.TestCase): # One more resnet crash with a page fault on AMD. Checked on rocm6.1.3, -O1 works, -O2 fails ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 1, 1, 1), strides=(100352, 0, 784, 28, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.CAST, dtypes.half, arg=None, src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( UOp(UOps.CAST, dtypes.float, arg=None, src=( UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 128, 4, 58, 4, 58), strides=(0, 401408, 0, 3136, 0, 56, 0, 1), offset=-57, mask=((0, 1), (0, 256), (0, 1), (0, 128), (0, 4), (1, 57), (0, 4), (1, 57)), contiguous=False), View(shape=(256, 1, 128, 28, 28, 128, 3, 3), strides=(6889472, 0, 0, 464, 2, 53824, 13688, 59), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) opts=[Opt(op=OptOps.TC, axis=5, amt=2), Opt(op=OptOps.UNROLL, axis=0, amt=0)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["AMD", "HIP"]) @@ -1005,11 +1004,11 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_42(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.UPCAST, axis=0, amt=2), Opt(op=OptOps.PADTO, axis=0, amt=32)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @@ -1019,11 +1018,11 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_43(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @@ -1033,11 +1032,11 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_44(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4)] k = helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @@ -1050,19 +1049,19 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_45(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 1, 1, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 3, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.CAST, dtypes.float, arg=None, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.MUL, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.LOAD, dtypes.int, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (4,)), src=( @@ -1075,7 +1074,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.LOAD, dtypes.int, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=3, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=3, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (4,)), src=( @@ -1090,7 +1089,7 @@ class TestLinearizerFailures(unittest.TestCase): def test_failure_46(self): ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( @@ -1100,22 +1099,22 @@ class TestLinearizerFailures(unittest.TestCase): UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.LOAD, dtypes.int, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 10), strides=(0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.LOAD, dtypes.int, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 10), strides=(1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.bool, True, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 10), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.LOAD, dtypes.bool, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=3, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=3, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 10), strides=(1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=4, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 10), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), UOp(UOps.ALU, dtypes.float, arg=UnaryOps.RECIP, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=5, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=0, amt=2)] helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[]) @@ -1124,7 +1123,7 @@ class TestLinearizerFailures(unittest.TestCase): # upcast an arange, failed with UOP_IS_SYMBOLIC=1 (fixed!) ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( @@ -1140,16 +1139,16 @@ class TestLinearizerFailures(unittest.TestCase): # with UOP_IS_SYMBOLIC=1, generates the wrong IDIV (fixed!) ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 1, 1, 256, 1, 1, 256), strides=(0, 0, 65536, 0, 0, 256, 0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3, 4)), src=( UOp(UOps.CAST, dtypes.float, arg=None, src=( UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 56, 56, 256, 1, 1, 256), strides=(0, 0, 0, 56, 1, 3136, 0, 0, 802816), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 56, 56, 256, 1, 1, 256), strides=(0, 0, 3136, 56, 1, 0, 0, 0, 200704), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)) opts = [Opt(op=OptOps.TC, axis=0, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) @@ -1158,15 +1157,15 @@ class TestLinearizerFailures(unittest.TestCase): # with UOP_IS_SYMBOLIC=1, on METAL it breaks store fusion and has A+B and B+A being two different UOp ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 1), strides=(6, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 10), strides=(10, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 10), strides=(0, 1, 6), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.TC, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=2)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=[]) @@ -1175,21 +1174,21 @@ class TestLinearizerFailures(unittest.TestCase): # from BEAM_COMPARE=2 running tinyphysics.onnx model ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 20, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.REDUCE_AXIS, dtypes.bool, arg=(BinaryOps.ADD, (3,)), src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.bool, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 20, 1), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=( UOp(UOps.LOAD, dtypes.int, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.LOAD, dtypes.int, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=3, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=3, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), ast_const(dtypes.bool, True, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 20, 20), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)), @@ -1202,7 +1201,7 @@ class TestLinearizerFailures(unittest.TestCase): # regression test for #7019, training bert on tinybox red ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(1024, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.half, arg=UnaryOps.RECIP, src=( UOp(UOps.ALU, dtypes.half, arg=BinaryOps.ADD, src=( @@ -1224,13 +1223,13 @@ class TestLinearizerFailures(unittest.TestCase): UOp(UOps.CAST, dtypes.float, arg=None, src=( UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1024), strides=(524288, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1024), strides=(0, 1024, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=3, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=3, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(12, 1024, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), UOp(UOps.ALU, dtypes.half, arg=TernaryOps.WHERE, src=( x6, @@ -1246,17 +1245,17 @@ class TestLinearizerFailures(unittest.TestCase): # CUDA Error 700, an illegal memory access was encountered ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.CAST, dtypes.half, arg=None, src=( UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( UOp(UOps.CAST, dtypes.float, arg=None, src=( UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, 256), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(256, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),)) opts = [Opt(op=OptOps.TC, axis=0, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16)] helper_test_lin(Kernel(ast, opts=Device[Device.DEFAULT].renderer), opts=opts, failed_platforms=["CUDA", "NV"]) diff --git a/test/test_linearizer_overflows.py b/test/test_linearizer_overflows.py index 0388923b..37f0f304 100644 --- a/test/test_linearizer_overflows.py +++ b/test/test_linearizer_overflows.py @@ -9,7 +9,6 @@ from tinygrad.engine.search import time_linearizer, bufs_from_lin # stuff needed to unpack a kernel from tinygrad.ops import UOp, UOps, BinaryOps, UnaryOps -from tinygrad.dtype import PtrDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View @@ -27,7 +26,7 @@ class TestLinearizerOverflow(unittest.TestCase): def test_overflow_1(self): ast = UOp(UOps.SINK, None, arg=None, src=( UOp(UOps.STORE, None, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( @@ -37,15 +36,15 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, 64), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), x16:=ast_const(dtypes.float, 0.0, st_src=( UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.ALU, dtypes.float, arg=UnaryOps.SQRT, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( @@ -57,7 +56,7 @@ class TestLinearizerOverflow(unittest.TestCase): ast_const(dtypes.float, 1e-05, st_src=( UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=4, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x16,)),)),)) opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=0)] @@ -67,15 +66,15 @@ class TestLinearizerOverflow(unittest.TestCase): def test_overflow_2(self): ast = UOp(UOps.SINK, None, arg=None, src=( UOp(UOps.STORE, None, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 1, 1, 1), strides=(65536, 0, 1024, 32, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 512, 1, 32, 4, 34, 4, 34), strides=(0, 32768, 0, 1024, 0, 32, 0, 1), offset=-33, mask=((0, 1), (0, 512), (0, 1), (0, 32), (0, 4), (1, 33), (0, 4), (1, 33)), contiguous=False), View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(591872, 0, 0, 136, 1, 18496, 4760, 35), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(0, 0, 288, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=0)] _test_overflow(ast, opts) @@ -84,15 +83,15 @@ class TestLinearizerOverflow(unittest.TestCase): def test_overflow_3(self): ast = UOp(UOps.SINK, None, arg=None, src=( UOp(UOps.STORE, None, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 16, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 16), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=2)] _test_overflow(ast, opts) @@ -101,15 +100,15 @@ class TestLinearizerOverflow(unittest.TestCase): def test_overflow_4(self): ast = UOp(UOps.SINK, None, arg=None, src=( UOp(UOps.STORE, None, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 4, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 4), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)] _test_overflow(ast, opts) @@ -118,15 +117,15 @@ class TestLinearizerOverflow(unittest.TestCase): def test_overflow_5(self): ast = UOp(UOps.SINK, None, arg=None, src=( UOp(UOps.STORE, None, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 2), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=2)] _test_overflow(ast, opts) @@ -135,15 +134,15 @@ class TestLinearizerOverflow(unittest.TestCase): def test_overflow_6(self): ast = UOp(UOps.SINK, None, arg=None, src=( UOp(UOps.STORE, None, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)] _test_overflow(ast, opts) @@ -152,15 +151,15 @@ class TestLinearizerOverflow(unittest.TestCase): def test_overflow_7(self): ast = UOp(UOps.SINK, None, arg=None, src=( UOp(UOps.STORE, None, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)] _test_overflow(ast, opts) @@ -170,7 +169,7 @@ class TestLinearizerOverflow(unittest.TestCase): class TestLinearizerOverflowAlt(unittest.TestCase): def test_overflow_1(self): BS = 2 - g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)] + g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] in_st_1 = ShapeTracker(views=(View(shape=(1, BS, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, BS), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))).to_uop() in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop() @@ -182,7 +181,7 @@ class TestLinearizerOverflowAlt(unittest.TestCase): _test_overflow(ast, opts) def test_overflow_2(self): BS = 2 - g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)] + g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] in_st_1 = ShapeTracker(views=(View(shape=(1, BS, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, BS), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))).to_uop() in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop() diff --git a/test/test_renderer_failures.py b/test/test_renderer_failures.py index 8b562147..25c6b3d6 100644 --- a/test/test_renderer_failures.py +++ b/test/test_renderer_failures.py @@ -4,7 +4,7 @@ import numpy as np from tinygrad.codegen.uopgraph import full_graph_rewrite from tinygrad.codegen.linearize import linearize_uop from tinygrad.device import Buffer, Device -from tinygrad.dtype import PtrDType, dtypes +from tinygrad.dtype import dtypes from tinygrad.engine.realize import CompiledRunner from tinygrad.helpers import dedup, flatten, getenv, prod from tinygrad.renderer.cstyle import CStyleLanguage @@ -30,8 +30,8 @@ def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None): @unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, CStyleLanguage), "uops are for cstyle") class TestCStyleFailures(unittest.TestCase): def test_inline_const_alu(self): - a = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) - b = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 1) + a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) + b = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1) idx = UOp.const(dtypes.int, 0) ld = UOp(UOps.LOAD, dtypes.int, (b, idx)) alu = ld.alu(BinaryOps.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1)) @@ -45,7 +45,7 @@ class TestCStyleFailures(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local") class TestPTXFailures(unittest.TestCase): def test_gated_store_with_alu(self): - a = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) + a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0) gated_alu_store = UOp(UOps.STORE, dtypes.void, (a, lidx0, UOp.const(dtypes.int, 1), gate_alu)) sink = UOp(UOps.SINK, dtypes.void, (gated_alu_store,)) @@ -55,7 +55,7 @@ class TestPTXFailures(unittest.TestCase): @unittest.skip("not still valid?") def test_gated_store_with_if(self): - a = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) + a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0) val = UOp.const(dtypes.int, 1) if_uop = UOp(UOps.IF, dtypes.void, (gate_alu, val)) diff --git a/test/test_schedule.py b/test/test_schedule.py index 89ac83b5..46cbb1ba 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -9,7 +9,7 @@ import functools from typing import List, Optional, Union, cast from tinygrad import nn, dtypes, Device, Tensor -from tinygrad.dtype import DType, PtrDType +from tinygrad.dtype import DType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite, track_rewrites @@ -1611,7 +1611,7 @@ class TestIndexing(unittest.TestCase): self.assertLess(et, 1e3) def test_no_rewrite_elementwise(self): - bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(3)] + bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)] ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop())) sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),)) @@ -1619,7 +1619,7 @@ class TestIndexing(unittest.TestCase): self.assertEqual(rsink.key, sink.key) def test_simple_store_reshape(self): - bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(2)] + bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(())) @@ -1632,7 +1632,7 @@ class TestIndexing(unittest.TestCase): verify_ast(rsink) def test_no_reshape_reduceop(self): - bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(2)] + bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),)) @@ -1641,7 +1641,7 @@ class TestIndexing(unittest.TestCase): self.assertEqual(sink.key, rsink.key) def test_reshape_many(self): - bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(2)] + bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(())) @@ -1660,7 +1660,7 @@ class TestIndexing(unittest.TestCase): sizes = [10*(i+1) for i in range(SZ)] tms: List[float] = [] for sz in sizes: - bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(2)] + bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) for _ in range(sz): r = r + ast_const(dtypes.int, 2, ()) @@ -1682,14 +1682,14 @@ class TestIndexing(unittest.TestCase): # graph rewrite sink = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501 UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=( UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( UOp(UOps.VIEW, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501 UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=( UOp(UOps.LOAD, dtypes.int, arg=None, src=( - x8:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()), + x8:=UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501 UOp(UOps.LOAD, dtypes.int, arg=None, src=( x8, @@ -1709,7 +1709,7 @@ class TestIndexing(unittest.TestCase): a = Tensor.randint(4,).realize() expected_out = a.numpy().sum(0)+1 # LazyBuffer to pre-rewrite AST - bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(2)] + bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop())) r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0,))) swizzle_r = UOp(UOps.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(())) @@ -1732,7 +1732,7 @@ class TestIndexing(unittest.TestCase): b = Tensor.randint(4,).realize() expected_out = a.numpy().sum(0)+b.numpy().sum(0)+2 # LazyBuffer to pre-rewrite AST - bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(3)] + bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)] ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop())) r1 = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld1,), (BinaryOps.ADD, (0,))) ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop())) @@ -1752,7 +1752,7 @@ class TestIndexing(unittest.TestCase): swizzle = UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501 UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501 # there's an EXPAND pushing through the REDUCE_AXIS self.assertGreater(prod(swizzle.st.shape), prod(swizzle.src[0].st.shape)) @@ -1766,7 +1766,7 @@ class TestIndexing(unittest.TestCase): def test_permute_rewrite(self): sink = UOp(UOps.STORE, dtypes.void, arg=None, src=( - x1:=UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(1, ('METAL', 16384, dtypes.float)), src=()), + x1:=UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(1, ('METAL', 16384, dtypes.float)), src=()), x2:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(UOps.CONTIGUOUS, dtypes.float, arg=None, src=( x1, @@ -1778,15 +1778,15 @@ class TestIndexing(unittest.TestCase): UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=( x11:=UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(2, ('METAL', 16384, dtypes.float)), src=()), + UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 16384, dtypes.float)), src=()), x2,)),)), UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(8, ('METAL', 256, dtypes.float)), src=()), + UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(8, ('METAL', 256, dtypes.float)), src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)), UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(10, ('METAL', 16, dtypes.float)), src=()), + UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(10, ('METAL', 16, dtypes.float)), src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=( x11,)),)),)),)),)) diff --git a/test/test_search.py b/test/test_search.py index 7aabda3d..1aefdb82 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -8,7 +8,7 @@ from tinygrad.engine.schedule import create_schedule from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search from tinygrad.device import Device, Buffer from tinygrad.tensor import Tensor -from tinygrad.dtype import dtypes, PtrDType +from tinygrad.dtype import dtypes from tinygrad.helpers import Context, GlobalCounters from tinygrad.engine.realize import capturing from tinygrad.shape.shapetracker import ShapeTracker @@ -48,7 +48,7 @@ class TestTimeLinearizer(unittest.TestCase): # ast of Tensor.zeros(16).contiguous().realize() ast = UOp(UOps.SINK, src=( UOp(UOps.STORE, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),))), ast_const(dtypes.float, 0.0, st_src=( UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(0,), offset=0, mask=None, contiguous=False),))),)),)),)) @@ -105,7 +105,7 @@ class TestBEAM(unittest.TestCase): # taken from https://github.com/tinygrad/tinygrad/issues/4612 ast = UOp(UOps.SINK, dtypes.void, arg=None, src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501 UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (1,)), src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=( @@ -115,22 +115,22 @@ class TestBEAM(unittest.TestCase): UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=( UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)), # noqa: E501 UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501 UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501 UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=4, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501 UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=5, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501 UOp(UOps.LOAD, dtypes.float, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=6, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501 ast_const(dtypes.float, 1.4285714285714286, st_src=( UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) # noqa: E501 diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 400e9625..ad2df4c9 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -2,7 +2,6 @@ from typing import List import unittest, time from test.helpers import assert_equiv_uops from tinygrad import dtypes, Device -from tinygrad.dtype import PtrDType from tinygrad.helpers import DEBUG from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo from tinygrad.ops import UPat, PatternMatcher @@ -32,7 +31,7 @@ class TestGraphRewriteEfficiency(unittest.TestCase): def test_expand_rewrite(self): sink = UOp(UOps.SINK, dtypes.void, arg=KernelInfo(local_dims=2, upcasted=4, dont_use_locals=False), src=( UOp(UOps.STORE, dtypes.void, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 4, 64, 8, 16, 1, 1, 3, 3, 4, 1), strides=(1179648, 9216, 1, 147456, 576, 0, 0, 64, 192, 36864, 0), offset=0, mask=None, contiguous=False),)), src=()), @@ -40,14 +39,14 @@ class TestGraphRewriteEfficiency(unittest.TestCase): UOp(UOps.CAST, dtypes.float, arg=None, src=( UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=( UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=( View(shape=(1, 1024, 1, 64, 4, 17, 4, 17), strides=(0, 14400, 0, 225, 0, 15, 0, 1), offset=-16, mask=((0, 1), (0, 1024), (0, 1), (0, 64), (0, 4), (1, 16), (0, 4), (1, 16)), contiguous=False), View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(0, 73984, 4734976, 0, 4624, 295936, 68, 18, 1224, 0, 1), offset=0, mask=None, contiguous=False))), src=()),)), UOp(UOps.LOAD, dtypes.half, arg=None, src=( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()), + UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()), UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=( View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(7200, 0, 230400, 900, 0, 14400, 15, 0, 0, 225, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)) @@ -225,7 +224,7 @@ class TestUOpGraph(unittest.TestCase): @unittest.skip("this test isn't valid uops") def test_noop_vectorize_fold(self): - d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0) + d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0) idx = UOp.const(dtypes.int, 0) ld = UOp(UOps.LOAD, dtypes.float.vec(2), (d0, idx)) vec = UOp(UOps.VECTORIZE, dtypes.float.vec(2), (ld,)) @@ -236,9 +235,9 @@ class TestUOpGraph(unittest.TestCase): self.assertEqual(len([x for x in uops if x.op is UOps.VECTORIZE]), 0) def test_gep_vec_fold(self): - d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) - d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 1) - d2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 2) + d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) + d1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1) + d2 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 2) idx = UOp.const(dtypes.int, 0) def _test_vec(geps, count=4): vec = UOp(UOps.VECTORIZE, dtypes.float.vec(count), geps) @@ -342,8 +341,8 @@ class TestUOpGraph(unittest.TestCase): assert_equiv_uops(uops[-1], wmma) def test_cast_alu_fold(self): - d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=0) - d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1) + d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0) + d1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1) idx = UOp.const(dtypes.int, 0) ld = UOp(UOps.LOAD, dtypes.int, (d1, idx)) alu = ld.lt(1).cast(dtypes.bool) @@ -352,8 +351,8 @@ class TestUOpGraph(unittest.TestCase): self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 0) def test_double_cast_fold(self): - d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0) - d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1) + d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0) + d1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1) idx = UOp.const(dtypes.int, 0) ld = UOp(UOps.LOAD, dtypes.int, (d1, idx)) alu = ld.cast(dtypes.float).cast(dtypes.float) @@ -376,9 +375,9 @@ class TestUOpGraph(unittest.TestCase): self.assertEqual(out.src[1].arg, 6) def test_fold_gated_load(self): - glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) - glbl1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 1) - glbl2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 2) + glbl0 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) + glbl1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1) + glbl2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2) idx = UOp.const(dtypes.int, 0) ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False))) ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True))) @@ -390,8 +389,8 @@ class TestUOpGraph(unittest.TestCase): assert_equiv_uops(ld0, UOp.load(glbl2, idx, dtype=dtypes.int)) def test_fold_gated_load_local(self): - glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) - smem = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.int, local=True), (), ("temp", 1)) + glbl0 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) + smem = UOp(UOps.DEFINE_LOCAL, dtypes.int.ptr(local=True), (), ("temp", 1)) lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16)) st = UOp(UOps.STORE, dtypes.void, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int))) barrier = UOp(UOps.BARRIER, dtypes.void, (st, )) @@ -405,7 +404,7 @@ class TestUOpGraph(unittest.TestCase): assert_equiv_uops(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int)) def test_fold_gated_store(self): - glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) + glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) idx0 = UOp.const(dtypes.int, 0) idx1 = UOp.const(dtypes.int, 0) val = UOp.const(dtypes.int, 42) @@ -418,13 +417,13 @@ class TestUOpGraph(unittest.TestCase): @unittest.skip("this is a uop type error") def test_asserts_bad_gate(self): - glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) + glbl0 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) idx = UOp.const(dtypes.int, 0) bad_gate = UOp.const(dtypes.int, 1) with self.assertRaises(AssertionError): to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))]) def test_switched_range_order(self): - glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) + glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) c0 = UOp.const(dtypes.int, 0) c2 = UOp.const(dtypes.int, 2) cf = UOp.const(dtypes.float, 0.0) @@ -591,21 +590,21 @@ class TestExpander(unittest.TestCase): class TestLoadStoreFolder(unittest.TestCase): def test_simple_load_fold(self): - buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) + buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr()) load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(4)] sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = float4_rewrite(sink) assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1 def test_two_load_fold(self): - buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) + buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr()) load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(8)] sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = float4_rewrite(sink) assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 2 def test_simple_load_fold_gated(self): - buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) + buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr()) gate = UOp(UOps.DEFINE_VAR, dtypes.bool) load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)] sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) @@ -615,7 +614,7 @@ class TestLoadStoreFolder(unittest.TestCase): self.assertListEqual(list(single_load.src[2].arg), [0.0, 1.0, 2.0, 3.0]) def test_simple_load_dont_fold_different_gated(self): - buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) + buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr()) gate = UOp.variable("g1", False, True, dtypes.bool) gate2 = UOp.variable("g2", False, True, dtypes.bool) load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)] @@ -624,14 +623,14 @@ class TestLoadStoreFolder(unittest.TestCase): assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 3 def test_simple_store_fold(self): - buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) + buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr()) load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i))) for i in range(4)] sink = UOp(UOps.SINK, dtypes.void, tuple(load)) sink = float4_rewrite(sink) assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1 def test_simple_store_fold_gate(self): - buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) + buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr()) gate = UOp.variable("g1", False, True, dtypes.bool) load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)] sink = UOp(UOps.SINK, dtypes.void, tuple(load)) @@ -642,7 +641,7 @@ class TestLoadStoreFolder(unittest.TestCase): assert str(one_store.src[3]) == str(gate) # huh, why do i need str here? def test_simple_store_dont_fold(self): - buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) + buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr()) gate = UOp.variable("g1", False, True, dtypes.bool) gate2 = UOp.variable("g2", False, True, dtypes.bool) load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)] @@ -655,8 +654,8 @@ def gate_rewrite(sink): return graph_rewrite(sink, sym + expander + reducer) class TestIFUOps(unittest.TestCase): def test_create_ifs(self): - gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) - sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float, local=True), (), ("smem", 4)) + gbuf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) + sbuf = UOp(UOps.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 4)) valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5) lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4)) gate = valid&(lidx.ne(2)) @@ -674,8 +673,8 @@ class TestIFUOps(unittest.TestCase): self.assertEqual(len(st.src), 3) def test_expand_ifs_one_gate(self): - gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) - sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float, local=True), (), ("smem", 16)) + gbuf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) + sbuf = UOp(UOps.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 16)) valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1) lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16)) gate = valid&(lidx.ne(2)) @@ -694,7 +693,7 @@ class TestIFUOps(unittest.TestCase): # this will be fixed with the merge gated stores bounty @unittest.expectedFailure def test_expand_ifs_dumb(self): - buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) + buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5) lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4)) gate = valid&(lidx.ne(2)) diff --git a/test/test_uops.py b/test/test_uops.py index 256904a6..037121bc 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -4,7 +4,7 @@ import numpy as np from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.helpers import CI, DEBUG, getenv, Context -from tinygrad.dtype import dtypes, DType, PtrDType +from tinygrad.dtype import dtypes, DType from tinygrad.device import Buffer, Device from tinygrad.ops import UOps, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401 from tinygrad.renderer import Program @@ -30,8 +30,8 @@ def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], src:Tuple[UOp, ...], ar def _test_single_value(vals, op, dts): uops = [] output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1] - buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), 0) - buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), i+1) for i,dtype in enumerate(dts)] + buf_store = uop(uops, UOps.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) + buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)] loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts)) alu = uop(uops, UOps.ALU, output_dtype, loads, op) out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) @@ -46,7 +46,7 @@ def _test_single_value(vals, op, dts): def _test_single_value_const(vals, op, dts): uops = [] output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1] - buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), 0) + buf_store = uop(uops, UOps.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts)) alu = uop(uops, UOps.ALU, output_dtype, loads, op) out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu)) @@ -59,7 +59,7 @@ def _test_single_value_const(vals, op, dts): def _test_uops_result(output_dtype, uops, res): # uops = [] - buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), 0) + buf_store = uop(uops, UOps.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) # res = output_fn(uops) out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), res)) buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate() @@ -246,7 +246,7 @@ class TestConstantFolding(unittest.TestCase): class TestGatedStoreRewrite(unittest.TestCase): @unittest.expectedFailure def test_tiny_gate_store(self): - gmem = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) + gmem = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4)) idx = gidx0 * UOp.const(dtypes.int, 2) val = UOp.const(dtypes.float, 42.0) @@ -263,8 +263,8 @@ class TestGatedStoreRewrite(unittest.TestCase): @unittest.expectedFailure def test_gate_some_stores(self): - gmem0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) - gmem1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 1) + gmem0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) + gmem1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1) gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4)) idx = gidx0*UOp.const(dtypes.int, 2) val = UOp.const(dtypes.float, 42.0) @@ -282,8 +282,8 @@ class TestGatedStoreRewrite(unittest.TestCase): # scaled down version of TestLinearizerDumb.test_unmerged_ifs @unittest.expectedFailure def test_merge_ifs_alt(self): - gmem0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) - gmem1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 1) + gmem0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) + gmem1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1) gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4)) idx = gidx0*UOp.const(dtypes.int, 2) val = UOp.const(dtypes.float, 42.0) @@ -305,7 +305,7 @@ class TestLocalAccess(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") def test_local_basic(self): uops = [] - smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.float32, local=True), (), ('smem', 16)) + smem = uop(uops, UOps.DEFINE_LOCAL, dtypes.float32.ptr(local=True), (), ('smem', 16)) st = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), uop(uops, UOps.CONST, dtypes.float32, (), 42.0))) barr = uop(uops, UOps.BARRIER, dtypes.void, (st,)) sres = uop(uops, UOps.LOAD, dtypes.float32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), barr)) @@ -314,7 +314,7 @@ class TestLocalAccess(unittest.TestCase): @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory") def test_local_indirect(self): uops = [] - smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.int32, local=True), (), ('smem', 16)) + smem = uop(uops, UOps.DEFINE_LOCAL, dtypes.int32.ptr(local=True), (), ('smem', 16)) st1 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), uop(uops, UOps.CONST, dtypes.int32, (), 2))) st2 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 2), uop(uops, UOps.CONST, dtypes.int32, (), 42))) barr = uop(uops, UOps.BARRIER, dtypes.void, (st1,st2)) @@ -325,7 +325,7 @@ class TestLocalAccess(unittest.TestCase): @unittest.skipUnless(getenv("PTX"), "This only tests assembly backends") class TestAssembly(unittest.TestCase): def test_bitshift_left(self): - g1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 0) + g1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0) c1 = UOp(UOps.CONST, dtypes.int, (), 2) c2 = UOp(UOps.CONST, dtypes.int, (), 3) l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1)) @@ -337,7 +337,7 @@ class TestAssembly(unittest.TestCase): self.assertEqual(uops[-2].arg, BinaryOps.MUL) def test_bitshift_right(self): - g1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 0) + g1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0) c1 = UOp(UOps.CONST, dtypes.int, (), 2) c2 = UOp(UOps.CONST, dtypes.int, (), 3) l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1)) @@ -361,7 +361,7 @@ class TestUOpMethod(unittest.TestCase): def test_uop_variables(self): a = UOp.variable("a", 1, 10) uop_var = UOp.const(dtypes.int, a) - st_var = UOp(UOps.LOAD, dtypes.float, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0), + st_var = UOp(UOps.LOAD, dtypes.float, (UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0), ShapeTracker.from_shape((2, a)).to_uop())) ast_vars = (st_var+uop_var).variables() self.assertEqual(len(ast_vars), 1) @@ -376,7 +376,7 @@ class TestUOpMethod(unittest.TestCase): self.assertEqual((gidx0*3+1).const_factor(), 1) def test_replace(self): - x = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.void), (), 0) + x = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0) self.assertIs(x.replace(arg=None).arg, None) with self.assertRaises(AssertionError): x.replace(field="a") @@ -403,7 +403,7 @@ class TestIndexingOrdering(unittest.TestCase): # NOTE: these tests skip type_verify since they add dtype to STORE @unittest.expectedFailure def test_simple_order(self): - buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) + buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) uops = to_uops_list([st1, st0], skip_check=True) @@ -412,8 +412,8 @@ class TestIndexingOrdering(unittest.TestCase): @unittest.expectedFailure def test_ordering_multi_output(self): - buf0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) - buf1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 1) + buf0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) + buf1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1) st0_0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf0, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) st1_0 = UOp(UOps.STORE, dtypes.float, (buf0, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) st0_1 = UOp(UOps.STORE, dtypes.float.vec(4), (buf1, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) @@ -430,7 +430,7 @@ class TestIndexingOrdering(unittest.TestCase): assert stores[2].src[1] < stores[3].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}" def test_simple_order_with_special(self): - buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0) + buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0) gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4)) st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42))) st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10))) diff --git a/test/test_uops_stats.py b/test/test_uops_stats.py index fbbc87af..8198c6c1 100644 --- a/test/test_uops_stats.py +++ b/test/test_uops_stats.py @@ -5,7 +5,7 @@ from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import lower_schedule_item from tinygrad.codegen.linearize import linearize_uop from tinygrad.ops import BinaryOps, TernaryOps, flops_mem, UOps, UOp -from tinygrad.dtype import PtrDType, dtypes +from tinygrad.dtype import dtypes from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError # **************** new FlopCounter **************** @@ -119,7 +119,7 @@ class TestUOpsStats(unittest.TestCase): #MULACC should have the same stats as MUL + ADD def test_mulacc(self): - globl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple()) + globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), tuple()) o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1) o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2) u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1)) @@ -129,7 +129,7 @@ class TestUOpsStats(unittest.TestCase): u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD) uops = linearize_uop(u5.sink()) - globl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple()) + globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), tuple()) o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1) o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2) u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1)) diff --git a/test/test_viz.py b/test/test_viz.py index 592b1da4..e21b8280 100644 --- a/test/test_viz.py +++ b/test/test_viz.py @@ -1,6 +1,6 @@ from typing import Dict, List, Optional import unittest -from tinygrad.dtype import PtrDType, dtypes +from tinygrad.dtype import dtypes from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, UOps, UPat, \ graph_rewrite, contexts, track_rewrites from tinygrad.viz.serve import get_details, get_metadata, uop_to_json @@ -27,7 +27,7 @@ class TestViz(unittest.TestCase): pm = PatternMatcher([ (UPat.var("x")*1, lambda x:x), ]) - a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0))) + a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) uops = helper_test_viz(a*1, pm) self.assertEqual(len(uops), 1) self.assertEqual(uops[0], a) @@ -37,15 +37,15 @@ class TestViz(unittest.TestCase): (UPat.var("x")+UPat.var("x"), lambda x:x*2), (UPat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))), ]) - a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0))) + a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) uops = helper_test_viz(a+a, pm) self.assertEqual(len(uops), 2) self.assertEqual(uops[0], a*2) self.assertEqual(uops[1], graph_rewrite(a+a, pm)) def test_rewrite_with_ctx(self): - a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0))) - b = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 1), UOp.const(dtypes.int, 0))) + a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) + b = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1), UOp.const(dtypes.int, 0))) def store_load(visited:Dict[UOp, None], x:UOp) -> Optional[UOp]: if x in visited: return None visited[x] = None @@ -85,7 +85,7 @@ class TestViz(unittest.TestCase): self.assertEqual(len(ret), 1) def test_fold_const(self): - a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0))) + a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0))) graph = uop_to_json(a) assert not any(v[0].startswith("CONST") for v in graph.values()) assert len([x for x in graph.values() if "CONST" in x[0]]) == 1 diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index f8cdbe37..b4d2c931 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -3,12 +3,12 @@ from typing import Tuple from tinygrad.codegen.linearize import linearize_uop from tinygrad.codegen.uopgraph import full_graph_rewrite, is_increasing -from tinygrad.dtype import dtypes, PtrDType +from tinygrad.dtype import dtypes from tinygrad.ops import UOp, UOps, BinaryOps def get_gated_load_uop(valid:UOp, idx:UOp): return UOp(UOps.LOAD, dtypes.float, ( - UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0), + UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), idx, UOp.const(dtypes.float, 0.0), valid diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 003024fe..1a883acb 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -7,7 +7,7 @@ from typing import Tuple # *** fake symobilc uops *** from tinygrad.helpers import DEBUG -from tinygrad.dtype import dtypes, PtrDType, ConstType +from tinygrad.dtype import dtypes, ConstType from tinygrad.codegen.linearize import linearize_uop from tinygrad.codegen.uopgraph import full_graph_rewrite, sym from tinygrad.ops import BinaryOps, UOp, UOps, print_uops, graph_rewrite @@ -16,7 +16,7 @@ import functools def render(self) -> Tuple[str, ConstType, ConstType]: # NOTE: we need STORE so the ALU op has children - glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0) + glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0) uops = linearize_uop(full_graph_rewrite(UOp(UOps.STORE, dtypes.void, (glbl, UOp.const(dtypes.int, 0), self)).sink())) if DEBUG>=5: print_uops(uops) from tinygrad.renderer.cstyle import CStyleLanguage diff --git a/test/unit/test_verify_ast.py b/test/unit/test_verify_ast.py index 3b7ab3ce..b758e31f 100644 --- a/test/unit/test_verify_ast.py +++ b/test/unit/test_verify_ast.py @@ -3,7 +3,6 @@ import unittest from tinygrad import Tensor from tinygrad.codegen.kernel import Kernel -from tinygrad.dtype import PtrDType from tinygrad.helpers import DEBUG from tinygrad.ops import UOp, UOps, ReduceOps, print_uops from tinygrad.codegen.kernel import verify_ast @@ -27,9 +26,9 @@ def helper_test_verify_ast(*stores:UOp) -> Kernel: class TestVerifyAST(unittest.TestCase): def test_tiny_add(self): dtype = dtypes.int - buf_0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtype), (), 0) - buf_1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtype), (), 1) - buf_2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtype), (), 2) + buf_0 = UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), 0) + buf_1 = UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), 1) + buf_2 = UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), 2) a = UOp(UOps.LOAD, dtype, (buf_1, ShapeTracker.from_shape((32, 1)).to_uop())) b = UOp(UOps.LOAD, dtype, (buf_2, ShapeTracker.from_shape((32, 1)).to_uop())) store = UOp(UOps.STORE, dtypes.void, (buf_0, ShapeTracker.from_shape((32, 1)).to_uop(), a+b)) @@ -37,7 +36,7 @@ class TestVerifyAST(unittest.TestCase): def test_exactly_one_full_shape(self): dtype = dtypes.int - bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtype), (), i) for i in range(6)] + bufs = [UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), i) for i in range(6)] a = UOp(UOps.LOAD, dtype, (bufs[2], ShapeTracker.from_shape((32, 1)).to_uop())) b = UOp(UOps.LOAD, dtype, (bufs[3], ShapeTracker.from_shape((32, 1)).to_uop())) st0 = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), a+b) @@ -47,28 +46,28 @@ class TestVerifyAST(unittest.TestCase): with self.assertRaises(InvalidASTException): helper_test_verify_ast(st0, st1) def test_no_implicit_broadcasting(self): - bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)] + bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)] a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop())) b = a + UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.MAX, (1,))) st = UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b)) with self.assertRaises(InvalidASTException): helper_test_verify_ast(st) def test_shrink_ok(self): - bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)] + bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)] a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)).to_uop())) b = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(0, 1), offset=0, mask=None, contiguous=False),)).to_uop())) st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), a+b) helper_test_verify_ast(st) def test_reduce_store(self): - bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)] + bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)] a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop())) r = UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,))) st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r) with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st) def test_reduce_add_store(self): - bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)] + bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)] a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop())) r = UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,))) st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r+a) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 0e7f9669..aa974d8a 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -9,7 +9,7 @@ from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp graph_rewrite, track_rewrites, Variable, sint from tinygrad.device import Device from tinygrad.renderer import Renderer, TensorCore, Program -from tinygrad.dtype import ImageDType, PtrDType +from tinygrad.dtype import ImageDType from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX from tinygrad.shape.shapetracker import ShapeTracker @@ -660,7 +660,7 @@ class Kernel: st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD] local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape)) st_uop = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uop() - membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in, True), (), (f"temp{-(-1-i)}", st_uop.arg.real_size())) + membuf = UOp(UOps.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (), (f"temp{-(-1-i)}", st_uop.arg.real_size())) local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, st_uop, src)), fix_st_fxn) srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, st_uop, local_store))) else: @@ -690,7 +690,7 @@ class Kernel: for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \ (1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)]) st_uop = ShapeTracker.from_shape(local_shape).to_uop() - local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(op.dtype, True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size())) + local_buffer = UOp(UOps.DEFINE_LOCAL, op.dtype.ptr(local=True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size())) local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, start))) grouped_reduce = UOp(UOps.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], second_axis)) if op is self.reduceops[-1]: return grouped_reduce diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 90df8f43..781baf47 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: from tinygrad.renderer import Renderer # ***** float4/image store handling ***** def fold_expanded(ex, buf): - if buf.dtype != PtrDType(dtypes.float) and buf.dtype != PtrDType(dtypes.half) and not isinstance(buf.dtype, ImageDType): return None + if buf.dtype != dtypes.float.ptr() and buf.dtype != dtypes.half.ptr() and not isinstance(buf.dtype, ImageDType): return None new_srcs = dedup(list(ex.src)) old_new_srcs = new_srcs[:] is_load, is_image = new_srcs[0].op is UOps.LOAD, isinstance(buf.dtype, ImageDType) @@ -32,7 +32,7 @@ def fold_expanded(ex, buf): offsets_rootsrc[root_src][arg] = i # then rewrite everything we can - lengths = [4] if is_image else ([8,4,2] if buf.dtype == PtrDType(dtypes.half) and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2])) + lengths = [4] if is_image else ([8,4,2] if buf.dtype == dtypes.half.ptr() and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2])) used = set() for rootsrc, offsets in offsets_rootsrc.items(): for o in offsets: diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 429419f6..eaec5abf 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -18,7 +18,7 @@ class DType: assert self.count == 1, f"can't vectorize {self} with size {sz}" if sz == 1 or self.name == 'void': return self # void doesn't vectorize, and sz=1 is scalar return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz) - def ptr(self) -> Union[PtrDType, ImageDType]: return PtrDType(self) + def ptr(self, local=False) -> Union[PtrDType, ImageDType]: return PtrDType(self, local) def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self # dependent typing? @@ -29,10 +29,9 @@ class ImageDType(DType): local: bool = False # images are never local def scalar(self) -> DType: return self.base def vec(self, sz:int): return self.base.vec(sz) - def ptr(self) -> Union[PtrDType, ImageDType]: return self + def ptr(self, local=False) -> Union[PtrDType, ImageDType]: return self def __repr__(self): return f"dtypes.{self.name}({self.shape})" -# @dataclass(frozen=True, init=False, repr=False, eq=False) class PtrDType(DType): def __init__(self, dt:DType, local=False): self.base, self.local = dt, local @@ -40,7 +39,7 @@ class PtrDType(DType): def __hash__(self): return super().__hash__() def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count def __ne__(self, dt): return not (self == dt) - def __repr__(self): return f"PtrDType({super().__repr__()}, local=True)" if self.local else f"PtrDType({super().__repr__()})" + def __repr__(self): return f"{super().__repr__()}.ptr(local=True)" if self.local else f"{super().__repr__()}.ptr()" class dtypes: @staticmethod diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 34e8158a..45fdad2a 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -102,17 +102,19 @@ class CStyleLanguage(Renderer): def get_kernel_modifier(self, uops:List[UOp]) -> str: return "" def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str: tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501 - buftypes = [(name,f"{'write_only' if mutable else 'read_only'} image2d_t" if dtype.name.startswith('image') else - ("" if mutable else "const ")+self.render_dtype(dtype)+self.buffer_suffix if isinstance(dtype, PtrDType) else + buftypes = [(name, self.render_dtype(dtype, mutable)+self.buffer_suffix if isinstance(dtype, (ImageDType, PtrDType)) else self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs] prg = ''.join([f"{self.kernel_prefix}void {self.get_kernel_modifier(uops)}{function_name}(",] + [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] + [") {\n" + tmp] + ['\n'.join(kernel), "\n}"]) return prg if prefix is None else "\n".join(prefix)+f"\n{prg}" - def render_dtype(self, dt:DType) -> str: + def render_dtype(self, dt:DType, mutable=True) -> str: + if isinstance(dt, ImageDType): + return f"{'write_only' if mutable else 'read_only'} image2d_t" if isinstance(dt, PtrDType): - return (self.smem_prefix if dt.local else self.buffer_prefix) + self.render_dtype(dt.base) + ("*" if isinstance(dt, PtrDType) else "") + return ("" if mutable else "const ") + (self.smem_prefix if dt.local else self.buffer_prefix) +\ + self.render_dtype(dt.base) + ("*" if isinstance(dt, PtrDType) else "") return self.type_map.get(scalar:=dt.scalar(), scalar.name) + (str(dt.count) if (dt.count) > 1 else "") def __getitem__(self, key): return self.r[key] # hacky helper @@ -126,12 +128,8 @@ class CStyleLanguage(Renderer): depth = 1 c: DefaultDict[str, int] = defaultdict(int) for u in uops: - if u.op is UOps.DEFINE_GLOBAL: - r[u] = f"data{u.arg}" - bufs[u] = (r[u], (u.dtype, False)) - continue - if u.op is UOps.DEFINE_VAR: - r[u] = u.arg[0] + if u.op in (UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR): + r[u] = f"data{u.arg}" if u.op is UOps.DEFINE_GLOBAL else u.arg[0] bufs[u] = (r[u], (u.dtype, False)) continue