remove Variable from UOp.DEFINE_VAR (#6393)

now it's just arg = (expr as str, min as UOp.const, max as UOp.const)
This commit is contained in:
chenyu 2024-09-06 05:55:19 -04:00 committed by GitHub
parent 9ed2b8b818
commit 26c5d8346a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 34 additions and 39 deletions

View File

@ -1,7 +1,7 @@
from typing import List from typing import List
import unittest, time import unittest, time
from test.helpers import assert_equiv_uops from test.helpers import assert_equiv_uops
from tinygrad import dtypes, Variable, Device from tinygrad import dtypes, Device
from tinygrad.dtype import PtrDType from tinygrad.dtype import PtrDType
from tinygrad.helpers import DEBUG from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo
@ -131,10 +131,10 @@ class TestGraphRewrite(unittest.TestCase):
self.assertEqual(nout.src[1].arg, 3.0) self.assertEqual(nout.src[1].arg, 3.0)
def test_consts_go_last(self): def test_consts_go_last(self):
a = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('a', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1))) a = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('a', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
b = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('b', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1))) b = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('b', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
c = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('c', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1))) c = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('c', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
d = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('d', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1))) d = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('d', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b] outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b]
for out in outs: for out in outs:
sink = graph_rewrite(out, constant_folder) sink = graph_rewrite(out, constant_folder)
@ -155,7 +155,7 @@ class TestUOpGraph(unittest.TestCase):
self.assertEqual(out.arg, 3.0) self.assertEqual(out.arg, 3.0)
def test_where_same_fold(self): def test_where_same_fold(self):
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('tmp', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1))) v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('tmp', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
c0 = UOp(UOps.CONST, dtypes.int, arg=0) c0 = UOp(UOps.CONST, dtypes.int, arg=0)
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE) vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
@ -249,7 +249,7 @@ class TestUOpGraph(unittest.TestCase):
for i in [2, 4, 8]: for i in [2, 4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i))) vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i)) var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma]) uops = to_uops_list([wmma])
assert_equiv_uops(uops[0], acc) assert_equiv_uops(uops[0], acc)
@ -258,7 +258,7 @@ class TestUOpGraph(unittest.TestCase):
for i in [2, 4, 8]: for i in [2, 4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i)) var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i))) vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma]) uops = to_uops_list([wmma])
assert_equiv_uops(uops[0], acc) assert_equiv_uops(uops[0], acc)
@ -268,21 +268,19 @@ class TestUOpGraph(unittest.TestCase):
for i in [4, 8]: for i in [4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) + tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
arg=(Variable(f'tmp{j}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2))) var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma]) uops = to_uops_list([wmma])
assert_equiv_uops(uops[-1], wmma) assert_equiv_uops(uops[-1], wmma)
for i in [4, 8]: for i in [4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) + tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
arg=(Variable(f'tmp{j}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2))) acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma]) uops = to_uops_list([wmma])
assert_equiv_uops(uops[-1], wmma) assert_equiv_uops(uops[-1], wmma)
@ -290,17 +288,17 @@ class TestUOpGraph(unittest.TestCase):
for i in [2, 4, 8]: for i in [2, 4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i))) tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma]) uops = to_uops_list([wmma])
assert_equiv_uops(uops[-1], wmma) assert_equiv_uops(uops[-1], wmma)
for i in [2, 4, 8]: for i in [2, 4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i))) tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma]) uops = to_uops_list([wmma])
assert_equiv_uops(uops[-1], wmma) assert_equiv_uops(uops[-1], wmma)
@ -326,7 +324,7 @@ class TestUOpGraph(unittest.TestCase):
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1) self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1)
def test_depth_2_const_fold(self): def test_depth_2_const_fold(self):
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('tmp', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1))) v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('tmp', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
c2 = UOp(UOps.CONST, dtypes.int, arg=2) c2 = UOp(UOps.CONST, dtypes.int, arg=2)
c4 = UOp(UOps.CONST, dtypes.int, arg=4) c4 = UOp(UOps.CONST, dtypes.int, arg=4)
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD) vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)

