mirror of https://github.com/commaai/tinygrad.git
minor dtype cleanup [pr] (#7124)
* minor dtype cleanup [pr] * use ptr() function
This commit is contained in:
parent
0b2621f63f
commit
ded1b38b84
|
@ -37,7 +37,7 @@ print("******** second, the Device ***********")
|
||||||
DEVICE = "CLANG" # NOTE: you can change this!
|
DEVICE = "CLANG" # NOTE: you can change this!
|
||||||
|
|
||||||
import struct
|
import struct
|
||||||
from tinygrad.dtype import PtrDType, dtypes
|
from tinygrad.dtype import dtypes
|
||||||
from tinygrad.device import Buffer, Device
|
from tinygrad.device import Buffer, Device
|
||||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UOps
|
from tinygrad.ops import BinaryOps, MetaOps, UOp, UOps
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
|
@ -49,12 +49,12 @@ b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struc
|
||||||
# NOTE: a._buf is the same as the return from MallocAllocator.alloc
|
# NOTE: a._buf is the same as the return from MallocAllocator.alloc
|
||||||
|
|
||||||
# describe the computation
|
# describe the computation
|
||||||
buf_1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 1)
|
buf_1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 1)
|
||||||
buf_2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 2)
|
buf_2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 2)
|
||||||
ld_1 = UOp(UOps.LOAD, dtypes.int32, (buf_1, ShapeTracker.from_shape((1,)).to_uop()))
|
ld_1 = UOp(UOps.LOAD, dtypes.int32, (buf_1, ShapeTracker.from_shape((1,)).to_uop()))
|
||||||
ld_2 = UOp(UOps.LOAD, dtypes.int32, (buf_2, ShapeTracker.from_shape((1,)).to_uop()))
|
ld_2 = UOp(UOps.LOAD, dtypes.int32, (buf_2, ShapeTracker.from_shape((1,)).to_uop()))
|
||||||
alu = ld_1 + ld_2
|
alu = ld_1 + ld_2
|
||||||
output_buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 0)
|
output_buf = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
|
||||||
st_0 = UOp(UOps.STORE, dtypes.void, (output_buf, ShapeTracker.from_shape((1,)).to_uop(), alu))
|
st_0 = UOp(UOps.STORE, dtypes.void, (output_buf, ShapeTracker.from_shape((1,)).to_uop(), alu))
|
||||||
s = UOp(UOps.SINK, dtypes.void, (st_0,))
|
s = UOp(UOps.SINK, dtypes.void, (st_0,))
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ import unittest
|
||||||
from tinygrad import Device
|
from tinygrad import Device
|
||||||
from tinygrad.ops import UOp, UOps, BinaryOps, UnaryOps
|
from tinygrad.ops import UOp, UOps, BinaryOps, UnaryOps
|
||||||
from tinygrad.engine.search import Opt, OptOps
|
from tinygrad.engine.search import Opt, OptOps
|
||||||
from tinygrad.dtype import dtypes, PtrDType
|
from tinygrad.dtype import dtypes
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
from tinygrad.shape.view import View
|
from tinygrad.shape.view import View
|
||||||
from tinygrad.codegen.kernel import Kernel
|
from tinygrad.codegen.kernel import Kernel
|
||||||
|
@ -31,7 +31,7 @@ class TestOpenpilotValidhack(unittest.TestCase):
|
||||||
UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((8, 108, 4)), arg=2, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.imagef((8, 108, 4)), arg=2, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 3, 4, 3, 3), strides=(0, 0, 0, 0, 0, 432, 1, 48, 4, 144, 16), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 3, 4, 3, 3), strides=(0, 0, 0, 0, 0, 432, 1, 48, 4, 144, 16), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 4, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 4, 1, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||||
x19:=UOp(UOps.CONST, dtypes.float, arg=0.0, src=(
|
x19:=UOp(UOps.CONST, dtypes.float, arg=0.0, src=(
|
||||||
x20:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
x20:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 128, 1, 1, 8, 4, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||||
|
@ -81,7 +81,7 @@ class TestOpenpilotValidhack(unittest.TestCase):
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
x18:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
|
x18:=UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=48128, mask=((0, 1), (1, 2), (0, 512)), contiguous=False),)), src=()),)),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 10, 512), strides=(0, 0, 1), offset=48128, mask=((0, 1), (1, 2), (0, 512)), contiguous=False),)), src=()),)),
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
x18,
|
x18,
|
||||||
|
|
|
@ -312,15 +312,15 @@ class TestEqStrDType(unittest.TestCase):
|
||||||
def test_ptr_ne(self):
|
def test_ptr_ne(self):
|
||||||
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
|
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
|
||||||
# TODO: is this the wrong behavior?
|
# TODO: is this the wrong behavior?
|
||||||
assert PtrDType(dtypes.float32) == dtypes.float32
|
assert dtypes.float32.ptr() == dtypes.float32
|
||||||
assert not (PtrDType(dtypes.float32) != dtypes.float32)
|
assert not (dtypes.float32.ptr() != dtypes.float32)
|
||||||
assert PtrDType(dtypes.float32) == PtrDType(dtypes.float32)
|
assert dtypes.float32.ptr() == dtypes.float32.ptr()
|
||||||
assert not (PtrDType(dtypes.float32) != PtrDType(dtypes.float32))
|
assert not (dtypes.float32.ptr() != dtypes.float32.ptr())
|
||||||
#assert PtrDType(dtypes.float32) != dtypes.float32
|
#assert dtypes.float32.ptr() != dtypes.float32
|
||||||
def test_strs(self):
|
def test_strs(self):
|
||||||
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
|
if PtrDType is None: raise unittest.SkipTest("no PtrDType support")
|
||||||
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
|
self.assertEqual(str(dtypes.imagef((1,2,4))), "dtypes.imagef((1, 2, 4))")
|
||||||
self.assertEqual(str(PtrDType(dtypes.float32)), "PtrDType(dtypes.float)")
|
self.assertEqual(str(dtypes.float32.ptr()), "dtypes.float.ptr()")
|
||||||
|
|
||||||
class TestHelpers(unittest.TestCase):
|
class TestHelpers(unittest.TestCase):
|
||||||
signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)
|
signed_ints = (dtypes.int8, dtypes.int16, dtypes.int32, dtypes.int64)
|
||||||
|
|
|
@ -15,7 +15,7 @@ from tinygrad.tensor import Tensor, _to_np_dtype
|
||||||
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule
|
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule
|
||||||
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner
|
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner
|
||||||
from tinygrad.helpers import prod, Context, getenv, CI, flatten, dedup, AMX
|
from tinygrad.helpers import prod, Context, getenv, CI, flatten, dedup, AMX
|
||||||
from tinygrad.dtype import DType, PtrDType, dtypes
|
from tinygrad.dtype import DType, dtypes
|
||||||
|
|
||||||
def helper_realized_ast(r:Union[Tensor, List[Tensor]]) -> Tuple[UOp, List[Buffer]]:
|
def helper_realized_ast(r:Union[Tensor, List[Tensor]]) -> Tuple[UOp, List[Buffer]]:
|
||||||
if isinstance(r, Tensor): r = [r]
|
if isinstance(r, Tensor): r = [r]
|
||||||
|
@ -84,7 +84,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
|
|
||||||
def test_multioutput(self):
|
def test_multioutput(self):
|
||||||
dtype, st = dtypes.int, ShapeTracker.from_shape((8,))
|
dtype, st = dtypes.int, ShapeTracker.from_shape((8,))
|
||||||
g0, g1, g2, g3 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtype), arg=i) for i in range(4)]
|
g0, g1, g2, g3 = [UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), arg=i) for i in range(4)]
|
||||||
a = UOp(UOps.LOAD, dtype, (g2, st.to_uop()))
|
a = UOp(UOps.LOAD, dtype, (g2, st.to_uop()))
|
||||||
b = UOp(UOps.LOAD, dtype, (g3, st.to_uop()))
|
b = UOp(UOps.LOAD, dtype, (g3, st.to_uop()))
|
||||||
out0 = UOp(UOps.STORE, dtypes.void, (g0, st.to_uop(), a + b))
|
out0 = UOp(UOps.STORE, dtypes.void, (g0, st.to_uop(), a + b))
|
||||||
|
@ -107,7 +107,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
Tensor.manual_seed(0)
|
Tensor.manual_seed(0)
|
||||||
x = Tensor.randn(32, dtype=dtypes.float).realize()
|
x = Tensor.randn(32, dtype=dtypes.float).realize()
|
||||||
st_x = x.lazydata.st
|
st_x = x.lazydata.st
|
||||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((1, 32)).expand((32, 32)).to_uop()))
|
first_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((1, 32)).expand((32, 32)).to_uop()))
|
||||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (1,)))
|
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (1,)))
|
||||||
second_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((32, 1)).to_uop()))
|
second_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((32, 1)).to_uop()))
|
||||||
|
@ -143,7 +143,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
Tensor.manual_seed(0)
|
Tensor.manual_seed(0)
|
||||||
x = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
|
x = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
|
||||||
st_x = x.lazydata.st
|
st_x = x.lazydata.st
|
||||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((27, 1, 32, 5)).expand((27, 32, 32, 5)).to_uop()))
|
first_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((27, 1, 32, 5)).expand((27, 32, 32, 5)).to_uop()))
|
||||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
||||||
second_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((27, 32, 1, 5)).to_uop()))
|
second_x = UOp(UOps.LOAD, dtypes.float, (g1, st_x.reshape((27, 32, 1, 5)).to_uop()))
|
||||||
|
@ -205,7 +205,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
x0 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
|
x0 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
|
||||||
x1 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
|
x1 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
|
||||||
x2 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
|
x2 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize()
|
||||||
g0, g1, g2, g3 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(4)]
|
g0, g1, g2, g3 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(4)]
|
||||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x0.lazydata.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5)).to_uop()))
|
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x0.lazydata.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5)).to_uop()))
|
||||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
|
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
|
||||||
second_x = UOp(UOps.LOAD, dtypes.float, (g2, x1.lazydata.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5)).to_uop()))
|
second_x = UOp(UOps.LOAD, dtypes.float, (g2, x1.lazydata.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5)).to_uop()))
|
||||||
|
@ -232,7 +232,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
Tensor.manual_seed(0)
|
Tensor.manual_seed(0)
|
||||||
x = Tensor.randn(8, 32, 8, 16, dtype=dtypes.float).realize()
|
x = Tensor.randn(8, 32, 8, 16, dtype=dtypes.float).realize()
|
||||||
st = x.lazydata.st
|
st = x.lazydata.st
|
||||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, st.reshape((8, 1, 32, 8, 1, 16)).expand((8, 32, 32, 8, 16, 16)).to_uop()))
|
first_x = UOp(UOps.LOAD, dtypes.float, (g1, st.reshape((8, 1, 32, 8, 1, 16)).expand((8, 32, 32, 8, 16, 16)).to_uop()))
|
||||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2, 5)))
|
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2, 5)))
|
||||||
second_x = UOp(UOps.LOAD, dtypes.float, (g1, st.reshape((8, 32, 1, 8, 16, 1)).to_uop()))
|
second_x = UOp(UOps.LOAD, dtypes.float, (g1, st.reshape((8, 32, 1, 8, 16, 1)).to_uop()))
|
||||||
|
@ -283,7 +283,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
# check how it works with one reduce optimized and one unoptimized
|
# check how it works with one reduce optimized and one unoptimized
|
||||||
Tensor.manual_seed(0)
|
Tensor.manual_seed(0)
|
||||||
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).softmax(1).realize()
|
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).softmax(1).realize()
|
||||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
|
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
|
||||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
||||||
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
|
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
|
||||||
|
@ -314,7 +314,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
Tensor.manual_seed(0)
|
Tensor.manual_seed(0)
|
||||||
x = Tensor.randn(4, 32, dtype=dtypes.float).realize()
|
x = Tensor.randn(4, 32, dtype=dtypes.float).realize()
|
||||||
x_p = Tensor.randn(4, 32, dtype=dtypes.float).realize()
|
x_p = Tensor.randn(4, 32, dtype=dtypes.float).realize()
|
||||||
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)]
|
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop()))
|
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop()))
|
||||||
first_x_p = UOp(UOps.LOAD, dtypes.float, (g2, x_p.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop()))
|
first_x_p = UOp(UOps.LOAD, dtypes.float, (g2, x_p.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop()))
|
||||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
||||||
|
@ -350,7 +350,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
# check how multireduce works with multioutput
|
# check how multireduce works with multioutput
|
||||||
Tensor.manual_seed(0)
|
Tensor.manual_seed(0)
|
||||||
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize()
|
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize()
|
||||||
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)]
|
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||||
first_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
|
first_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
|
||||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
||||||
second_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
|
second_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
|
||||||
|
@ -373,7 +373,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
# if we change the shape of store1 to be contiguous, it will match store0 but not the value it's storing (FAIL!)
|
# if we change the shape of store1 to be contiguous, it will match store0 but not the value it's storing (FAIL!)
|
||||||
Tensor.manual_seed(0)
|
Tensor.manual_seed(0)
|
||||||
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize()
|
x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize()
|
||||||
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)]
|
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||||
first_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
|
first_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop()))
|
||||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
||||||
second_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
|
second_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop()))
|
||||||
|
@ -397,7 +397,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
def test_complete_unroll_multireduce(self):
|
def test_complete_unroll_multireduce(self):
|
||||||
Tensor.manual_seed(0)
|
Tensor.manual_seed(0)
|
||||||
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
|
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
|
||||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop()))
|
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop()))
|
||||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
||||||
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop()))
|
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop()))
|
||||||
|
@ -413,7 +413,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
def test_upcast_multireduce(self):
|
def test_upcast_multireduce(self):
|
||||||
Tensor.manual_seed(0)
|
Tensor.manual_seed(0)
|
||||||
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
|
x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize()
|
||||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop()))
|
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop()))
|
||||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
||||||
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop()))
|
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop()))
|
||||||
|
@ -432,7 +432,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
# make sure the if block of a grouped reduce can be closed early and the result loaded back in
|
# make sure the if block of a grouped reduce can be closed early and the result loaded back in
|
||||||
Tensor.manual_seed(0)
|
Tensor.manual_seed(0)
|
||||||
x = Tensor.randn(27, 12, 5, dtype=dtypes.float).realize()
|
x = Tensor.randn(27, 12, 5, dtype=dtypes.float).realize()
|
||||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5)).to_uop()))
|
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5)).to_uop()))
|
||||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
||||||
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 12, 1, 5)).to_uop()))
|
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 12, 1, 5)).to_uop()))
|
||||||
|
@ -448,7 +448,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
def test_mean_std_multireduce(self):
|
def test_mean_std_multireduce(self):
|
||||||
Tensor.manual_seed(0)
|
Tensor.manual_seed(0)
|
||||||
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
|
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
|
||||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop()))
|
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop()))
|
||||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
|
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
|
||||||
neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1))
|
neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1))
|
||||||
|
@ -466,7 +466,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
def test_mean_std_multireduce_mid_dim(self):
|
def test_mean_std_multireduce_mid_dim(self):
|
||||||
Tensor.manual_seed(0)
|
Tensor.manual_seed(0)
|
||||||
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
|
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
|
||||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35)).to_uop()))
|
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35)).to_uop()))
|
||||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,)))
|
||||||
neg_mean = first_reduce * ast_const(dtypes.float, -0.04, (15, 25, 1, 35))
|
neg_mean = first_reduce * ast_const(dtypes.float, -0.04, (15, 25, 1, 35))
|
||||||
|
@ -486,7 +486,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
# TODO: Similar error to test_multiout_intermediate_multireduce (implicit expand vs shape mismatch)
|
# TODO: Similar error to test_multiout_intermediate_multireduce (implicit expand vs shape mismatch)
|
||||||
Tensor.manual_seed(0)
|
Tensor.manual_seed(0)
|
||||||
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
|
x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize()
|
||||||
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)]
|
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||||
first_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop()))
|
first_x = UOp(UOps.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop()))
|
||||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
|
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
|
||||||
neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1))
|
neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1))
|
||||||
|
@ -508,7 +508,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
def test_var_multireduce(self):
|
def test_var_multireduce(self):
|
||||||
Tensor.manual_seed(0)
|
Tensor.manual_seed(0)
|
||||||
x = Tensor.randn(3, 27, 32, dtype=dtypes.float).realize()
|
x = Tensor.randn(3, 27, 32, dtype=dtypes.float).realize()
|
||||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||||
# push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD
|
# push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD
|
||||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32)).to_uop()))
|
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32)).to_uop()))
|
||||||
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
|
first_reduce = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,)))
|
||||||
|
@ -530,7 +530,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||||
def test_softmax_multireduce(self):
|
def test_softmax_multireduce(self):
|
||||||
x = Tensor.rand(4, 32).realize()
|
x = Tensor.rand(4, 32).realize()
|
||||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||||
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32,)).expand((4, 32, 32)).to_uop()))
|
first_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32,)).expand((4, 32, 32)).to_uop()))
|
||||||
max_x = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.MAX, (2,)))
|
max_x = UOp(UOps.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.MAX, (2,)))
|
||||||
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1,)).to_uop()))
|
second_x = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1,)).to_uop()))
|
||||||
|
@ -559,15 +559,15 @@ class TestLinearizer(unittest.TestCase):
|
||||||
arange = UOp(UOps.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (BinaryOps.ADD, arange_axis))
|
arange = UOp(UOps.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (BinaryOps.ADD, arange_axis))
|
||||||
output_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape))
|
output_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape))
|
||||||
out = arange+ast_const(dtypes.int, -1, output_shape)
|
out = arange+ast_const(dtypes.int, -1, output_shape)
|
||||||
store = UOp(UOps.STORE, src=(UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0), ShapeTracker.from_shape(output_shape).to_uop(), out))
|
store = UOp(UOps.STORE, src=(UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0), ShapeTracker.from_shape(output_shape).to_uop(), out))
|
||||||
sink = UOp(UOps.SINK, src=(store,))
|
sink = UOp(UOps.SINK, src=(store,))
|
||||||
helper_linearizer_ast(sink, [], wanna_output=[real_arange])
|
helper_linearizer_ast(sink, [], wanna_output=[real_arange])
|
||||||
with Context(DEBUG=0, NOOPT=0): np.testing.assert_equal(tiny.numpy(), real_arange)
|
with Context(DEBUG=0, NOOPT=0): np.testing.assert_equal(tiny.numpy(), real_arange)
|
||||||
|
|
||||||
@unittest.skipIf(CI and Device.DEFAULT in {"PTX", "AMD", "NV"}, "very slow")
|
@unittest.skipIf(CI and Device.DEFAULT in {"PTX", "AMD", "NV"}, "very slow")
|
||||||
def test_indexing_multireduce(self):
|
def test_indexing_multireduce(self):
|
||||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||||
g2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=2)
|
g2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2)
|
||||||
arange_input_st = ShapeTracker(views=(View(shape=(16385, 32767), strides=(0, 0), offset=0, mask=((0, 16385), (16383, 32767)), contiguous=False), \
|
arange_input_st = ShapeTracker(views=(View(shape=(16385, 32767), strides=(0, 0), offset=0, mask=((0, 16385), (16383, 32767)), contiguous=False), \
|
||||||
View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False)))
|
View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False)))
|
||||||
# TODO: do this arange broadcast in the scheduler
|
# TODO: do this arange broadcast in the scheduler
|
||||||
|
@ -598,7 +598,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
real_argmax = np.argmax(t.numpy(), axis=0, keepdims=False).reshape(1, 20, 1)
|
real_argmax = np.argmax(t.numpy(), axis=0, keepdims=False).reshape(1, 20, 1)
|
||||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa E501
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa E501
|
||||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||||
|
@ -611,10 +611,10 @@ class TestLinearizer(unittest.TestCase):
|
||||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(20, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa E501
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(20, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa E501
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa E501
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa E501
|
||||||
ast_const(dtypes.bool, True, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)), # noqa E501
|
ast_const(dtypes.bool, True, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)), # noqa E501
|
||||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||||
|
@ -630,7 +630,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
real_argmax = np.argmax(t.numpy())
|
real_argmax = np.argmax(t.numpy())
|
||||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
|
||||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||||
|
@ -643,10 +643,10 @@ class TestLinearizer(unittest.TestCase):
|
||||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa: E501
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa: E501
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa: E501
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa: E501
|
||||||
ast_const(dtypes.bool, True, (200, 1)),)),)),
|
ast_const(dtypes.bool, True, (200, 1)),)),)),
|
||||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||||
|
@ -669,7 +669,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
# [Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)]
|
# [Opt(OptOps.PADTO, 1, 32), Opt(OptOps.PADTO, 2, 32)]
|
||||||
]
|
]
|
||||||
|
|
||||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||||
x_ld0 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop()))
|
x_ld0 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop()))
|
||||||
x_ld1 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop()))
|
x_ld1 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop()))
|
||||||
r0 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.ADD, (1,)))
|
r0 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.ADD, (1,)))
|
||||||
|
@ -696,7 +696,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),]
|
[Opt(OptOps.PADTO, 0, 32), Opt(OptOps.UPCAST, 0, 8),]
|
||||||
]
|
]
|
||||||
|
|
||||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(2)]
|
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)]
|
||||||
x_ld0 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop()))
|
x_ld0 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop()))
|
||||||
x_ld1 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop()))
|
x_ld1 = UOp(UOps.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop()))
|
||||||
r0 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.MAX, (1,)))
|
r0 = UOp(UOps.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.MAX, (1,)))
|
||||||
|
@ -730,7 +730,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
ld1 = x.lazydata.st.reshape((N, N, 1))
|
ld1 = x.lazydata.st.reshape((N, N, 1))
|
||||||
ast = UOp(UOps.SINK, src=(
|
ast = UOp(UOps.SINK, src=(
|
||||||
UOp(UOps.STORE, src=(
|
UOp(UOps.STORE, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),))),
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),))),
|
||||||
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
||||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
||||||
|
@ -738,20 +738,20 @@ class TestLinearizer(unittest.TestCase):
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, src=(
|
UOp(UOps.LOAD, dtypes.float, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
|
||||||
ld1.to_uop(),)),
|
ld1.to_uop(),)),
|
||||||
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
||||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
||||||
ast_const(dtypes.float, 0.75*N, (N, N, 1)),
|
ast_const(dtypes.float, 0.75*N, (N, N, 1)),
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, src=(
|
UOp(UOps.LOAD, dtypes.float, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
|
||||||
ld0.to_uop(),)),)),)),
|
ld0.to_uop(),)),)),)),
|
||||||
UOp(UOps.LOAD, dtypes.float, src=(
|
UOp(UOps.LOAD, dtypes.float, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
|
||||||
UOp(UOps.LOAD, dtypes.float, src=(
|
UOp(UOps.LOAD, dtypes.float, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)), # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)), # noqa: E501
|
||||||
|
|
||||||
ast_const(dtypes.float, 0.0, (N, 1, 1)),
|
ast_const(dtypes.float, 0.0, (N, 1, 1)),
|
||||||
|
@ -763,7 +763,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=0,keepdims=True), a.numpy(), b.numpy())).sum(axis=0),0.0,1.0).reshape(1,1,N) # noqa: E501
|
wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(axis=0,keepdims=True), a.numpy(), b.numpy())).sum(axis=0),0.0,1.0).reshape(1,1,N) # noqa: E501
|
||||||
ast = UOp(UOps.SINK, src=(
|
ast = UOp(UOps.SINK, src=(
|
||||||
UOp(UOps.STORE, src=(
|
UOp(UOps.STORE, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, N), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, N), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||||
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
||||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
||||||
|
@ -771,20 +771,20 @@ class TestLinearizer(unittest.TestCase):
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=(
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, src=(
|
UOp(UOps.LOAD, dtypes.float, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||||
ld1.to_uop(),)),
|
ld1.to_uop(),)),
|
||||||
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
||||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
||||||
ast_const(dtypes.float, 0.75*N, (N, 1, N)),
|
ast_const(dtypes.float, 0.75*N, (N, 1, N)),
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, src=(
|
UOp(UOps.LOAD, dtypes.float, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||||
ld0.to_uop(),)),)),)),
|
ld0.to_uop(),)),)),)),
|
||||||
UOp(UOps.LOAD, dtypes.float, src=(
|
UOp(UOps.LOAD, dtypes.float, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), # noqa: E501
|
||||||
UOp(UOps.LOAD, dtypes.float, src=(
|
UOp(UOps.LOAD, dtypes.float, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), # noqa: E501
|
||||||
|
|
||||||
ast_const(dtypes.float, 0.0, (1, 1, N)),
|
ast_const(dtypes.float, 0.0, (1, 1, N)),
|
||||||
|
@ -799,7 +799,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(keepdims=True), a.numpy(), b.numpy())).sum(keepdims=True),0.0,1.0).reshape((1,1,1,1))# noqa: E501
|
wanna_output = np.where(0.5*17 < (x.numpy()+np.where(0.75*17 < x.numpy().sum(keepdims=True), a.numpy(), b.numpy())).sum(keepdims=True),0.0,1.0).reshape((1,1,1,1))# noqa: E501
|
||||||
ast = UOp(UOps.SINK, src=(
|
ast = UOp(UOps.SINK, src=(
|
||||||
UOp(UOps.STORE, src=(
|
UOp(UOps.STORE, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=True),))),
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=True),))),
|
||||||
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
||||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
||||||
|
@ -807,20 +807,20 @@ class TestLinearizer(unittest.TestCase):
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, src=(
|
UOp(UOps.LOAD, dtypes.float, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(N, 1, 0, 0), offset=0, mask=None, contiguous=True),))),)), # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(N, 1, 0, 0), offset=0, mask=None, contiguous=True),))),)), # noqa: E501
|
||||||
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
||||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
||||||
ast_const(dtypes.float, 0.75*N, (N, N, 1, 1)),
|
ast_const(dtypes.float, 0.75*N, (N, N, 1, 1)),
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, src=(
|
UOp(UOps.LOAD, dtypes.float, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, N, N), strides=(0, 0, N, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, N, N), strides=(0, 0, N, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501
|
||||||
UOp(UOps.LOAD, dtypes.float, src=(
|
UOp(UOps.LOAD, dtypes.float, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
|
||||||
UOp(UOps.LOAD, dtypes.float, src=(
|
UOp(UOps.LOAD, dtypes.float, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)), # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)), # noqa: E501
|
||||||
ast_const(dtypes.float, 0.0, (1, 1, 1, 1)),
|
ast_const(dtypes.float, 0.0, (1, 1, 1, 1)),
|
||||||
ast_const(dtypes.float, 1.0, (1, 1, 1, 1)),)),)),))
|
ast_const(dtypes.float, 1.0, (1, 1, 1, 1)),)),)),))
|
||||||
|
@ -829,7 +829,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||||
def test_end_local(self):
|
def test_end_local(self):
|
||||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=i) for i in range(2)]
|
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=i) for i in range(2)]
|
||||||
load = UOp(UOps.LOAD, dtypes.int, (g1, ShapeTracker.from_shape((32,)).to_uop()))
|
load = UOp(UOps.LOAD, dtypes.int, (g1, ShapeTracker.from_shape((32,)).to_uop()))
|
||||||
reduce = UOp(UOps.REDUCE_AXIS, dtypes.int, (load,), (BinaryOps.ADD, (0,)))
|
reduce = UOp(UOps.REDUCE_AXIS, dtypes.int, (load,), (BinaryOps.ADD, (0,)))
|
||||||
store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((1,)).to_uop(), reduce))
|
store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((1,)).to_uop(), reduce))
|
||||||
|
@ -931,7 +931,7 @@ class TestLinearizer(unittest.TestCase):
|
||||||
# make sure const buffers are differentiated from local and mem buffers
|
# make sure const buffers are differentiated from local and mem buffers
|
||||||
ST, DT = ShapeTracker(views=(View(shape=((1,)), strides=(0, 0), offset=0, mask=None, contiguous=False),)).to_uop(), dtypes.int
|
ST, DT = ShapeTracker(views=(View(shape=((1,)), strides=(0, 0), offset=0, mask=None, contiguous=False),)).to_uop(), dtypes.int
|
||||||
VAL = ast_const(DT, 2, ST.arg.shape)
|
VAL = ast_const(DT, 2, ST.arg.shape)
|
||||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(DT), arg=i) for i in range(2)]
|
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, DT.ptr(), arg=i) for i in range(2)]
|
||||||
|
|
||||||
# data1[0] + VAL
|
# data1[0] + VAL
|
||||||
a = UOp(UOps.LOAD, DT, (g1, ST)) + VAL
|
a = UOp(UOps.LOAD, DT, (g1, ST)) + VAL
|
||||||
|
@ -1368,10 +1368,10 @@ class TestLinearizer(unittest.TestCase):
|
||||||
Tensor.manual_seed(0)
|
Tensor.manual_seed(0)
|
||||||
ast = UOp(UOps.SINK, src=(
|
ast = UOp(UOps.SINK, src=(
|
||||||
UOp(UOps.STORE, src=(
|
UOp(UOps.STORE, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),))),
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(40, 1, 0, 0), offset=0, mask=None, contiguous=True),))),
|
||||||
UOp(UOps.LOAD, dtypes.float, src=(
|
UOp(UOps.LOAD, dtypes.float, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(240, 40, 1, 1), strides=(1, 240, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501
|
||||||
opt = [
|
opt = [
|
||||||
Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16),
|
Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=16),
|
||||||
|
@ -1387,10 +1387,10 @@ class TestLinearizer(unittest.TestCase):
|
||||||
Tensor.manual_seed(0)
|
Tensor.manual_seed(0)
|
||||||
ast = UOp(UOps.SINK, src=(
|
ast = UOp(UOps.SINK, src=(
|
||||||
UOp(UOps.STORE, src=(
|
UOp(UOps.STORE, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),))),
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),))),
|
||||||
UOp(UOps.LOAD, dtypes.float, src=(
|
UOp(UOps.LOAD, dtypes.float, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)) # noqa: E501
|
||||||
opt = [Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=8),
|
opt = [Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=8),
|
||||||
Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8),
|
Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8),
|
||||||
|
@ -1603,16 +1603,16 @@ class TestFloat4(unittest.TestCase):
|
||||||
# from llama 7B shard 4 gpus
|
# from llama 7B shard 4 gpus
|
||||||
ast = UOp(UOps.SINK, src=(
|
ast = UOp(UOps.SINK, src=(
|
||||||
UOp(UOps.STORE, src=(
|
UOp(UOps.STORE, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
|
||||||
UOp(UOps.CAST, dtypes.float, src=(
|
UOp(UOps.CAST, dtypes.float, src=(
|
||||||
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
|
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
|
||||||
UOp(UOps.LOAD, dtypes.half, src=(
|
UOp(UOps.LOAD, dtypes.half, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 4096, 0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
|
||||||
UOp(UOps.LOAD, dtypes.half, src=(
|
UOp(UOps.LOAD, dtypes.half, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 0, 1024, 1), offset=0, mask=None, contiguous=False),))),)),)),)),)),)),)) # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1024), strides=(0, 0, 1024, 1), offset=0, mask=None, contiguous=False),))),)),)),)),)),)),)) # noqa: E501
|
||||||
|
|
||||||
# TODO: fix this, expected might change but should be positive
|
# TODO: fix this, expected might change but should be positive
|
||||||
|
@ -1632,19 +1632,19 @@ class TestFloat4(unittest.TestCase):
|
||||||
# from float32 stable diffusion red tinybox
|
# from float32 stable diffusion red tinybox
|
||||||
ast = UOp(UOps.SINK, src=(
|
ast = UOp(UOps.SINK, src=(
|
||||||
UOp(UOps.STORE, src=(
|
UOp(UOps.STORE, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=(
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, src=(
|
UOp(UOps.LOAD, dtypes.float, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False)))),)), # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 256, 4, 514, 4, 514), strides=(0, 0, 0, 262144, 0, 512, 0, 1), offset=-513, mask=((0, 1), (0, 1), (0, 1), (0, 256), (0, 4), (1, 513), (0, 4), (1, 513)), contiguous=False), View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 0, 2056, 1, 4227136, 1058840, 515), offset=0, mask=None, contiguous=False)))),)), # noqa: E501
|
||||||
UOp(UOps.LOAD, dtypes.float, src=(
|
UOp(UOps.LOAD, dtypes.float, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 256, 3, 3), strides=(0, 0, 2304, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501
|
||||||
UOp(UOps.LOAD, dtypes.float, src=(
|
UOp(UOps.LOAD, dtypes.float, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)) # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)) # noqa: E501
|
||||||
|
|
||||||
for expected, opts in [
|
for expected, opts in [
|
||||||
|
@ -1662,13 +1662,13 @@ class TestFloat4(unittest.TestCase):
|
||||||
# from resnet
|
# from resnet
|
||||||
ast = UOp(UOps.SINK, src=(
|
ast = UOp(UOps.SINK, src=(
|
||||||
UOp(UOps.STORE, src=(
|
UOp(UOps.STORE, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),))), # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),))), # noqa: E501
|
||||||
UOp(UOps.CAST, dtypes.half, src=(
|
UOp(UOps.CAST, dtypes.half, src=(
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (4, 6)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (4, 6)), src=(
|
||||||
UOp(UOps.CAST, dtypes.float, src=(
|
UOp(UOps.CAST, dtypes.float, src=(
|
||||||
UOp(UOps.LOAD, dtypes.half, src=(
|
UOp(UOps.LOAD, dtypes.half, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True)))),)),)),)),)),)),)) # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(256, 64, 3, 56, 2, 3, 56, 2), strides=(1806336, 28224, 3, 504, 0, 1, 9, 0), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 56), (0, 1), (0, 3), (0, 56), (0, 1)), contiguous=False), View(shape=(256, 64, 3, 115, 3, 115), strides=(7225344, 112896, 37632, 336, 112, 1), offset=0, mask=((0, 256), (0, 64), (0, 3), (0, 112), (0, 3), (0, 112)), contiguous=False), View(shape=(256, 64, 456, 456), strides=(7617600, 119025, 345, 1), offset=0, mask=((0, 256), (0, 64), (0, 345), (0, 345)), contiguous=False), View(shape=(1, 256, 1, 64, 4, 114, 4, 114), strides=(0, 13307904, 0, 207936, 51984, 456, 114, 1), offset=0, mask=None, contiguous=True)))),)),)),)),)),)),)) # noqa: E501
|
||||||
for expected, opts in [
|
for expected, opts in [
|
||||||
(16, [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=3, amt=4)]), # noqa: E501
|
(16, [Opt(op=OptOps.LOCAL, axis=1, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=3, amt=4)]), # noqa: E501
|
||||||
|
@ -1950,20 +1950,20 @@ class TestKernelOpts(unittest.TestCase):
|
||||||
def test_buf_index_not_found_tensor_core(self):
|
def test_buf_index_not_found_tensor_core(self):
|
||||||
ast = UOp(UOps.SINK, src=(
|
ast = UOp(UOps.SINK, src=(
|
||||||
UOp(UOps.STORE, src=(
|
UOp(UOps.STORE, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),))),
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),))),
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=(
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||||
UOp(UOps.CAST, dtypes.float, src=(
|
UOp(UOps.CAST, dtypes.float, src=(
|
||||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||||
UOp(UOps.LOAD, dtypes.int, src=(
|
UOp(UOps.LOAD, dtypes.int, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
|
||||||
UOp(UOps.LOAD, dtypes.int, src=(
|
UOp(UOps.LOAD, dtypes.int, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=2),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501
|
||||||
UOp(UOps.LOAD, dtypes.float, src=(
|
UOp(UOps.LOAD, dtypes.float, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)) # noqa: E501
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)) # noqa: E501
|
||||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||||
with self.assertRaises(KernelOptError):
|
with self.assertRaises(KernelOptError):
|
||||||
|
@ -2138,7 +2138,7 @@ class TestKernelOpts(unittest.TestCase):
|
||||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared")
|
||||||
def test_padto_group(self):
|
def test_padto_group(self):
|
||||||
Tensor.manual_seed(0)
|
Tensor.manual_seed(0)
|
||||||
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)]
|
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||||
ld0 = UOp(UOps.LOAD, dtypes.float, (g1, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501
|
ld0 = UOp(UOps.LOAD, dtypes.float, (g1, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501
|
||||||
ld1 = UOp(UOps.LOAD, dtypes.float, (g2, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501
|
ld1 = UOp(UOps.LOAD, dtypes.float, (g2, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501
|
||||||
store = UOp(UOps.STORE, src=(g0, ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)).to_uop(), UOp(UOps.REDUCE_AXIS, dtypes.float, (ld0*ld1,), (BinaryOps.ADD, (0, 2, 4, 6)),))) # noqa: E501
|
store = UOp(UOps.STORE, src=(g0, ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)).to_uop(), UOp(UOps.REDUCE_AXIS, dtypes.float, (ld0*ld1,), (BinaryOps.ADD, (0, 2, 4, 6)),))) # noqa: E501
|
||||||
|
|
|
@ -5,7 +5,6 @@
|
||||||
import unittest
|
import unittest
|
||||||
from test.helpers import ast_const
|
from test.helpers import ast_const
|
||||||
from tinygrad import Device, dtypes
|
from tinygrad import Device, dtypes
|
||||||
from tinygrad.dtype import PtrDType
|
|
||||||
from tinygrad.ops import UOp, UOps, BinaryOps, TernaryOps
|
from tinygrad.ops import UOp, UOps, BinaryOps, TernaryOps
|
||||||
from tinygrad.helpers import getenv
|
from tinygrad.helpers import getenv
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||||
|
@ -17,7 +16,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||||
def test_unmerged_ifs(self):
|
def test_unmerged_ifs(self):
|
||||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||||
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MAX, src=(
|
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MAX, src=(
|
||||||
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
|
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
|
||||||
|
@ -26,10 +25,10 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||||
UOp(UOps.CAST, dtypes.float, arg=None, src=(
|
UOp(UOps.CAST, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
|
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
|
||||||
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 512, 4, 9, 4, 9), strides=(0, 25088, 0, 49, 0, 7, 0, 1), offset=-8, mask=((0, 1), (0, 64), (0, 1), (0, 512), (0, 4), (1, 8), (0, 4), (1, 8)), contiguous=False), View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(663552, 0, 0, 36, 1, 1296, 360, 10), offset=0, mask=None, contiguous=False))), src=()),)),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 512, 4, 9, 4, 9), strides=(0, 25088, 0, 49, 0, 7, 0, 1), offset=-8, mask=((0, 1), (0, 64), (0, 1), (0, 512), (0, 4), (1, 8), (0, 4), (1, 8)), contiguous=False), View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(663552, 0, 0, 36, 1, 1296, 360, 10), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||||
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),
|
||||||
ast_const(dtypes.half, 0.9999950000374996, st_src=(
|
ast_const(dtypes.half, 0.9999950000374996, st_src=(
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(64, 1, 512, 7, 7, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||||
|
@ -50,17 +49,17 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||||
def test_max_simplify_and_cancel(self):
|
def test_max_simplify_and_cancel(self):
|
||||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
|
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
|
||||||
UOp(UOps.CAST, dtypes.int, arg=None, src=(
|
UOp(UOps.CAST, dtypes.int, arg=None, src=(
|
||||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),)),
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||||
ast_const(dtypes.bool, True, st_src=(
|
ast_const(dtypes.bool, True, st_src=(
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||||
|
@ -81,11 +80,11 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||||
def test_expander_new_srcs(self):
|
def test_expander_new_srcs(self):
|
||||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),))
|
||||||
opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0)]
|
opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.LOCAL, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=0)]
|
||||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||||
|
@ -102,7 +101,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||||
def test_llama_embedding(self):
|
def test_llama_embedding(self):
|
||||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||||
UOp(UOps.CAST, dtypes.half, arg=None, src=(
|
UOp(UOps.CAST, dtypes.half, arg=None, src=(
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
||||||
|
@ -118,12 +117,12 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||||
ast_const(dtypes.int, -1, st_src=(
|
ast_const(dtypes.int, -1, st_src=(
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||||
ast_const(dtypes.bool, True, st_src=(
|
ast_const(dtypes.bool, True, st_src=(
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||||
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(1, 4096, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 32000, 1), strides=(1, 4096, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),))
|
||||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||||
prg = k.to_program()
|
prg = k.to_program()
|
||||||
|
@ -134,7 +133,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||||
def test_unaligns_idxs(self):
|
def test_unaligns_idxs(self):
|
||||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||||
|
@ -142,16 +141,16 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||||
UOp(UOps.LOAD, dtypes.long, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.long, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.long), arg=1, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||||
UOp(UOps.CAST, dtypes.long, arg=None, src=(
|
UOp(UOps.CAST, dtypes.long, arg=None, src=(
|
||||||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=2, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||||
ast_const(dtypes.bool, True, st_src=(
|
ast_const(dtypes.bool, True, st_src=(
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 5), strides=(0, 0, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||||
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=3)]
|
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0), Opt(op=OptOps.LOCAL, axis=0, amt=3)]
|
||||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||||
|
@ -166,14 +165,14 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||||
def test_unrolled_float4_align(self):
|
def test_unrolled_float4_align(self):
|
||||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||||
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
||||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPNE, src=(
|
||||||
UOp(UOps.LOAD, dtypes.long, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.long, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.long), arg=1, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.long.ptr(), arg=1, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),
|
||||||
ast_const(dtypes.long, -1, st_src=(
|
ast_const(dtypes.long, -1, st_src=(
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||||
|
@ -182,7 +181,7 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||||
ast_const(dtypes.float, 0.0, st_src=(
|
ast_const(dtypes.float, 0.0, st_src=(
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),))
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 6), strides=(6, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),))
|
||||||
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0)]
|
opts = [Opt(op=OptOps.UNROLL, axis=0, amt=0)]
|
||||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||||
|
@ -198,15 +197,15 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||||
def test_upcasted_stores_out_of_order(self):
|
def test_upcasted_stores_out_of_order(self):
|
||||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6,)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6,)), src=(
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(0, 0, 0, 0, 0, 0, 1, 0, 4, 48, 16), offset=0, mask=None, contiguous=False),)), src=()),)),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(0, 0, 0, 0, 0, 0, 1, 0, 4, 48, 16), offset=0, mask=None, contiguous=False),)), src=()),)),
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 4, 1, 4, 3, 3), strides=(260, 13, 1, 0, 0, 0, 65, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||||
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=0)]
|
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=2, amt=0)]
|
||||||
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
k = Kernel(ast, opts=Device[Device.DEFAULT].renderer)
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -9,7 +9,6 @@ from tinygrad.engine.search import time_linearizer, bufs_from_lin
|
||||||
|
|
||||||
# stuff needed to unpack a kernel
|
# stuff needed to unpack a kernel
|
||||||
from tinygrad.ops import UOp, UOps, BinaryOps, UnaryOps
|
from tinygrad.ops import UOp, UOps, BinaryOps, UnaryOps
|
||||||
from tinygrad.dtype import PtrDType
|
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
from tinygrad.shape.view import View
|
from tinygrad.shape.view import View
|
||||||
|
|
||||||
|
@ -27,7 +26,7 @@ class TestLinearizerOverflow(unittest.TestCase):
|
||||||
def test_overflow_1(self):
|
def test_overflow_1(self):
|
||||||
ast = UOp(UOps.SINK, None, arg=None, src=(
|
ast = UOp(UOps.SINK, None, arg=None, src=(
|
||||||
UOp(UOps.STORE, None, arg=None, src=(
|
UOp(UOps.STORE, None, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MAX, src=(
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||||
|
@ -37,15 +36,15 @@ class TestLinearizerOverflow(unittest.TestCase):
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, 64), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))), src=()),)),
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, 64), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False), View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),
|
||||||
x16:=ast_const(dtypes.float, 0.0, st_src=(
|
x16:=ast_const(dtypes.float, 0.0, st_src=(
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||||
UOp(UOps.ALU, dtypes.float, arg=UnaryOps.SQRT, src=(
|
UOp(UOps.ALU, dtypes.float, arg=UnaryOps.SQRT, src=(
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||||
|
@ -57,7 +56,7 @@ class TestLinearizerOverflow(unittest.TestCase):
|
||||||
ast_const(dtypes.float, 1e-05, st_src=(
|
ast_const(dtypes.float, 1e-05, st_src=(
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)),
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=4, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(64, 1, 64, 112, 112, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||||
x16,)),)),))
|
x16,)),)),))
|
||||||
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=0)]
|
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=0)]
|
||||||
|
@ -67,15 +66,15 @@ class TestLinearizerOverflow(unittest.TestCase):
|
||||||
def test_overflow_2(self):
|
def test_overflow_2(self):
|
||||||
ast = UOp(UOps.SINK, None, arg=None, src=(
|
ast = UOp(UOps.SINK, None, arg=None, src=(
|
||||||
UOp(UOps.STORE, None, arg=None, src=(
|
UOp(UOps.STORE, None, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 1, 1, 1), strides=(65536, 0, 1024, 32, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 1, 1, 1), strides=(65536, 0, 1024, 32, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 512, 1, 32, 4, 34, 4, 34), strides=(0, 32768, 0, 1024, 0, 32, 0, 1), offset=-33, mask=((0, 1), (0, 512), (0, 1), (0, 32), (0, 4), (1, 33), (0, 4), (1, 33)), contiguous=False), View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(591872, 0, 0, 136, 1, 18496, 4760, 35), offset=0, mask=None, contiguous=False))), src=()),)),
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 512, 1, 32, 4, 34, 4, 34), strides=(0, 32768, 0, 1024, 0, 32, 0, 1), offset=-33, mask=((0, 1), (0, 512), (0, 1), (0, 32), (0, 4), (1, 33), (0, 4), (1, 33)), contiguous=False), View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(591872, 0, 0, 136, 1, 18496, 4760, 35), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(0, 0, 288, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 32, 3, 3), strides=(0, 0, 288, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||||
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=0)]
|
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UNROLL, axis=0, amt=0)]
|
||||||
_test_overflow(ast, opts)
|
_test_overflow(ast, opts)
|
||||||
|
@ -84,15 +83,15 @@ class TestLinearizerOverflow(unittest.TestCase):
|
||||||
def test_overflow_3(self):
|
def test_overflow_3(self):
|
||||||
ast = UOp(UOps.SINK, None, arg=None, src=(
|
ast = UOp(UOps.SINK, None, arg=None, src=(
|
||||||
UOp(UOps.STORE, None, arg=None, src=(
|
UOp(UOps.STORE, None, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 16, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 16), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 16, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 16), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||||
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=2)]
|
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=2)]
|
||||||
_test_overflow(ast, opts)
|
_test_overflow(ast, opts)
|
||||||
|
@ -101,15 +100,15 @@ class TestLinearizerOverflow(unittest.TestCase):
|
||||||
def test_overflow_4(self):
|
def test_overflow_4(self):
|
||||||
ast = UOp(UOps.SINK, None, arg=None, src=(
|
ast = UOp(UOps.SINK, None, arg=None, src=(
|
||||||
UOp(UOps.STORE, None, arg=None, src=(
|
UOp(UOps.STORE, None, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 4, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 4), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 4, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 4), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||||
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)]
|
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)]
|
||||||
_test_overflow(ast, opts)
|
_test_overflow(ast, opts)
|
||||||
|
@ -118,15 +117,15 @@ class TestLinearizerOverflow(unittest.TestCase):
|
||||||
def test_overflow_5(self):
|
def test_overflow_5(self):
|
||||||
ast = UOp(UOps.SINK, None, arg=None, src=(
|
ast = UOp(UOps.SINK, None, arg=None, src=(
|
||||||
UOp(UOps.STORE, None, arg=None, src=(
|
UOp(UOps.STORE, None, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 2), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 2, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 2), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||||
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=2)]
|
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=2)]
|
||||||
_test_overflow(ast, opts)
|
_test_overflow(ast, opts)
|
||||||
|
@ -135,15 +134,15 @@ class TestLinearizerOverflow(unittest.TestCase):
|
||||||
def test_overflow_6(self):
|
def test_overflow_6(self):
|
||||||
ast = UOp(UOps.SINK, None, arg=None, src=(
|
ast = UOp(UOps.SINK, None, arg=None, src=(
|
||||||
UOp(UOps.STORE, None, arg=None, src=(
|
UOp(UOps.STORE, None, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||||
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)]
|
opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=3, amt=2)]
|
||||||
_test_overflow(ast, opts)
|
_test_overflow(ast, opts)
|
||||||
|
@ -152,15 +151,15 @@ class TestLinearizerOverflow(unittest.TestCase):
|
||||||
def test_overflow_7(self):
|
def test_overflow_7(self):
|
||||||
ast = UOp(UOps.SINK, None, arg=None, src=(
|
ast = UOp(UOps.SINK, None, arg=None, src=(
|
||||||
UOp(UOps.STORE, None, arg=None, src=(
|
UOp(UOps.STORE, None, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=(
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(1, 3, 1, 128, 4, 130, 4, 130), strides=(0, 2097152, 0, 16384, 0, 128, 0, 1), offset=-129, mask=((0, 1), (0, 3), (0, 1), (0, 128), (0, 4), (1, 129), (0, 4), (1, 129)), contiguous=False), View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(34611200, 0, 0, 520, 1, 270400, 68120, 131), offset=0, mask=None, contiguous=False))), src=()),)),
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||||
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
UOp(UOps.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 128, 3, 3), strides=(0, 0, 1152, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),))
|
||||||
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)]
|
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)]
|
||||||
_test_overflow(ast, opts)
|
_test_overflow(ast, opts)
|
||||||
|
@ -170,7 +169,7 @@ class TestLinearizerOverflow(unittest.TestCase):
|
||||||
class TestLinearizerOverflowAlt(unittest.TestCase):
|
class TestLinearizerOverflowAlt(unittest.TestCase):
|
||||||
def test_overflow_1(self):
|
def test_overflow_1(self):
|
||||||
BS = 2
|
BS = 2
|
||||||
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)]
|
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||||
in_st_1 = ShapeTracker(views=(View(shape=(1, BS, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, BS), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False),
|
in_st_1 = ShapeTracker(views=(View(shape=(1, BS, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, BS), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False),
|
||||||
View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))).to_uop()
|
View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))).to_uop()
|
||||||
in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop()
|
in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop()
|
||||||
|
@ -182,7 +181,7 @@ class TestLinearizerOverflowAlt(unittest.TestCase):
|
||||||
_test_overflow(ast, opts)
|
_test_overflow(ast, opts)
|
||||||
def test_overflow_2(self):
|
def test_overflow_2(self):
|
||||||
BS = 2
|
BS = 2
|
||||||
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=i) for i in range(3)]
|
g0, g1, g2 = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)]
|
||||||
in_st_1 = ShapeTracker(views=(View(shape=(1, BS, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, BS), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False),
|
in_st_1 = ShapeTracker(views=(View(shape=(1, BS, 1, 3, 8, 230, 8, 230), strides=(0, 150528, 0, 50176, 0, 224, 0, 1), offset=-675, mask=((0, 1), (0, BS), (0, 1), (0, 3), (0, 8), (3, 227), (0, 8), (3, 227)), contiguous=False),
|
||||||
View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))).to_uop()
|
View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(10156800, 0, 0, 3680, 2, 3385600, 425040, 231), offset=0, mask=None, contiguous=False))).to_uop()
|
||||||
in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop()
|
in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop()
|
||||||
|
|
|
@ -4,7 +4,7 @@ import numpy as np
|
||||||
from tinygrad.codegen.uopgraph import full_graph_rewrite
|
from tinygrad.codegen.uopgraph import full_graph_rewrite
|
||||||
from tinygrad.codegen.linearize import linearize_uop
|
from tinygrad.codegen.linearize import linearize_uop
|
||||||
from tinygrad.device import Buffer, Device
|
from tinygrad.device import Buffer, Device
|
||||||
from tinygrad.dtype import PtrDType, dtypes
|
from tinygrad.dtype import dtypes
|
||||||
from tinygrad.engine.realize import CompiledRunner
|
from tinygrad.engine.realize import CompiledRunner
|
||||||
from tinygrad.helpers import dedup, flatten, getenv, prod
|
from tinygrad.helpers import dedup, flatten, getenv, prod
|
||||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||||
|
@ -30,8 +30,8 @@ def _test_uop_result(inputs:List[Tensor], stores:List[UOp], local_size=None):
|
||||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, CStyleLanguage), "uops are for cstyle")
|
@unittest.skipIf(not isinstance(Device[Device.DEFAULT].renderer, CStyleLanguage), "uops are for cstyle")
|
||||||
class TestCStyleFailures(unittest.TestCase):
|
class TestCStyleFailures(unittest.TestCase):
|
||||||
def test_inline_const_alu(self):
|
def test_inline_const_alu(self):
|
||||||
a = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||||
b = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 1)
|
b = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
|
||||||
idx = UOp.const(dtypes.int, 0)
|
idx = UOp.const(dtypes.int, 0)
|
||||||
ld = UOp(UOps.LOAD, dtypes.int, (b, idx))
|
ld = UOp(UOps.LOAD, dtypes.int, (b, idx))
|
||||||
alu = ld.alu(BinaryOps.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1))
|
alu = ld.alu(BinaryOps.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1))
|
||||||
|
@ -45,7 +45,7 @@ class TestCStyleFailures(unittest.TestCase):
|
||||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
|
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local")
|
||||||
class TestPTXFailures(unittest.TestCase):
|
class TestPTXFailures(unittest.TestCase):
|
||||||
def test_gated_store_with_alu(self):
|
def test_gated_store_with_alu(self):
|
||||||
a = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||||
gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
|
gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
|
||||||
gated_alu_store = UOp(UOps.STORE, dtypes.void, (a, lidx0, UOp.const(dtypes.int, 1), gate_alu))
|
gated_alu_store = UOp(UOps.STORE, dtypes.void, (a, lidx0, UOp.const(dtypes.int, 1), gate_alu))
|
||||||
sink = UOp(UOps.SINK, dtypes.void, (gated_alu_store,))
|
sink = UOp(UOps.SINK, dtypes.void, (gated_alu_store,))
|
||||||
|
@ -55,7 +55,7 @@ class TestPTXFailures(unittest.TestCase):
|
||||||
|
|
||||||
@unittest.skip("not still valid?")
|
@unittest.skip("not still valid?")
|
||||||
def test_gated_store_with_if(self):
|
def test_gated_store_with_if(self):
|
||||||
a = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
a = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||||
gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
|
gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
|
||||||
val = UOp.const(dtypes.int, 1)
|
val = UOp.const(dtypes.int, 1)
|
||||||
if_uop = UOp(UOps.IF, dtypes.void, (gate_alu, val))
|
if_uop = UOp(UOps.IF, dtypes.void, (gate_alu, val))
|
||||||
|
|
|
@ -9,7 +9,7 @@ import functools
|
||||||
from typing import List, Optional, Union, cast
|
from typing import List, Optional, Union, cast
|
||||||
|
|
||||||
from tinygrad import nn, dtypes, Device, Tensor
|
from tinygrad import nn, dtypes, Device, Tensor
|
||||||
from tinygrad.dtype import DType, PtrDType
|
from tinygrad.dtype import DType
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
from tinygrad.shape.view import View
|
from tinygrad.shape.view import View
|
||||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite, track_rewrites
|
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite, track_rewrites
|
||||||
|
@ -1611,7 +1611,7 @@ class TestIndexing(unittest.TestCase):
|
||||||
self.assertLess(et, 1e3)
|
self.assertLess(et, 1e3)
|
||||||
|
|
||||||
def test_no_rewrite_elementwise(self):
|
def test_no_rewrite_elementwise(self):
|
||||||
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(3)]
|
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
|
||||||
ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||||
ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop()))
|
ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),))
|
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),))
|
||||||
|
@ -1619,7 +1619,7 @@ class TestIndexing(unittest.TestCase):
|
||||||
self.assertEqual(rsink.key, sink.key)
|
self.assertEqual(rsink.key, sink.key)
|
||||||
|
|
||||||
def test_simple_store_reshape(self):
|
def test_simple_store_reshape(self):
|
||||||
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(2)]
|
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||||
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||||
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||||
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
|
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
|
||||||
|
@ -1632,7 +1632,7 @@ class TestIndexing(unittest.TestCase):
|
||||||
verify_ast(rsink)
|
verify_ast(rsink)
|
||||||
|
|
||||||
def test_no_reshape_reduceop(self):
|
def test_no_reshape_reduceop(self):
|
||||||
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(2)]
|
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||||
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||||
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||||
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
|
sink = UOp(UOps.SINK, dtypes.void, (UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
|
||||||
|
@ -1641,7 +1641,7 @@ class TestIndexing(unittest.TestCase):
|
||||||
self.assertEqual(sink.key, rsink.key)
|
self.assertEqual(sink.key, rsink.key)
|
||||||
|
|
||||||
def test_reshape_many(self):
|
def test_reshape_many(self):
|
||||||
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(2)]
|
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||||
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||||
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||||
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
|
r = UOp(UOps.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
|
||||||
|
@ -1660,7 +1660,7 @@ class TestIndexing(unittest.TestCase):
|
||||||
sizes = [10*(i+1) for i in range(SZ)]
|
sizes = [10*(i+1) for i in range(SZ)]
|
||||||
tms: List[float] = []
|
tms: List[float] = []
|
||||||
for sz in sizes:
|
for sz in sizes:
|
||||||
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(2)]
|
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||||
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||||
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1)))
|
||||||
for _ in range(sz): r = r + ast_const(dtypes.int, 2, ())
|
for _ in range(sz): r = r + ast_const(dtypes.int, 2, ())
|
||||||
|
@ -1682,14 +1682,14 @@ class TestIndexing(unittest.TestCase):
|
||||||
# graph rewrite
|
# graph rewrite
|
||||||
sink = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
sink = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||||
UOp(UOps.VIEW, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501
|
UOp(UOps.VIEW, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( # noqa E501
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||||
x8:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()),
|
x8:=UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), # noqa E501
|
||||||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||||
x8,
|
x8,
|
||||||
|
@ -1709,7 +1709,7 @@ class TestIndexing(unittest.TestCase):
|
||||||
a = Tensor.randint(4,).realize()
|
a = Tensor.randint(4,).realize()
|
||||||
expected_out = a.numpy().sum(0)+1
|
expected_out = a.numpy().sum(0)+1
|
||||||
# LazyBuffer to pre-rewrite AST
|
# LazyBuffer to pre-rewrite AST
|
||||||
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(2)]
|
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||||
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
|
ld = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
|
||||||
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0,)))
|
r = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0,)))
|
||||||
swizzle_r = UOp(UOps.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(()))
|
swizzle_r = UOp(UOps.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(()))
|
||||||
|
@ -1732,7 +1732,7 @@ class TestIndexing(unittest.TestCase):
|
||||||
b = Tensor.randint(4,).realize()
|
b = Tensor.randint(4,).realize()
|
||||||
expected_out = a.numpy().sum(0)+b.numpy().sum(0)+2
|
expected_out = a.numpy().sum(0)+b.numpy().sum(0)+2
|
||||||
# LazyBuffer to pre-rewrite AST
|
# LazyBuffer to pre-rewrite AST
|
||||||
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), i) for i in range(3)]
|
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
|
||||||
ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
|
ld1 = UOp(UOps.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop()))
|
||||||
r1 = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld1,), (BinaryOps.ADD, (0,)))
|
r1 = UOp(UOps.REDUCE_AXIS, dtypes.int, (ld1,), (BinaryOps.ADD, (0,)))
|
||||||
ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop()))
|
ld2 = UOp(UOps.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop()))
|
||||||
|
@ -1752,7 +1752,7 @@ class TestIndexing(unittest.TestCase):
|
||||||
swizzle = UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501
|
swizzle = UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501
|
UOp(UOps.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501
|
||||||
# there's an EXPAND pushing through the REDUCE_AXIS
|
# there's an EXPAND pushing through the REDUCE_AXIS
|
||||||
self.assertGreater(prod(swizzle.st.shape), prod(swizzle.src[0].st.shape))
|
self.assertGreater(prod(swizzle.st.shape), prod(swizzle.src[0].st.shape))
|
||||||
|
@ -1766,7 +1766,7 @@ class TestIndexing(unittest.TestCase):
|
||||||
|
|
||||||
def test_permute_rewrite(self):
|
def test_permute_rewrite(self):
|
||||||
sink = UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
sink = UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||||
x1:=UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(1, ('METAL', 16384, dtypes.float)), src=()),
|
x1:=UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(1, ('METAL', 16384, dtypes.float)), src=()),
|
||||||
x2:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
x2:=UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 16), strides=(0, 512, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||||
UOp(UOps.CONTIGUOUS, dtypes.float, arg=None, src=(
|
UOp(UOps.CONTIGUOUS, dtypes.float, arg=None, src=(
|
||||||
x1,
|
x1,
|
||||||
|
@ -1778,15 +1778,15 @@ class TestIndexing(unittest.TestCase):
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 512, 16, 0, 0, 0, 0, 4, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||||
x11:=UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
x11:=UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(2, ('METAL', 16384, dtypes.float)), src=()),
|
UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(2, ('METAL', 16384, dtypes.float)), src=()),
|
||||||
x2,)),)),
|
x2,)),)),
|
||||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 32, 32, 1, 1, 4, 4, 4, 4, 1, 1), strides=(0, 0, 0, 0, 0, 64, 1, 16, 4, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(8, ('METAL', 256, dtypes.float)), src=()),
|
UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(8, ('METAL', 256, dtypes.float)), src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 4, 1, 4, 4), strides=(64, 0, 16, 0, 4, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)),)),
|
||||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.BUFFER, PtrDType(dtypes.float), arg=(10, ('METAL', 16, dtypes.float)), src=()),
|
UOp(UOps.BUFFER, dtypes.float.ptr(), arg=(10, ('METAL', 16, dtypes.float)), src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
|
||||||
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
|
UOp(UOps.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(1, 16, 32, 32), strides=(0, 1, 512, 16), offset=0, mask=None, contiguous=False),)), src=(
|
||||||
x11,)),)),)),)),))
|
x11,)),)),)),)),))
|
||||||
|
|
|
@ -8,7 +8,7 @@ from tinygrad.engine.schedule import create_schedule
|
||||||
from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search
|
from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search
|
||||||
from tinygrad.device import Device, Buffer
|
from tinygrad.device import Device, Buffer
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.dtype import dtypes, PtrDType
|
from tinygrad.dtype import dtypes
|
||||||
from tinygrad.helpers import Context, GlobalCounters
|
from tinygrad.helpers import Context, GlobalCounters
|
||||||
from tinygrad.engine.realize import capturing
|
from tinygrad.engine.realize import capturing
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
|
@ -48,7 +48,7 @@ class TestTimeLinearizer(unittest.TestCase):
|
||||||
# ast of Tensor.zeros(16).contiguous().realize()
|
# ast of Tensor.zeros(16).contiguous().realize()
|
||||||
ast = UOp(UOps.SINK, src=(
|
ast = UOp(UOps.SINK, src=(
|
||||||
UOp(UOps.STORE, src=(
|
UOp(UOps.STORE, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),))),
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(1,), offset=0, mask=None, contiguous=True),))),
|
||||||
ast_const(dtypes.float, 0.0, st_src=(
|
ast_const(dtypes.float, 0.0, st_src=(
|
||||||
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(0,), offset=0, mask=None, contiguous=False),))),)),)),))
|
UOp(UOps.VIEW, arg=ShapeTracker(views=(View(shape=(16,), strides=(0,), offset=0, mask=None, contiguous=False),))),)),)),))
|
||||||
|
@ -105,7 +105,7 @@ class TestBEAM(unittest.TestCase):
|
||||||
# taken from https://github.com/tinygrad/tinygrad/issues/4612
|
# taken from https://github.com/tinygrad/tinygrad/issues/4612
|
||||||
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
ast = UOp(UOps.SINK, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
|
||||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (1,)), src=(
|
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (1,)), src=(
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.MUL, src=(
|
||||||
|
@ -115,22 +115,22 @@ class TestBEAM(unittest.TestCase):
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)), # noqa: E501
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=0, mask=((0, 64128),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)), # noqa: E501
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-64128, mask=((64128, 128256),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=3, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-128256, mask=((128256, 192384),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=4, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=4, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-192384, mask=((192384, 256512),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=5, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=5, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-256512, mask=((256512, 320640),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=6, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=6, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(384768,), strides=(1,), offset=-320640, mask=((320640, 384768),), contiguous=False), View(shape=(1, 501, 256), strides=(0, 1, 501), offset=256512, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||||
ast_const(dtypes.float, 1.4285714285714286, st_src=(
|
ast_const(dtypes.float, 1.4285714285714286, st_src=(
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) # noqa: E501
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 501, 256), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)) # noqa: E501
|
||||||
|
|
|
@ -2,7 +2,6 @@ from typing import List
|
||||||
import unittest, time
|
import unittest, time
|
||||||
from test.helpers import assert_equiv_uops
|
from test.helpers import assert_equiv_uops
|
||||||
from tinygrad import dtypes, Device
|
from tinygrad import dtypes, Device
|
||||||
from tinygrad.dtype import PtrDType
|
|
||||||
from tinygrad.helpers import DEBUG
|
from tinygrad.helpers import DEBUG
|
||||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo
|
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo
|
||||||
from tinygrad.ops import UPat, PatternMatcher
|
from tinygrad.ops import UPat, PatternMatcher
|
||||||
|
@ -32,7 +31,7 @@ class TestGraphRewriteEfficiency(unittest.TestCase):
|
||||||
def test_expand_rewrite(self):
|
def test_expand_rewrite(self):
|
||||||
sink = UOp(UOps.SINK, dtypes.void, arg=KernelInfo(local_dims=2, upcasted=4, dont_use_locals=False), src=(
|
sink = UOp(UOps.SINK, dtypes.void, arg=KernelInfo(local_dims=2, upcasted=4, dont_use_locals=False), src=(
|
||||||
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
UOp(UOps.STORE, dtypes.void, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 4, 64, 8, 16, 1, 1, 3, 3, 4, 1),
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 4, 64, 8, 16, 1, 1, 3, 3, 4, 1),
|
||||||
strides=(1179648, 9216, 1, 147456, 576, 0, 0, 64, 192, 36864, 0),
|
strides=(1179648, 9216, 1, 147456, 576, 0, 0, 64, 192, 36864, 0),
|
||||||
offset=0, mask=None, contiguous=False),)), src=()),
|
offset=0, mask=None, contiguous=False),)), src=()),
|
||||||
|
@ -40,14 +39,14 @@ class TestGraphRewriteEfficiency(unittest.TestCase):
|
||||||
UOp(UOps.CAST, dtypes.float, arg=None, src=(
|
UOp(UOps.CAST, dtypes.float, arg=None, src=(
|
||||||
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
|
UOp(UOps.ALU, dtypes.half, arg=BinaryOps.MUL, src=(
|
||||||
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=1, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(
|
||||||
View(shape=(1, 1024, 1, 64, 4, 17, 4, 17), strides=(0, 14400, 0, 225, 0, 15, 0, 1), offset=-16,
|
View(shape=(1, 1024, 1, 64, 4, 17, 4, 17), strides=(0, 14400, 0, 225, 0, 15, 0, 1), offset=-16,
|
||||||
mask=((0, 1), (0, 1024), (0, 1), (0, 64), (0, 4), (1, 16), (0, 4), (1, 16)), contiguous=False),
|
mask=((0, 1), (0, 1024), (0, 1), (0, 64), (0, 4), (1, 16), (0, 4), (1, 16)), contiguous=False),
|
||||||
View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(0, 73984, 4734976, 0, 4624, 295936, 68, 18, 1224, 0, 1), offset=0,
|
View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(0, 73984, 4734976, 0, 4624, 295936, 68, 18, 1224, 0, 1), offset=0,
|
||||||
mask=None, contiguous=False))), src=()),)),
|
mask=None, contiguous=False))), src=()),)),
|
||||||
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
|
UOp(UOps.LOAD, dtypes.half, arg=None, src=(
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.half), arg=2, src=()),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.half.ptr(), arg=2, src=()),
|
||||||
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(
|
UOp(UOps.VIEW, dtypes.void, arg=ShapeTracker(views=(
|
||||||
View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(7200, 0, 230400, 900, 0, 14400, 15, 0, 0, 225, 1), offset=0,
|
View(shape=(2, 4, 64, 8, 16, 16, 15, 3, 3, 4, 15), strides=(7200, 0, 230400, 900, 0, 14400, 15, 0, 0, 225, 1), offset=0,
|
||||||
mask=None, contiguous=False),)), src=()),)),)),)),)),)),))
|
mask=None, contiguous=False),)), src=()),)),)),)),)),)),))
|
||||||
|
@ -225,7 +224,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||||
|
|
||||||
@unittest.skip("this test isn't valid uops")
|
@unittest.skip("this test isn't valid uops")
|
||||||
def test_noop_vectorize_fold(self):
|
def test_noop_vectorize_fold(self):
|
||||||
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0)
|
d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
|
||||||
idx = UOp.const(dtypes.int, 0)
|
idx = UOp.const(dtypes.int, 0)
|
||||||
ld = UOp(UOps.LOAD, dtypes.float.vec(2), (d0, idx))
|
ld = UOp(UOps.LOAD, dtypes.float.vec(2), (d0, idx))
|
||||||
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(2), (ld,))
|
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(2), (ld,))
|
||||||
|
@ -236,9 +235,9 @@ class TestUOpGraph(unittest.TestCase):
|
||||||
self.assertEqual(len([x for x in uops if x.op is UOps.VECTORIZE]), 0)
|
self.assertEqual(len([x for x in uops if x.op is UOps.VECTORIZE]), 0)
|
||||||
|
|
||||||
def test_gep_vec_fold(self):
|
def test_gep_vec_fold(self):
|
||||||
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||||
d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 1)
|
d1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
|
||||||
d2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 2)
|
d2 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 2)
|
||||||
idx = UOp.const(dtypes.int, 0)
|
idx = UOp.const(dtypes.int, 0)
|
||||||
def _test_vec(geps, count=4):
|
def _test_vec(geps, count=4):
|
||||||
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(count), geps)
|
vec = UOp(UOps.VECTORIZE, dtypes.float.vec(count), geps)
|
||||||
|
@ -342,8 +341,8 @@ class TestUOpGraph(unittest.TestCase):
|
||||||
assert_equiv_uops(uops[-1], wmma)
|
assert_equiv_uops(uops[-1], wmma)
|
||||||
|
|
||||||
def test_cast_alu_fold(self):
|
def test_cast_alu_fold(self):
|
||||||
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.bool), arg=0)
|
d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0)
|
||||||
d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1)
|
d1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
|
||||||
idx = UOp.const(dtypes.int, 0)
|
idx = UOp.const(dtypes.int, 0)
|
||||||
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
|
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
|
||||||
alu = ld.lt(1).cast(dtypes.bool)
|
alu = ld.lt(1).cast(dtypes.bool)
|
||||||
|
@ -352,8 +351,8 @@ class TestUOpGraph(unittest.TestCase):
|
||||||
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 0)
|
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 0)
|
||||||
|
|
||||||
def test_double_cast_fold(self):
|
def test_double_cast_fold(self):
|
||||||
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0)
|
d0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0)
|
||||||
d1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1)
|
d1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=1)
|
||||||
idx = UOp.const(dtypes.int, 0)
|
idx = UOp.const(dtypes.int, 0)
|
||||||
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
|
ld = UOp(UOps.LOAD, dtypes.int, (d1, idx))
|
||||||
alu = ld.cast(dtypes.float).cast(dtypes.float)
|
alu = ld.cast(dtypes.float).cast(dtypes.float)
|
||||||
|
@ -376,9 +375,9 @@ class TestUOpGraph(unittest.TestCase):
|
||||||
self.assertEqual(out.src[1].arg, 6)
|
self.assertEqual(out.src[1].arg, 6)
|
||||||
|
|
||||||
def test_fold_gated_load(self):
|
def test_fold_gated_load(self):
|
||||||
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
glbl0 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||||
glbl1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 1)
|
glbl1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1)
|
||||||
glbl2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 2)
|
glbl2 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 2)
|
||||||
idx = UOp.const(dtypes.int, 0)
|
idx = UOp.const(dtypes.int, 0)
|
||||||
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False)))
|
ld0 = UOp(UOps.LOAD, dtypes.int, (glbl1, idx, UOp.const(dtypes.int, 2), UOp.const(dtypes.bool, False)))
|
||||||
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True)))
|
ld1 = UOp(UOps.LOAD, dtypes.int, (glbl2, idx, UOp.const(dtypes.int, 3), UOp.const(dtypes.bool, True)))
|
||||||
|
@ -390,8 +389,8 @@ class TestUOpGraph(unittest.TestCase):
|
||||||
assert_equiv_uops(ld0, UOp.load(glbl2, idx, dtype=dtypes.int))
|
assert_equiv_uops(ld0, UOp.load(glbl2, idx, dtype=dtypes.int))
|
||||||
|
|
||||||
def test_fold_gated_load_local(self):
|
def test_fold_gated_load_local(self):
|
||||||
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
glbl0 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||||
smem = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.int, local=True), (), ("temp", 1))
|
smem = UOp(UOps.DEFINE_LOCAL, dtypes.int.ptr(local=True), (), ("temp", 1))
|
||||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
|
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
|
||||||
st = UOp(UOps.STORE, dtypes.void, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int)))
|
st = UOp(UOps.STORE, dtypes.void, (smem, lidx, UOp.load(glbl0, lidx, dtype=dtypes.int)))
|
||||||
barrier = UOp(UOps.BARRIER, dtypes.void, (st, ))
|
barrier = UOp(UOps.BARRIER, dtypes.void, (st, ))
|
||||||
|
@ -405,7 +404,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||||
assert_equiv_uops(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int))
|
assert_equiv_uops(ld0, UOp.load(smem, lidx+2, barrier, dtype=dtypes.int))
|
||||||
|
|
||||||
def test_fold_gated_store(self):
|
def test_fold_gated_store(self):
|
||||||
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||||
idx0 = UOp.const(dtypes.int, 0)
|
idx0 = UOp.const(dtypes.int, 0)
|
||||||
idx1 = UOp.const(dtypes.int, 0)
|
idx1 = UOp.const(dtypes.int, 0)
|
||||||
val = UOp.const(dtypes.int, 42)
|
val = UOp.const(dtypes.int, 42)
|
||||||
|
@ -418,13 +417,13 @@ class TestUOpGraph(unittest.TestCase):
|
||||||
|
|
||||||
@unittest.skip("this is a uop type error")
|
@unittest.skip("this is a uop type error")
|
||||||
def test_asserts_bad_gate(self):
|
def test_asserts_bad_gate(self):
|
||||||
glbl0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
glbl0 = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||||
idx = UOp.const(dtypes.int, 0)
|
idx = UOp.const(dtypes.int, 0)
|
||||||
bad_gate = UOp.const(dtypes.int, 1)
|
bad_gate = UOp.const(dtypes.int, 1)
|
||||||
with self.assertRaises(AssertionError): to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
|
with self.assertRaises(AssertionError): to_uops_list([UOp(UOps.STORE, dtypes.void, (glbl0, idx, UOp.const(dtypes.int, 42), bad_gate))])
|
||||||
|
|
||||||
def test_switched_range_order(self):
|
def test_switched_range_order(self):
|
||||||
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||||
c0 = UOp.const(dtypes.int, 0)
|
c0 = UOp.const(dtypes.int, 0)
|
||||||
c2 = UOp.const(dtypes.int, 2)
|
c2 = UOp.const(dtypes.int, 2)
|
||||||
cf = UOp.const(dtypes.float, 0.0)
|
cf = UOp.const(dtypes.float, 0.0)
|
||||||
|
@ -591,21 +590,21 @@ class TestExpander(unittest.TestCase):
|
||||||
|
|
||||||
class TestLoadStoreFolder(unittest.TestCase):
|
class TestLoadStoreFolder(unittest.TestCase):
|
||||||
def test_simple_load_fold(self):
|
def test_simple_load_fold(self):
|
||||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(4)]
|
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(4)]
|
||||||
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||||
sink = float4_rewrite(sink)
|
sink = float4_rewrite(sink)
|
||||||
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
|
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 1
|
||||||
|
|
||||||
def test_two_load_fold(self):
|
def test_two_load_fold(self):
|
||||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(8)]
|
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i))) for i in range(8)]
|
||||||
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||||
sink = float4_rewrite(sink)
|
sink = float4_rewrite(sink)
|
||||||
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 2
|
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 2
|
||||||
|
|
||||||
def test_simple_load_fold_gated(self):
|
def test_simple_load_fold_gated(self):
|
||||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||||
gate = UOp(UOps.DEFINE_VAR, dtypes.bool)
|
gate = UOp(UOps.DEFINE_VAR, dtypes.bool)
|
||||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
|
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
|
||||||
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||||
|
@ -615,7 +614,7 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||||
self.assertListEqual(list(single_load.src[2].arg), [0.0, 1.0, 2.0, 3.0])
|
self.assertListEqual(list(single_load.src[2].arg), [0.0, 1.0, 2.0, 3.0])
|
||||||
|
|
||||||
def test_simple_load_dont_fold_different_gated(self):
|
def test_simple_load_dont_fold_different_gated(self):
|
||||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||||
gate = UOp.variable("g1", False, True, dtypes.bool)
|
gate = UOp.variable("g1", False, True, dtypes.bool)
|
||||||
gate2 = UOp.variable("g2", False, True, dtypes.bool)
|
gate2 = UOp.variable("g2", False, True, dtypes.bool)
|
||||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
|
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
|
||||||
|
@ -624,14 +623,14 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||||
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 3
|
assert len([x for x in sink.sparents if x.op is UOps.LOAD]) == 3
|
||||||
|
|
||||||
def test_simple_store_fold(self):
|
def test_simple_store_fold(self):
|
||||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||||
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i))) for i in range(4)]
|
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i))) for i in range(4)]
|
||||||
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
|
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
|
||||||
sink = float4_rewrite(sink)
|
sink = float4_rewrite(sink)
|
||||||
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1
|
assert len([x for x in sink.sparents if x.op is UOps.STORE]) == 1
|
||||||
|
|
||||||
def test_simple_store_fold_gate(self):
|
def test_simple_store_fold_gate(self):
|
||||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||||
gate = UOp.variable("g1", False, True, dtypes.bool)
|
gate = UOp.variable("g1", False, True, dtypes.bool)
|
||||||
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
|
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
|
||||||
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
|
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
|
||||||
|
@ -642,7 +641,7 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||||
assert str(one_store.src[3]) == str(gate) # huh, why do i need str here?
|
assert str(one_store.src[3]) == str(gate) # huh, why do i need str here?
|
||||||
|
|
||||||
def test_simple_store_dont_fold(self):
|
def test_simple_store_dont_fold(self):
|
||||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr())
|
||||||
gate = UOp.variable("g1", False, True, dtypes.bool)
|
gate = UOp.variable("g1", False, True, dtypes.bool)
|
||||||
gate2 = UOp.variable("g2", False, True, dtypes.bool)
|
gate2 = UOp.variable("g2", False, True, dtypes.bool)
|
||||||
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
|
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
|
||||||
|
@ -655,8 +654,8 @@ def gate_rewrite(sink): return graph_rewrite(sink, sym + expander + reducer)
|
||||||
|
|
||||||
class TestIFUOps(unittest.TestCase):
|
class TestIFUOps(unittest.TestCase):
|
||||||
def test_create_ifs(self):
|
def test_create_ifs(self):
|
||||||
gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
gbuf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||||
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float, local=True), (), ("smem", 4))
|
sbuf = UOp(UOps.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 4))
|
||||||
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
|
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
|
||||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
|
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
|
||||||
gate = valid&(lidx.ne(2))
|
gate = valid&(lidx.ne(2))
|
||||||
|
@ -674,8 +673,8 @@ class TestIFUOps(unittest.TestCase):
|
||||||
self.assertEqual(len(st.src), 3)
|
self.assertEqual(len(st.src), 3)
|
||||||
|
|
||||||
def test_expand_ifs_one_gate(self):
|
def test_expand_ifs_one_gate(self):
|
||||||
gbuf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
gbuf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||||
sbuf = UOp(UOps.DEFINE_LOCAL, PtrDType(dtypes.float, local=True), (), ("smem", 16))
|
sbuf = UOp(UOps.DEFINE_LOCAL, dtypes.float.ptr(local=True), (), ("smem", 16))
|
||||||
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1)
|
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 4)).lt(1)
|
||||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
|
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 16))
|
||||||
gate = valid&(lidx.ne(2))
|
gate = valid&(lidx.ne(2))
|
||||||
|
@ -694,7 +693,7 @@ class TestIFUOps(unittest.TestCase):
|
||||||
# this will be fixed with the merge gated stores bounty
|
# this will be fixed with the merge gated stores bounty
|
||||||
@unittest.expectedFailure
|
@unittest.expectedFailure
|
||||||
def test_expand_ifs_dumb(self):
|
def test_expand_ifs_dumb(self):
|
||||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||||
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
|
valid = UOp(UOps.SPECIAL, dtypes.int, (), ("gidx0", 10)).lt(5)
|
||||||
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
|
lidx = UOp(UOps.SPECIAL, dtypes.int, (), ("lidx0", 4))
|
||||||
gate = valid&(lidx.ne(2))
|
gate = valid&(lidx.ne(2))
|
||||||
|
|
|
@ -4,7 +4,7 @@ import numpy as np
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||||
from tinygrad.helpers import CI, DEBUG, getenv, Context
|
from tinygrad.helpers import CI, DEBUG, getenv, Context
|
||||||
from tinygrad.dtype import dtypes, DType, PtrDType
|
from tinygrad.dtype import dtypes, DType
|
||||||
from tinygrad.device import Buffer, Device
|
from tinygrad.device import Buffer, Device
|
||||||
from tinygrad.ops import UOps, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
|
from tinygrad.ops import UOps, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, ReduceOps, KernelInfo, exec_alu, spec # noqa F401
|
||||||
from tinygrad.renderer import Program
|
from tinygrad.renderer import Program
|
||||||
|
@ -30,8 +30,8 @@ def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], src:Tuple[UOp, ...], ar
|
||||||
def _test_single_value(vals, op, dts):
|
def _test_single_value(vals, op, dts):
|
||||||
uops = []
|
uops = []
|
||||||
output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1]
|
output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1]
|
||||||
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), 0)
|
buf_store = uop(uops, UOps.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
|
||||||
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), i+1) for i,dtype in enumerate(dts)]
|
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)]
|
||||||
loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts))
|
loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts))
|
||||||
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
|
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
|
||||||
out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
||||||
|
@ -46,7 +46,7 @@ def _test_single_value(vals, op, dts):
|
||||||
def _test_single_value_const(vals, op, dts):
|
def _test_single_value_const(vals, op, dts):
|
||||||
uops = []
|
uops = []
|
||||||
output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1]
|
output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1]
|
||||||
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), 0)
|
buf_store = uop(uops, UOps.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
|
||||||
loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
|
loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
|
||||||
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
|
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
|
||||||
out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
||||||
|
@ -59,7 +59,7 @@ def _test_single_value_const(vals, op, dts):
|
||||||
|
|
||||||
def _test_uops_result(output_dtype, uops, res):
|
def _test_uops_result(output_dtype, uops, res):
|
||||||
# uops = []
|
# uops = []
|
||||||
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), 0)
|
buf_store = uop(uops, UOps.DEFINE_GLOBAL, output_dtype.ptr(), (), 0)
|
||||||
# res = output_fn(uops)
|
# res = output_fn(uops)
|
||||||
out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), res))
|
out = uop(uops, UOps.STORE, dtypes.void, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), res))
|
||||||
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
|
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
|
||||||
|
@ -246,7 +246,7 @@ class TestConstantFolding(unittest.TestCase):
|
||||||
class TestGatedStoreRewrite(unittest.TestCase):
|
class TestGatedStoreRewrite(unittest.TestCase):
|
||||||
@unittest.expectedFailure
|
@unittest.expectedFailure
|
||||||
def test_tiny_gate_store(self):
|
def test_tiny_gate_store(self):
|
||||||
gmem = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
gmem = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||||
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||||
idx = gidx0 * UOp.const(dtypes.int, 2)
|
idx = gidx0 * UOp.const(dtypes.int, 2)
|
||||||
val = UOp.const(dtypes.float, 42.0)
|
val = UOp.const(dtypes.float, 42.0)
|
||||||
|
@ -263,8 +263,8 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||||
|
|
||||||
@unittest.expectedFailure
|
@unittest.expectedFailure
|
||||||
def test_gate_some_stores(self):
|
def test_gate_some_stores(self):
|
||||||
gmem0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
gmem0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||||
gmem1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 1)
|
gmem1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
|
||||||
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||||
idx = gidx0*UOp.const(dtypes.int, 2)
|
idx = gidx0*UOp.const(dtypes.int, 2)
|
||||||
val = UOp.const(dtypes.float, 42.0)
|
val = UOp.const(dtypes.float, 42.0)
|
||||||
|
@ -282,8 +282,8 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||||
# scaled down version of TestLinearizerDumb.test_unmerged_ifs
|
# scaled down version of TestLinearizerDumb.test_unmerged_ifs
|
||||||
@unittest.expectedFailure
|
@unittest.expectedFailure
|
||||||
def test_merge_ifs_alt(self):
|
def test_merge_ifs_alt(self):
|
||||||
gmem0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
gmem0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||||
gmem1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 1)
|
gmem1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
|
||||||
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||||
idx = gidx0*UOp.const(dtypes.int, 2)
|
idx = gidx0*UOp.const(dtypes.int, 2)
|
||||||
val = UOp.const(dtypes.float, 42.0)
|
val = UOp.const(dtypes.float, 42.0)
|
||||||
|
@ -305,7 +305,7 @@ class TestLocalAccess(unittest.TestCase):
|
||||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
|
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
|
||||||
def test_local_basic(self):
|
def test_local_basic(self):
|
||||||
uops = []
|
uops = []
|
||||||
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.float32, local=True), (), ('smem', 16))
|
smem = uop(uops, UOps.DEFINE_LOCAL, dtypes.float32.ptr(local=True), (), ('smem', 16))
|
||||||
st = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), uop(uops, UOps.CONST, dtypes.float32, (), 42.0)))
|
st = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), uop(uops, UOps.CONST, dtypes.float32, (), 42.0)))
|
||||||
barr = uop(uops, UOps.BARRIER, dtypes.void, (st,))
|
barr = uop(uops, UOps.BARRIER, dtypes.void, (st,))
|
||||||
sres = uop(uops, UOps.LOAD, dtypes.float32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), barr))
|
sres = uop(uops, UOps.LOAD, dtypes.float32, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 0), barr))
|
||||||
|
@ -314,7 +314,7 @@ class TestLocalAccess(unittest.TestCase):
|
||||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
|
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_shared, "test requires shared memory")
|
||||||
def test_local_indirect(self):
|
def test_local_indirect(self):
|
||||||
uops = []
|
uops = []
|
||||||
smem = uop(uops, UOps.DEFINE_LOCAL, PtrDType(dtypes.int32, local=True), (), ('smem', 16))
|
smem = uop(uops, UOps.DEFINE_LOCAL, dtypes.int32.ptr(local=True), (), ('smem', 16))
|
||||||
st1 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), uop(uops, UOps.CONST, dtypes.int32, (), 2)))
|
st1 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 1), uop(uops, UOps.CONST, dtypes.int32, (), 2)))
|
||||||
st2 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 2), uop(uops, UOps.CONST, dtypes.int32, (), 42)))
|
st2 = uop(uops, UOps.STORE, dtypes.void, (smem, uop(uops, UOps.CONST, dtypes.int32, (), 2), uop(uops, UOps.CONST, dtypes.int32, (), 42)))
|
||||||
barr = uop(uops, UOps.BARRIER, dtypes.void, (st1,st2))
|
barr = uop(uops, UOps.BARRIER, dtypes.void, (st1,st2))
|
||||||
|
@ -325,7 +325,7 @@ class TestLocalAccess(unittest.TestCase):
|
||||||
@unittest.skipUnless(getenv("PTX"), "This only tests assembly backends")
|
@unittest.skipUnless(getenv("PTX"), "This only tests assembly backends")
|
||||||
class TestAssembly(unittest.TestCase):
|
class TestAssembly(unittest.TestCase):
|
||||||
def test_bitshift_left(self):
|
def test_bitshift_left(self):
|
||||||
g1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 0)
|
g1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
|
||||||
c1 = UOp(UOps.CONST, dtypes.int, (), 2)
|
c1 = UOp(UOps.CONST, dtypes.int, (), 2)
|
||||||
c2 = UOp(UOps.CONST, dtypes.int, (), 3)
|
c2 = UOp(UOps.CONST, dtypes.int, (), 3)
|
||||||
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
|
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
|
||||||
|
@ -337,7 +337,7 @@ class TestAssembly(unittest.TestCase):
|
||||||
self.assertEqual(uops[-2].arg, BinaryOps.MUL)
|
self.assertEqual(uops[-2].arg, BinaryOps.MUL)
|
||||||
|
|
||||||
def test_bitshift_right(self):
|
def test_bitshift_right(self):
|
||||||
g1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 0)
|
g1 = UOp(UOps.DEFINE_GLOBAL, dtypes.int32.ptr(), (), 0)
|
||||||
c1 = UOp(UOps.CONST, dtypes.int, (), 2)
|
c1 = UOp(UOps.CONST, dtypes.int, (), 2)
|
||||||
c2 = UOp(UOps.CONST, dtypes.int, (), 3)
|
c2 = UOp(UOps.CONST, dtypes.int, (), 3)
|
||||||
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
|
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
|
||||||
|
@ -361,7 +361,7 @@ class TestUOpMethod(unittest.TestCase):
|
||||||
def test_uop_variables(self):
|
def test_uop_variables(self):
|
||||||
a = UOp.variable("a", 1, 10)
|
a = UOp.variable("a", 1, 10)
|
||||||
uop_var = UOp.const(dtypes.int, a)
|
uop_var = UOp.const(dtypes.int, a)
|
||||||
st_var = UOp(UOps.LOAD, dtypes.float, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0),
|
st_var = UOp(UOps.LOAD, dtypes.float, (UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0),
|
||||||
ShapeTracker.from_shape((2, a)).to_uop()))
|
ShapeTracker.from_shape((2, a)).to_uop()))
|
||||||
ast_vars = (st_var+uop_var).variables()
|
ast_vars = (st_var+uop_var).variables()
|
||||||
self.assertEqual(len(ast_vars), 1)
|
self.assertEqual(len(ast_vars), 1)
|
||||||
|
@ -376,7 +376,7 @@ class TestUOpMethod(unittest.TestCase):
|
||||||
self.assertEqual((gidx0*3+1).const_factor(), 1)
|
self.assertEqual((gidx0*3+1).const_factor(), 1)
|
||||||
|
|
||||||
def test_replace(self):
|
def test_replace(self):
|
||||||
x = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.void), (), 0)
|
x = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0)
|
||||||
self.assertIs(x.replace(arg=None).arg, None)
|
self.assertIs(x.replace(arg=None).arg, None)
|
||||||
with self.assertRaises(AssertionError): x.replace(field="a")
|
with self.assertRaises(AssertionError): x.replace(field="a")
|
||||||
|
|
||||||
|
@ -403,7 +403,7 @@ class TestIndexingOrdering(unittest.TestCase):
|
||||||
# NOTE: these tests skip type_verify since they add dtype to STORE
|
# NOTE: these tests skip type_verify since they add dtype to STORE
|
||||||
@unittest.expectedFailure
|
@unittest.expectedFailure
|
||||||
def test_simple_order(self):
|
def test_simple_order(self):
|
||||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||||
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||||
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||||
uops = to_uops_list([st1, st0], skip_check=True)
|
uops = to_uops_list([st1, st0], skip_check=True)
|
||||||
|
@ -412,8 +412,8 @@ class TestIndexingOrdering(unittest.TestCase):
|
||||||
|
|
||||||
@unittest.expectedFailure
|
@unittest.expectedFailure
|
||||||
def test_ordering_multi_output(self):
|
def test_ordering_multi_output(self):
|
||||||
buf0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
buf0 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||||
buf1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 1)
|
buf1 = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 1)
|
||||||
st0_0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf0, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
st0_0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf0, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||||
st1_0 = UOp(UOps.STORE, dtypes.float, (buf0, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
st1_0 = UOp(UOps.STORE, dtypes.float, (buf0, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||||
st0_1 = UOp(UOps.STORE, dtypes.float.vec(4), (buf1, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
st0_1 = UOp(UOps.STORE, dtypes.float.vec(4), (buf1, UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||||
|
@ -430,7 +430,7 @@ class TestIndexingOrdering(unittest.TestCase):
|
||||||
assert stores[2].src[1] < stores[3].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
|
assert stores[2].src[1] < stores[3].src[1], f"stored at idx {stores[1].src[1].arg} AFTER {stores[0].src[1].arg}"
|
||||||
|
|
||||||
def test_simple_order_with_special(self):
|
def test_simple_order_with_special(self):
|
||||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0)
|
buf = UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), 0)
|
||||||
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), ('gidx0', 4))
|
||||||
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
st0 = UOp(UOps.STORE, dtypes.float.vec(4), (buf, gidx0+UOp.const(dtypes.int, 0), UOp.const(dtypes.float.vec(4), 42)))
|
||||||
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
st1 = UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, 4), UOp.const(dtypes.float, 10)))
|
||||||
|
|
|
@ -5,7 +5,7 @@ from tinygrad.engine.schedule import create_schedule
|
||||||
from tinygrad.engine.realize import lower_schedule_item
|
from tinygrad.engine.realize import lower_schedule_item
|
||||||
from tinygrad.codegen.linearize import linearize_uop
|
from tinygrad.codegen.linearize import linearize_uop
|
||||||
from tinygrad.ops import BinaryOps, TernaryOps, flops_mem, UOps, UOp
|
from tinygrad.ops import BinaryOps, TernaryOps, flops_mem, UOps, UOp
|
||||||
from tinygrad.dtype import PtrDType, dtypes
|
from tinygrad.dtype import dtypes
|
||||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
|
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
|
||||||
|
|
||||||
# **************** new FlopCounter ****************
|
# **************** new FlopCounter ****************
|
||||||
|
@ -119,7 +119,7 @@ class TestUOpsStats(unittest.TestCase):
|
||||||
|
|
||||||
#MULACC should have the same stats as MUL + ADD
|
#MULACC should have the same stats as MUL + ADD
|
||||||
def test_mulacc(self):
|
def test_mulacc(self):
|
||||||
globl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple())
|
globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
|
||||||
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
|
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
|
||||||
o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
|
o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
|
||||||
u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1))
|
u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1))
|
||||||
|
@ -129,7 +129,7 @@ class TestUOpsStats(unittest.TestCase):
|
||||||
u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
|
u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
|
||||||
uops = linearize_uop(u5.sink())
|
uops = linearize_uop(u5.sink())
|
||||||
|
|
||||||
globl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), tuple())
|
globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), tuple())
|
||||||
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
|
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
|
||||||
o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
|
o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
|
||||||
u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1))
|
u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1))
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
import unittest
|
import unittest
|
||||||
from tinygrad.dtype import PtrDType, dtypes
|
from tinygrad.dtype import dtypes
|
||||||
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, UOps, UPat, \
|
from tinygrad.ops import TRACK_MATCH_STATS, BinaryOps, TrackedPatternMatcher as PatternMatcher, UOp, UOps, UPat, \
|
||||||
graph_rewrite, contexts, track_rewrites
|
graph_rewrite, contexts, track_rewrites
|
||||||
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json
|
from tinygrad.viz.serve import get_details, get_metadata, uop_to_json
|
||||||
|
@ -27,7 +27,7 @@ class TestViz(unittest.TestCase):
|
||||||
pm = PatternMatcher([
|
pm = PatternMatcher([
|
||||||
(UPat.var("x")*1, lambda x:x),
|
(UPat.var("x")*1, lambda x:x),
|
||||||
])
|
])
|
||||||
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0)))
|
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||||
uops = helper_test_viz(a*1, pm)
|
uops = helper_test_viz(a*1, pm)
|
||||||
self.assertEqual(len(uops), 1)
|
self.assertEqual(len(uops), 1)
|
||||||
self.assertEqual(uops[0], a)
|
self.assertEqual(uops[0], a)
|
||||||
|
@ -37,15 +37,15 @@ class TestViz(unittest.TestCase):
|
||||||
(UPat.var("x")+UPat.var("x"), lambda x:x*2),
|
(UPat.var("x")+UPat.var("x"), lambda x:x*2),
|
||||||
(UPat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))),
|
(UPat.var("x", dtypes.int)*2, lambda x:x.alu(BinaryOps.SHL, UOp.const(dtypes.int, 1))),
|
||||||
])
|
])
|
||||||
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0)))
|
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||||
uops = helper_test_viz(a+a, pm)
|
uops = helper_test_viz(a+a, pm)
|
||||||
self.assertEqual(len(uops), 2)
|
self.assertEqual(len(uops), 2)
|
||||||
self.assertEqual(uops[0], a*2)
|
self.assertEqual(uops[0], a*2)
|
||||||
self.assertEqual(uops[1], graph_rewrite(a+a, pm))
|
self.assertEqual(uops[1], graph_rewrite(a+a, pm))
|
||||||
|
|
||||||
def test_rewrite_with_ctx(self):
|
def test_rewrite_with_ctx(self):
|
||||||
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0)))
|
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||||
b = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 1), UOp.const(dtypes.int, 0)))
|
b = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1), UOp.const(dtypes.int, 0)))
|
||||||
def store_load(visited:Dict[UOp, None], x:UOp) -> Optional[UOp]:
|
def store_load(visited:Dict[UOp, None], x:UOp) -> Optional[UOp]:
|
||||||
if x in visited: return None
|
if x in visited: return None
|
||||||
visited[x] = None
|
visited[x] = None
|
||||||
|
@ -85,7 +85,7 @@ class TestViz(unittest.TestCase):
|
||||||
self.assertEqual(len(ret), 1)
|
self.assertEqual(len(ret), 1)
|
||||||
|
|
||||||
def test_fold_const(self):
|
def test_fold_const(self):
|
||||||
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0), UOp.const(dtypes.int, 0)))
|
a = UOp(UOps.LOAD, dtypes.int, (UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), (), 0), UOp.const(dtypes.int, 0)))
|
||||||
graph = uop_to_json(a)
|
graph = uop_to_json(a)
|
||||||
assert not any(v[0].startswith("CONST") for v in graph.values())
|
assert not any(v[0].startswith("CONST") for v in graph.values())
|
||||||
assert len([x for x in graph.values() if "CONST" in x[0]]) == 1
|
assert len([x for x in graph.values() if "CONST" in x[0]]) == 1
|
||||||
|
|
|
@ -3,12 +3,12 @@ from typing import Tuple
|
||||||
|
|
||||||
from tinygrad.codegen.linearize import linearize_uop
|
from tinygrad.codegen.linearize import linearize_uop
|
||||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, is_increasing
|
from tinygrad.codegen.uopgraph import full_graph_rewrite, is_increasing
|
||||||
from tinygrad.dtype import dtypes, PtrDType
|
from tinygrad.dtype import dtypes
|
||||||
from tinygrad.ops import UOp, UOps, BinaryOps
|
from tinygrad.ops import UOp, UOps, BinaryOps
|
||||||
|
|
||||||
def get_gated_load_uop(valid:UOp, idx:UOp):
|
def get_gated_load_uop(valid:UOp, idx:UOp):
|
||||||
return UOp(UOps.LOAD, dtypes.float, (
|
return UOp(UOps.LOAD, dtypes.float, (
|
||||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=0),
|
UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0),
|
||||||
idx,
|
idx,
|
||||||
UOp.const(dtypes.float, 0.0),
|
UOp.const(dtypes.float, 0.0),
|
||||||
valid
|
valid
|
||||||
|
|
|
@ -7,7 +7,7 @@ from typing import Tuple
|
||||||
# *** fake symobilc uops ***
|
# *** fake symobilc uops ***
|
||||||
|
|
||||||
from tinygrad.helpers import DEBUG
|
from tinygrad.helpers import DEBUG
|
||||||
from tinygrad.dtype import dtypes, PtrDType, ConstType
|
from tinygrad.dtype import dtypes, ConstType
|
||||||
from tinygrad.codegen.linearize import linearize_uop
|
from tinygrad.codegen.linearize import linearize_uop
|
||||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
|
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
|
||||||
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops, graph_rewrite
|
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops, graph_rewrite
|
||||||
|
@ -16,7 +16,7 @@ import functools
|
||||||
|
|
||||||
def render(self) -> Tuple[str, ConstType, ConstType]:
|
def render(self) -> Tuple[str, ConstType, ConstType]:
|
||||||
# NOTE: we need STORE so the ALU op has children
|
# NOTE: we need STORE so the ALU op has children
|
||||||
glbl = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0)
|
glbl = UOp(UOps.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0)
|
||||||
uops = linearize_uop(full_graph_rewrite(UOp(UOps.STORE, dtypes.void, (glbl, UOp.const(dtypes.int, 0), self)).sink()))
|
uops = linearize_uop(full_graph_rewrite(UOp(UOps.STORE, dtypes.void, (glbl, UOp.const(dtypes.int, 0), self)).sink()))
|
||||||
if DEBUG>=5: print_uops(uops)
|
if DEBUG>=5: print_uops(uops)
|
||||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||||
|
|
|
@ -3,7 +3,6 @@ import unittest
|
||||||
|
|
||||||
from tinygrad import Tensor
|
from tinygrad import Tensor
|
||||||
from tinygrad.codegen.kernel import Kernel
|
from tinygrad.codegen.kernel import Kernel
|
||||||
from tinygrad.dtype import PtrDType
|
|
||||||
from tinygrad.helpers import DEBUG
|
from tinygrad.helpers import DEBUG
|
||||||
from tinygrad.ops import UOp, UOps, ReduceOps, print_uops
|
from tinygrad.ops import UOp, UOps, ReduceOps, print_uops
|
||||||
from tinygrad.codegen.kernel import verify_ast
|
from tinygrad.codegen.kernel import verify_ast
|
||||||
|
@ -27,9 +26,9 @@ def helper_test_verify_ast(*stores:UOp) -> Kernel:
|
||||||
class TestVerifyAST(unittest.TestCase):
|
class TestVerifyAST(unittest.TestCase):
|
||||||
def test_tiny_add(self):
|
def test_tiny_add(self):
|
||||||
dtype = dtypes.int
|
dtype = dtypes.int
|
||||||
buf_0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtype), (), 0)
|
buf_0 = UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), 0)
|
||||||
buf_1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtype), (), 1)
|
buf_1 = UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), 1)
|
||||||
buf_2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtype), (), 2)
|
buf_2 = UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), 2)
|
||||||
a = UOp(UOps.LOAD, dtype, (buf_1, ShapeTracker.from_shape((32, 1)).to_uop()))
|
a = UOp(UOps.LOAD, dtype, (buf_1, ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||||
b = UOp(UOps.LOAD, dtype, (buf_2, ShapeTracker.from_shape((32, 1)).to_uop()))
|
b = UOp(UOps.LOAD, dtype, (buf_2, ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||||
store = UOp(UOps.STORE, dtypes.void, (buf_0, ShapeTracker.from_shape((32, 1)).to_uop(), a+b))
|
store = UOp(UOps.STORE, dtypes.void, (buf_0, ShapeTracker.from_shape((32, 1)).to_uop(), a+b))
|
||||||
|
@ -37,7 +36,7 @@ class TestVerifyAST(unittest.TestCase):
|
||||||
|
|
||||||
def test_exactly_one_full_shape(self):
|
def test_exactly_one_full_shape(self):
|
||||||
dtype = dtypes.int
|
dtype = dtypes.int
|
||||||
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtype), (), i) for i in range(6)]
|
bufs = [UOp(UOps.DEFINE_GLOBAL, dtype.ptr(), (), i) for i in range(6)]
|
||||||
a = UOp(UOps.LOAD, dtype, (bufs[2], ShapeTracker.from_shape((32, 1)).to_uop()))
|
a = UOp(UOps.LOAD, dtype, (bufs[2], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||||
b = UOp(UOps.LOAD, dtype, (bufs[3], ShapeTracker.from_shape((32, 1)).to_uop()))
|
b = UOp(UOps.LOAD, dtype, (bufs[3], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||||
st0 = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), a+b)
|
st0 = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), a+b)
|
||||||
|
@ -47,28 +46,28 @@ class TestVerifyAST(unittest.TestCase):
|
||||||
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st0, st1)
|
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st0, st1)
|
||||||
|
|
||||||
def test_no_implicit_broadcasting(self):
|
def test_no_implicit_broadcasting(self):
|
||||||
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)]
|
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||||
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop()))
|
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((4, 32)).to_uop()))
|
||||||
b = a + UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.MAX, (1,)))
|
b = a + UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.MAX, (1,)))
|
||||||
st = UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b))
|
st = UOp(UOps.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((4, 32)).to_uop(), b))
|
||||||
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
|
with self.assertRaises(InvalidASTException): helper_test_verify_ast(st)
|
||||||
|
|
||||||
def test_shrink_ok(self):
|
def test_shrink_ok(self):
|
||||||
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)]
|
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||||
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)).to_uop()))
|
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)).to_uop()))
|
||||||
b = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(0, 1), offset=0, mask=None, contiguous=False),)).to_uop()))
|
b = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker((View((32, 32), strides=(0, 1), offset=0, mask=None, contiguous=False),)).to_uop()))
|
||||||
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), a+b)
|
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), a+b)
|
||||||
helper_test_verify_ast(st)
|
helper_test_verify_ast(st)
|
||||||
|
|
||||||
def test_reduce_store(self):
|
def test_reduce_store(self):
|
||||||
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)]
|
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||||
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
|
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||||
r = UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
|
r = UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
|
||||||
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r)
|
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r)
|
||||||
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
|
with self.assertRaisesRegex(InvalidASTException, "implicit expand"): helper_test_verify_ast(st)
|
||||||
|
|
||||||
def test_reduce_add_store(self):
|
def test_reduce_add_store(self):
|
||||||
bufs = [UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), i) for i in range(2)]
|
bufs = [UOp(UOps.DEFINE_GLOBAL, dtypes.float.ptr(), (), i) for i in range(2)]
|
||||||
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
|
a = UOp(UOps.LOAD, dtypes.float, (bufs[1], ShapeTracker.from_shape((32, 1)).to_uop()))
|
||||||
r = UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
|
r = UOp(UOps.REDUCE_AXIS, dtypes.float, (a,), (ReduceOps.SUM, (0,)))
|
||||||
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r+a)
|
st = UOp.store(bufs[0], ShapeTracker.from_shape((32, 1)).to_uop(), r+a)
|
||||||
|
|
|
@ -9,7 +9,7 @@ from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp
|
||||||
graph_rewrite, track_rewrites, Variable, sint
|
graph_rewrite, track_rewrites, Variable, sint
|
||||||
from tinygrad.device import Device
|
from tinygrad.device import Device
|
||||||
from tinygrad.renderer import Renderer, TensorCore, Program
|
from tinygrad.renderer import Renderer, TensorCore, Program
|
||||||
from tinygrad.dtype import ImageDType, PtrDType
|
from tinygrad.dtype import ImageDType
|
||||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put
|
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put
|
||||||
from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX
|
from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX
|
||||||
from tinygrad.shape.shapetracker import ShapeTracker
|
from tinygrad.shape.shapetracker import ShapeTracker
|
||||||
|
@ -660,7 +660,7 @@ class Kernel:
|
||||||
st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD]
|
st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD]
|
||||||
local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape))
|
local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape))
|
||||||
st_uop = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uop()
|
st_uop = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uop()
|
||||||
membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in, True), (), (f"temp{-(-1-i)}", st_uop.arg.real_size()))
|
membuf = UOp(UOps.DEFINE_LOCAL, tc.dtype_in.ptr(local=True), (), (f"temp{-(-1-i)}", st_uop.arg.real_size()))
|
||||||
local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, st_uop, src)), fix_st_fxn)
|
local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, st_uop, src)), fix_st_fxn)
|
||||||
srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, st_uop, local_store)))
|
srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, st_uop, local_store)))
|
||||||
else:
|
else:
|
||||||
|
@ -690,7 +690,7 @@ class Kernel:
|
||||||
for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
|
for i in range(self.first_reduce, self.first_reduce+self.group_for_reduces)]) + \
|
||||||
(1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
|
(1,) * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
|
||||||
st_uop = ShapeTracker.from_shape(local_shape).to_uop()
|
st_uop = ShapeTracker.from_shape(local_shape).to_uop()
|
||||||
local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(op.dtype, True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
|
local_buffer = UOp(UOps.DEFINE_LOCAL, op.dtype.ptr(local=True), (), (f"temp{self.reduceops.index(op)+1}", st_uop.arg.real_size()))
|
||||||
local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, start)))
|
local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, start)))
|
||||||
grouped_reduce = UOp(UOps.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], second_axis))
|
grouped_reduce = UOp(UOps.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], second_axis))
|
||||||
if op is self.reduceops[-1]: return grouped_reduce
|
if op is self.reduceops[-1]: return grouped_reduce
|
||||||
|
|
|
@ -13,7 +13,7 @@ if TYPE_CHECKING: from tinygrad.renderer import Renderer
|
||||||
# ***** float4/image store handling *****
|
# ***** float4/image store handling *****
|
||||||
|
|
||||||
def fold_expanded(ex, buf):
|
def fold_expanded(ex, buf):
|
||||||
if buf.dtype != PtrDType(dtypes.float) and buf.dtype != PtrDType(dtypes.half) and not isinstance(buf.dtype, ImageDType): return None
|
if buf.dtype != dtypes.float.ptr() and buf.dtype != dtypes.half.ptr() and not isinstance(buf.dtype, ImageDType): return None
|
||||||
new_srcs = dedup(list(ex.src))
|
new_srcs = dedup(list(ex.src))
|
||||||
old_new_srcs = new_srcs[:]
|
old_new_srcs = new_srcs[:]
|
||||||
is_load, is_image = new_srcs[0].op is UOps.LOAD, isinstance(buf.dtype, ImageDType)
|
is_load, is_image = new_srcs[0].op is UOps.LOAD, isinstance(buf.dtype, ImageDType)
|
||||||
|
@ -32,7 +32,7 @@ def fold_expanded(ex, buf):
|
||||||
offsets_rootsrc[root_src][arg] = i
|
offsets_rootsrc[root_src][arg] = i
|
||||||
|
|
||||||
# then rewrite everything we can
|
# then rewrite everything we can
|
||||||
lengths = [4] if is_image else ([8,4,2] if buf.dtype == PtrDType(dtypes.half) and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]))
|
lengths = [4] if is_image else ([8,4,2] if buf.dtype == dtypes.half.ptr() and getenv("ALLOW_HALF8") else ([16,8,4,2] if AMX else [4,2]))
|
||||||
used = set()
|
used = set()
|
||||||
for rootsrc, offsets in offsets_rootsrc.items():
|
for rootsrc, offsets in offsets_rootsrc.items():
|
||||||
for o in offsets:
|
for o in offsets:
|
||||||
|
|
|
@ -18,7 +18,7 @@ class DType:
|
||||||
assert self.count == 1, f"can't vectorize {self} with size {sz}"
|
assert self.count == 1, f"can't vectorize {self} with size {sz}"
|
||||||
if sz == 1 or self.name == 'void': return self # void doesn't vectorize, and sz=1 is scalar
|
if sz == 1 or self.name == 'void': return self # void doesn't vectorize, and sz=1 is scalar
|
||||||
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
|
return DType(self.priority, self.itemsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz)
|
||||||
def ptr(self) -> Union[PtrDType, ImageDType]: return PtrDType(self)
|
def ptr(self, local=False) -> Union[PtrDType, ImageDType]: return PtrDType(self, local)
|
||||||
def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
|
def scalar(self) -> DType: return DTYPES_DICT[self.name[:-len(str(self.count))]] if self.count > 1 else self
|
||||||
|
|
||||||
# dependent typing?
|
# dependent typing?
|
||||||
|
@ -29,10 +29,9 @@ class ImageDType(DType):
|
||||||
local: bool = False # images are never local
|
local: bool = False # images are never local
|
||||||
def scalar(self) -> DType: return self.base
|
def scalar(self) -> DType: return self.base
|
||||||
def vec(self, sz:int): return self.base.vec(sz)
|
def vec(self, sz:int): return self.base.vec(sz)
|
||||||
def ptr(self) -> Union[PtrDType, ImageDType]: return self
|
def ptr(self, local=False) -> Union[PtrDType, ImageDType]: return self
|
||||||
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
|
def __repr__(self): return f"dtypes.{self.name}({self.shape})"
|
||||||
|
|
||||||
# @dataclass(frozen=True, init=False, repr=False, eq=False)
|
|
||||||
class PtrDType(DType):
|
class PtrDType(DType):
|
||||||
def __init__(self, dt:DType, local=False):
|
def __init__(self, dt:DType, local=False):
|
||||||
self.base, self.local = dt, local
|
self.base, self.local = dt, local
|
||||||
|
@ -40,7 +39,7 @@ class PtrDType(DType):
|
||||||
def __hash__(self): return super().__hash__()
|
def __hash__(self): return super().__hash__()
|
||||||
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
|
def __eq__(self, dt): return self.priority==dt.priority and self.itemsize==dt.itemsize and self.name==dt.name and self.count==dt.count
|
||||||
def __ne__(self, dt): return not (self == dt)
|
def __ne__(self, dt): return not (self == dt)
|
||||||
def __repr__(self): return f"PtrDType({super().__repr__()}, local=True)" if self.local else f"PtrDType({super().__repr__()})"
|
def __repr__(self): return f"{super().__repr__()}.ptr(local=True)" if self.local else f"{super().__repr__()}.ptr()"
|
||||||
|
|
||||||
class dtypes:
|
class dtypes:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -102,17 +102,19 @@ class CStyleLanguage(Renderer):
|
||||||
def get_kernel_modifier(self, uops:List[UOp]) -> str: return ""
|
def get_kernel_modifier(self, uops:List[UOp]) -> str: return ""
|
||||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
|
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
|
||||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
|
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
|
||||||
buftypes = [(name,f"{'write_only' if mutable else 'read_only'} image2d_t" if dtype.name.startswith('image') else
|
buftypes = [(name, self.render_dtype(dtype, mutable)+self.buffer_suffix if isinstance(dtype, (ImageDType, PtrDType)) else
|
||||||
("" if mutable else "const ")+self.render_dtype(dtype)+self.buffer_suffix if isinstance(dtype, PtrDType) else
|
|
||||||
self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
|
self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
|
||||||
prg = ''.join([f"{self.kernel_prefix}void {self.get_kernel_modifier(uops)}{function_name}(",] +
|
prg = ''.join([f"{self.kernel_prefix}void {self.get_kernel_modifier(uops)}{function_name}(",] +
|
||||||
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
|
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
|
||||||
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
|
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
|
||||||
return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
|
return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
|
||||||
|
|
||||||
def render_dtype(self, dt:DType) -> str:
|
def render_dtype(self, dt:DType, mutable=True) -> str:
|
||||||
|
if isinstance(dt, ImageDType):
|
||||||
|
return f"{'write_only' if mutable else 'read_only'} image2d_t"
|
||||||
if isinstance(dt, PtrDType):
|
if isinstance(dt, PtrDType):
|
||||||
return (self.smem_prefix if dt.local else self.buffer_prefix) + self.render_dtype(dt.base) + ("*" if isinstance(dt, PtrDType) else "")
|
return ("" if mutable else "const ") + (self.smem_prefix if dt.local else self.buffer_prefix) +\
|
||||||
|
self.render_dtype(dt.base) + ("*" if isinstance(dt, PtrDType) else "")
|
||||||
return self.type_map.get(scalar:=dt.scalar(), scalar.name) + (str(dt.count) if (dt.count) > 1 else "")
|
return self.type_map.get(scalar:=dt.scalar(), scalar.name) + (str(dt.count) if (dt.count) > 1 else "")
|
||||||
|
|
||||||
def __getitem__(self, key): return self.r[key] # hacky helper
|
def __getitem__(self, key): return self.r[key] # hacky helper
|
||||||
|
@ -126,12 +128,8 @@ class CStyleLanguage(Renderer):
|
||||||
depth = 1
|
depth = 1
|
||||||
c: DefaultDict[str, int] = defaultdict(int)
|
c: DefaultDict[str, int] = defaultdict(int)
|
||||||
for u in uops:
|
for u in uops:
|
||||||
if u.op is UOps.DEFINE_GLOBAL:
|
if u.op in (UOps.DEFINE_GLOBAL, UOps.DEFINE_VAR):
|
||||||
r[u] = f"data{u.arg}"
|
r[u] = f"data{u.arg}" if u.op is UOps.DEFINE_GLOBAL else u.arg[0]
|
||||||
bufs[u] = (r[u], (u.dtype, False))
|
|
||||||
continue
|
|
||||||
if u.op is UOps.DEFINE_VAR:
|
|
||||||
r[u] = u.arg[0]
|
|
||||||
bufs[u] = (r[u], (u.dtype, False))
|
bufs[u] = (r[u], (u.dtype, False))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue