mirror of https://github.com/commaai/tinygrad.git
use ast_const in test_linearizer asts [run_process_replay] (#6407)
This commit is contained in:
parent
750696a026
commit
935b4ddff6
|
@ -1,5 +1,5 @@
|
|||
import sys, time
|
||||
from typing import Callable, Tuple, TypeVar
|
||||
from typing import Callable, Optional, Tuple, TypeVar
|
||||
import numpy as np
|
||||
from test.external.process_replay.helpers import print_diff
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
|
@ -62,9 +62,9 @@ def assert_equiv_uops(u1:UOp, u2:UOp) -> None:
|
|||
print_diff(u1, u2)
|
||||
raise AssertionError("uops aren't equal.")
|
||||
|
||||
def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]) -> UOp:
|
||||
return UOp(UOps.CONST, dtype, (ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape).to_uop(),),
|
||||
dtypes.as_const(val, dtype))
|
||||
def ast_const(dtype:DType, val:ConstType, shape:Tuple[sint, ...]=(), st:Optional[ShapeTracker]=None) -> UOp:
|
||||
st = st if st is not None else ShapeTracker.from_shape(()).reshape((1,)*len(shape)).expand(shape)
|
||||
return UOp(UOps.CONST, dtype, (st.to_uop(),), dtypes.as_const(val, dtype))
|
||||
|
||||
T = TypeVar("T")
|
||||
def timeit(fxn:Callable[..., T], *args, **kwargs) -> Tuple[T, float]:
|
||||
|
|
|
@ -555,7 +555,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False)))
|
||||
arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384))
|
||||
arange_axis = (3,)
|
||||
arange = UOp(UOps.REDUCE_AXIS, dtypes.int, (UOp(UOps.CONST, dtypes.int, (arange_input_st.to_uop(),), 1),), (BinaryOps.ADD, arange_axis))
|
||||
arange = UOp(UOps.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (BinaryOps.ADD, arange_axis))
|
||||
output_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape))
|
||||
out = arange+ast_const(dtypes.int, -1, output_shape)
|
||||
store = UOp(UOps.STORE, src=(UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=0), ShapeTracker.from_shape(output_shape).to_uop(), out))
|
||||
|
@ -572,7 +572,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
# TODO: do this arange broadcast in the scheduler
|
||||
arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384))
|
||||
arange_axis = (3,)
|
||||
arange = UOp(UOps.REDUCE_AXIS, dtypes.int, (UOp(UOps.CONST, dtypes.int, (arange_input_st.to_uop(),), 1),), (BinaryOps.ADD, arange_axis))
|
||||
arange = UOp(UOps.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (BinaryOps.ADD, arange_axis))
|
||||
arange_out_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape))
|
||||
arange = arange+ast_const(dtypes.int, -1, arange_out_shape)
|
||||
# p2: the indexing
|
||||
|
@ -601,8 +601,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa E501
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.CONST, dtypes.int, arg=10, src=(
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), # noqa E501
|
||||
ast_const(dtypes.int, st=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), val=10),
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
|
||||
ast_const(dtypes.int, -1, (1, 20, 1)),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.MAX, (0,)), src=(
|
||||
|
@ -616,16 +615,12 @@ class TestLinearizer(unittest.TestCase):
|
|||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa E501
|
||||
UOp(UOps.CONST, dtypes.bool, arg=True, src=(
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), # noqa E501
|
||||
ast_const(dtypes.bool, True, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)), # noqa E501
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=(
|
||||
UOp(UOps.CONST, dtypes.int, arg=-1, src=(
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 20, 10), strides=(1, 0, 20), offset=0, mask=None, contiguous=False))), src=()),)),)), # noqa E501
|
||||
UOp(UOps.CONST, dtypes.int, arg=10, src=(
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)), # noqa E501
|
||||
UOp(UOps.CONST, dtypes.int, arg=-1, src=(
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) # noqa E501
|
||||
ast_const(dtypes.int, -1, st=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 20, 10), strides=(1, 0, 20), offset=0, mask=None, contiguous=False)))),)), # noqa E501
|
||||
ast_const(dtypes.int, 10, (10, 20, 1)))),)),)),)),)),
|
||||
ast_const(dtypes.int, -1, (1, 20, 1)),)),)),))
|
||||
helper_linearizer_ast(ast, [t, t_max], wanna_output=[real_argmax])
|
||||
|
||||
def test_argmax_multireduce_flat(self):
|
||||
|
@ -638,8 +633,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.CONST, dtypes.int, arg=200, src=(
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()),)), # noqa: E501
|
||||
ast_const(dtypes.int, 200, (1, 1)),
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.MUL, src=(
|
||||
ast_const(dtypes.int, -1, (1, 1)),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.MAX, (0,)), src=(
|
||||
|
@ -653,16 +647,12 @@ class TestLinearizer(unittest.TestCase):
|
|||
UOp(UOps.LOAD, dtypes.float, arg=None, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=2, src=()),
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa: E501
|
||||
UOp(UOps.CONST, dtypes.bool, arg=True, src=(
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), # noqa: E501
|
||||
ast_const(dtypes.bool, True, (200, 1)),)),)),
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(UOps.CONST, dtypes.int, arg=-1, src=(
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(201, 399), strides=(0, 0), offset=0, mask=((0, 201), (199, 399)), contiguous=False), View(shape=(200, 200), strides=(1, 400), offset=0, mask=None, contiguous=False))), src=()),)),)), # noqa: E501
|
||||
UOp(UOps.CONST, dtypes.int, arg=200, src=(
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)),)), # noqa: E501
|
||||
UOp(UOps.CONST, dtypes.int, arg=-1, src=(
|
||||
UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) # noqa: E501
|
||||
ast_const(dtypes.int, -1, st=ShapeTracker(views=(View(shape=(201, 399), strides=(0, 0), offset=0, mask=((0, 201), (199, 399)), contiguous=False), View(shape=(200, 200), strides=(1, 400), offset=0, mask=None, contiguous=False)))),)), # noqa: E501
|
||||
ast_const(dtypes.int, 200, (200, 1)),)),)),)),)),)),
|
||||
ast_const(dtypes.int, -1, (1, 1)),)),)),))
|
||||
helper_linearizer_ast(ast, [t, t_max], wanna_output=[real_argmax])
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"AMD"}, "AMD CI doesn't support multiple sync threads yet")
|
||||
|
@ -743,8 +733,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),))),
|
||||
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
||||
UOp(UOps.CONST, dtypes.float, arg=0.5*N, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),
|
||||
ast_const(dtypes.float, 0.5*N, (N, 1, 1)),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
|
@ -752,8 +741,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
ld1.to_uop(),)),
|
||||
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
||||
UOp(UOps.CONST, dtypes.float, arg=0.75*N, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
|
||||
ast_const(dtypes.float, 0.75*N, (N, N, 1)),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=(
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1),
|
||||
|
@ -764,10 +752,9 @@ class TestLinearizer(unittest.TestCase):
|
|||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, N, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)), # noqa: E501
|
||||
UOp(UOps.CONST, dtypes.float, arg=0.0, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),
|
||||
UOp(UOps.CONST, dtypes.float, arg=1.0, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, 1, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)) # noqa: E501
|
||||
|
||||
ast_const(dtypes.float, 0.0, (N, 1, 1)),
|
||||
ast_const(dtypes.float, 1.0, (N, 1, 1)),)),)),))
|
||||
helper_linearizer_ast(ast, [x,a,b], opts=opts, wanna_output=[wanna_output])
|
||||
|
||||
ld0 = x.lazydata.st.reshape((1, N, N)).expand((N,N,N))
|
||||
|
@ -779,8 +766,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 1, N), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()),
|
||||
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
||||
UOp(UOps.CONST, dtypes.float, arg=0.5*N, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), # noqa: E501
|
||||
ast_const(dtypes.float, 0.5*N, (1, 1, N)),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
|
@ -788,8 +774,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
ld1.to_uop(),)),
|
||||
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
||||
UOp(UOps.CONST, dtypes.float, arg=0.75*N, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), # noqa: E501
|
||||
ast_const(dtypes.float, 0.75*N, (N, 1, N)),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=(
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1, src=()),
|
||||
|
@ -800,10 +785,9 @@ class TestLinearizer(unittest.TestCase):
|
|||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3, src=()),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)),)), # noqa: E501
|
||||
UOp(UOps.CONST, dtypes.float, arg=0.0, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), # noqa: E501
|
||||
UOp(UOps.CONST, dtypes.float, arg=1.0, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 1, N), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)),)) # noqa: E501
|
||||
|
||||
ast_const(dtypes.float, 0.0, (1, 1, N)),
|
||||
ast_const(dtypes.float, 1.0, (1, 1, N)),)),)),))
|
||||
helper_linearizer_ast(ast, [x,a,b], opts=opts, wanna_output=[wanna_output])
|
||||
|
||||
# pad reduce axis
|
||||
|
@ -818,8 +802,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=True),))),
|
||||
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
||||
UOp(UOps.CONST, dtypes.float, arg=0.5*N, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=True),))),)), # noqa: E501
|
||||
ast_const(dtypes.float, 0.5*N, (1, 1, 1, 1)),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
UOp(UOps.ALU, dtypes.float, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
|
@ -827,8 +810,7 @@ class TestLinearizer(unittest.TestCase):
|
|||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(N, 1, 0, 0), offset=0, mask=None, contiguous=True),))),)), # noqa: E501
|
||||
UOp(UOps.ALU, dtypes.float, arg=TernaryOps.WHERE, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, arg=BinaryOps.CMPLT, src=(
|
||||
UOp(UOps.CONST, dtypes.float, arg=0.75*N, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)), # noqa: E501
|
||||
ast_const(dtypes.float, 0.75*N, (N, N, 1, 1)),
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=(
|
||||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=1),
|
||||
|
@ -839,11 +821,8 @@ class TestLinearizer(unittest.TestCase):
|
|||
UOp(UOps.LOAD, dtypes.float, src=(
|
||||
UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=3),
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(N, N, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)),)),)),)), # noqa: E501
|
||||
UOp(UOps.CONST, dtypes.float, arg=0.0, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=True),))),)), # noqa: E501
|
||||
UOp(UOps.CONST, dtypes.float, arg=1.0, src=(
|
||||
UOp(UOps.SHAPETRACKER, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1), strides=(0, 0, 0, 0), offset=0, mask=None, contiguous=True),))),)),)),)),)) # noqa: E501
|
||||
|
||||
ast_const(dtypes.float, 0.0, (1, 1, 1, 1)),
|
||||
ast_const(dtypes.float, 1.0, (1, 1, 1, 1)),)),)),))
|
||||
helper_linearizer_ast(ast, [x,a,b], opts=[[Opt(OptOps.PADTO, 0, 32)],], wanna_output=[wanna_output])
|
||||
|
||||
@unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "test requires locals")
|
||||
|
@ -950,13 +929,13 @@ class TestLinearizer(unittest.TestCase):
|
|||
def test_load_cache_const_bufs(self):
|
||||
# make sure const buffers are differentiated from local and mem buffers
|
||||
ST, DT = ShapeTracker(views=(View(shape=((1,)), strides=(0, 0), offset=0, mask=None, contiguous=False),)).to_uop(), dtypes.int
|
||||
VAL = UOp(UOps.CONST, DT, (ST,), 2)
|
||||
VAL = ast_const(DT, 2, ST.arg.shape)
|
||||
g0, g1 = [UOp(UOps.DEFINE_GLOBAL, PtrDType(DT), arg=i) for i in range(2)]
|
||||
|
||||
# data1[0] + VAL
|
||||
a = UOp(UOps.LOAD, DT, (g1, ST)) + VAL
|
||||
# (literal const 1) + VAL
|
||||
b = UOp(UOps.CONST, DT, (ST,), 1) + VAL
|
||||
b = ast_const(DT, 1, ST.arg.shape) + VAL
|
||||
|
||||
store = UOp(UOps.STORE, src=(g0, ST, (a+b)))
|
||||
sink = UOp(UOps.SINK, src=(store,))
|
||||
|
|
Loading…
Reference in New Issue