View File

@ -27,10 +27,7 @@ def render(self) -> Tuple[str, ConstType, ConstType]:
def NumNode(val): return UOp.const(dtypes.int, val) def NumNode(val): return UOp.const(dtypes.int, val)
def Variable(expr, nmin, nmax): def Variable(expr, nmin, nmax):
# TODO: fix DEFINE_VAR to not need this return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(expr, UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)))
class TempVar:
def __init__(self, x): self.expr = x
return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(TempVar(expr), UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)))
class Node: class Node:
@staticmethod @staticmethod
def sum(ops): return functools.reduce(lambda x,y: x+y, ops) def sum(ops): return functools.reduce(lambda x,y: x+y, ops)
@ -306,7 +303,6 @@ class TestSymbolic(unittest.TestCase):
def test_sum_combine_num(self): def test_sum_combine_num(self):
self.helper_test_variable(Node.sum([NumNode(29), Variable("a", 0, 10), NumNode(-23)]), 6, 16, {"(6+a)", "(a+6)"}) self.helper_test_variable(Node.sum([NumNode(29), Variable("a", 0, 10), NumNode(-23)]), 6, 16, {"(6+a)", "(a+6)"})
@unittest.expectedFailure
def test_sum_num_hoisted_and_factors_cancel_out(self): def test_sum_num_hoisted_and_factors_cancel_out(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1") self.helper_test_variable(Node.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1")

View File

@ -333,7 +333,7 @@ class UOp(MathTrait):
@functools.cached_property @functools.cached_property
def cmp_tuple(self) -> Tuple[int, Any, Optional[DType], Tuple[UOp, ...]]: def cmp_tuple(self) -> Tuple[int, Any, Optional[DType], Tuple[UOp, ...]]:
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX # NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg[0].expr) if self.op is not UOps.ALU else \ return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg[0]) if self.op is not UOps.ALU else \
self.arg.value, self.dtype, self.src) self.arg.value, self.dtype, self.src)
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
@functools.cached_property @functools.cached_property
@ -366,7 +366,7 @@ class UOp(MathTrait):
@classmethod @classmethod
def _const(cls, dtype:Optional[DType], b:ConstType|Variable): def _const(cls, dtype:Optional[DType], b:ConstType|Variable):
# TODO: fix dtype of b.max after Variable is just an UOp # TODO: fix dtype of b.max after Variable is just an UOp
if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, arg=(b, cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max)))) if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, arg=(b.expr, cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max))))
if dtype is not None and dtype != (sdtype := dtype.scalar()): if dtype is not None and dtype != (sdtype := dtype.scalar()):
return cls(UOps.VECTORIZE, dtype, src=tuple(cls(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count))) return cls(UOps.VECTORIZE, dtype, src=tuple(cls(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count)))
return cls(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) return cls(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
@ -388,7 +388,8 @@ class UOp(MathTrait):
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR]) def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
def variables(self) -> List[Variable]: def variables(self) -> List[Variable]:
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS] st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
return sorted(set.union(*st_vars, set([x.arg[0] for x in self.sparents if x.op is UOps.DEFINE_VAR])), key=lambda v: v.expr) return sorted(set.union(*st_vars, set([Variable(x.arg[0], x.arg[1], x.arg[2]) for x in self.sparents if x.op is UOps.DEFINE_VAR])),
key=lambda v: v.expr)
def const_factor(self) -> int: def const_factor(self) -> int:
"""largest known int that divides self""" """largest known int that divides self"""
if self.op is UOps.CONST: return self.arg if self.op is UOps.CONST: return self.arg
@ -476,7 +477,7 @@ def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x:
def uop_alu_resolve(u:UOp) -> sint: def uop_alu_resolve(u:UOp) -> sint:
if u.op is UOps.CONST: return u.arg if u.op is UOps.CONST: return u.arg
if u.op is UOps.DEFINE_VAR: return u.arg[0] if u.op is UOps.DEFINE_VAR: return Variable(u.arg[0], u.arg[1].arg, u.arg[2].arg)
if u.op is UOps.ALU: return exec_alu(u.arg, cast(DType,u.dtype), tuple(map(uop_alu_resolve, u.src))) if u.op is UOps.ALU: return exec_alu(u.arg, cast(DType,u.dtype), tuple(map(uop_alu_resolve, u.src)))
raise RuntimeError(f"ALU resolve fail @ {u.op}") raise RuntimeError(f"ALU resolve fail @ {u.op}")

View File

@ -34,7 +34,7 @@ class Program:
if not self._ran_post_init and self.uops is not None: if not self._ran_post_init and self.uops is not None:
# single pass through the uops # single pass through the uops
for u in self.uops: for u in self.uops:
if u.op is UOps.DEFINE_VAR: self.vars.append(u.arg[0]) if u.op is UOps.DEFINE_VAR: self.vars.append(Variable(u.arg[0], u.arg[1].arg, u.arg[2].arg))
if u.op is UOps.DEFINE_GLOBAL: self.globals.append(u.arg) if u.op is UOps.DEFINE_GLOBAL: self.globals.append(u.arg)
if u.op is UOps.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL]) if u.op is UOps.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL])
if u.op is UOps.SPECIAL: if u.op is UOps.SPECIAL:

View File

@ -212,9 +212,9 @@ class PTXRenderer(Renderer):
r[u] = "%" + args[0] r[u] = "%" + args[0]
kernel = [f".reg .u32 %{args[0]};"] + kernel kernel = [f".reg .u32 %{args[0]};"] + kernel
elif uop is UOps.DEFINE_VAR: elif uop is UOps.DEFINE_VAR:
bufs.append((args[0].expr, dtype)) bufs.append((args[0], dtype))
r[u] = f"%{args[0].expr}" r[u] = f"%{args[0]}"
kk(*self.render_load(args[0].expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param")) kk(*self.render_load(args[0], ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
elif uop is UOps.CONST: r[u] = const(args, dtype, mov=True) elif uop is UOps.CONST: r[u] = const(args, dtype, mov=True)
elif uop is UOps.GEP: r[u] = r[src[0]][u.arg] elif uop is UOps.GEP: r[u] = r[src[0]][u.arg]
elif uop is UOps.LOAD: elif uop is UOps.LOAD:

View File

@ -147,10 +147,10 @@ class CStyleLanguage(Renderer):
kk(f"int {args[0]} = {self.code_for_workitem[args[0][0]](args[0][-1])}; /* {args[1]} */") kk(f"int {args[0]} = {self.code_for_workitem[args[0][0]](args[0][-1])}; /* {args[1]} */")
r[u] = args[0] r[u] = args[0]
elif uop is UOps.DEFINE_VAR: elif uop is UOps.DEFINE_VAR:
assert args[0].expr not in seen_vars, f"duplicate variable {args[0].expr}" assert args[0] not in seen_vars, f"duplicate variable {args[0]}"
seen_vars.add(args[0].expr) seen_vars.add(args[0])
bufs[u] = (args[0].expr, (dtype,False)) bufs[u] = (args[0], (dtype,False))
r[u] = args[0].expr r[u] = args[0]
elif uop is UOps.LOAD: elif uop is UOps.LOAD:
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL) val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
# NOTE: this relies on the load not happening if it's in the unselected branch # NOTE: this relies on the load not happening if it's in the unselected branch

View File

@ -19,7 +19,7 @@ render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.pyint, self
ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx), ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)), LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
Variable: lambda self,ops,ctx: ctx[self] if ctx is not None and self in ctx else \ Variable: lambda self,ops,ctx: ctx[self] if ctx is not None and self in ctx else \
UOp(UOps.DEFINE_VAR, dtypes.int, arg=(self, UOp.const(dtypes.int, self.min), UOp.const(dtypes.int, self.max))), UOp(UOps.DEFINE_VAR, dtypes.int, arg=(self.expr, UOp.const(dtypes.int, self.min), UOp.const(dtypes.int, self.max))),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)), SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) } AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }