remove Variable from UOp.DEFINE_VAR (#6393)

now it's just arg = (expr as str, min as UOp.const, max as UOp.const)
This commit is contained in:
chenyu 2024-09-06 05:55:19 -04:00 committed by GitHub
parent 9ed2b8b818
commit 26c5d8346a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 34 additions and 39 deletions

View File

@ -1,7 +1,7 @@
from typing import List
import unittest, time
from test.helpers import assert_equiv_uops
from tinygrad import dtypes, Variable, Device
from tinygrad import dtypes, Device
from tinygrad.dtype import PtrDType
from tinygrad.helpers import DEBUG
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo
@ -131,10 +131,10 @@ class TestGraphRewrite(unittest.TestCase):
self.assertEqual(nout.src[1].arg, 3.0)
def test_consts_go_last(self):
a = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('a', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
b = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('b', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
c = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('c', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
d = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('d', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
a = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('a', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
b = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('b', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
c = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('c', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
d = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('d', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b]
for out in outs:
sink = graph_rewrite(out, constant_folder)
@ -155,7 +155,7 @@ class TestUOpGraph(unittest.TestCase):
self.assertEqual(out.arg, 3.0)
def test_where_same_fold(self):
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('tmp', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('tmp', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
c0 = UOp(UOps.CONST, dtypes.int, arg=0)
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
@ -249,7 +249,7 @@ class TestUOpGraph(unittest.TestCase):
for i in [2, 4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[0], acc)
@ -258,7 +258,7 @@ class TestUOpGraph(unittest.TestCase):
for i in [2, 4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[0], acc)
@ -268,21 +268,19 @@ class TestUOpGraph(unittest.TestCase):
for i in [4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
tuple(UOp(UOps.DEFINE_VAR, dtypes.half,
arg=(Variable(f'tmp{j}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[-1], wmma)
for i in [4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
tuple(UOp(UOps.DEFINE_VAR, dtypes.half,
arg=(Variable(f'tmp{j}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[-1], wmma)
@ -290,17 +288,17 @@ class TestUOpGraph(unittest.TestCase):
for i in [2, 4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[-1], wmma)
for i in [2, 4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[-1], wmma)
@ -326,7 +324,7 @@ class TestUOpGraph(unittest.TestCase):
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1)
def test_depth_2_const_fold(self):
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('tmp', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('tmp', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
c2 = UOp(UOps.CONST, dtypes.int, arg=2)
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)

View File

@ -27,10 +27,7 @@ def render(self) -> Tuple[str, ConstType, ConstType]:
def NumNode(val): return UOp.const(dtypes.int, val)
def Variable(expr, nmin, nmax):
# TODO: fix DEFINE_VAR to not need this
class TempVar:
def __init__(self, x): self.expr = x
return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(TempVar(expr), UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)))
return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(expr, UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)))
class Node:
@staticmethod
def sum(ops): return functools.reduce(lambda x,y: x+y, ops)
@ -306,7 +303,6 @@ class TestSymbolic(unittest.TestCase):
def test_sum_combine_num(self):
self.helper_test_variable(Node.sum([NumNode(29), Variable("a", 0, 10), NumNode(-23)]), 6, 16, {"(6+a)", "(a+6)"})
@unittest.expectedFailure
def test_sum_num_hoisted_and_factors_cancel_out(self):
self.helper_test_variable(Node.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1")

View File

@ -333,7 +333,7 @@ class UOp(MathTrait):
@functools.cached_property
def cmp_tuple(self) -> Tuple[int, Any, Optional[DType], Tuple[UOp, ...]]:
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg[0].expr) if self.op is not UOps.ALU else \
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg[0]) if self.op is not UOps.ALU else \
self.arg.value, self.dtype, self.src)
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
@functools.cached_property
@ -366,7 +366,7 @@ class UOp(MathTrait):
@classmethod
def _const(cls, dtype:Optional[DType], b:ConstType|Variable):
# TODO: fix dtype of b.max after Variable is just an UOp
if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, arg=(b, cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max))))
if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, arg=(b.expr, cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max))))
if dtype is not None and dtype != (sdtype := dtype.scalar()):
return cls(UOps.VECTORIZE, dtype, src=tuple(cls(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count)))
return cls(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
@ -388,7 +388,8 @@ class UOp(MathTrait):
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
def variables(self) -> List[Variable]:
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
return sorted(set.union(*st_vars, set([x.arg[0] for x in self.sparents if x.op is UOps.DEFINE_VAR])), key=lambda v: v.expr)
return sorted(set.union(*st_vars, set([Variable(x.arg[0], x.arg[1], x.arg[2]) for x in self.sparents if x.op is UOps.DEFINE_VAR])),
key=lambda v: v.expr)
def const_factor(self) -> int:
"""largest known int that divides self"""
if self.op is UOps.CONST: return self.arg
@ -476,7 +477,7 @@ def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x:
def uop_alu_resolve(u:UOp) -> sint:
if u.op is UOps.CONST: return u.arg
if u.op is UOps.DEFINE_VAR: return u.arg[0]
if u.op is UOps.DEFINE_VAR: return Variable(u.arg[0], u.arg[1].arg, u.arg[2].arg)
if u.op is UOps.ALU: return exec_alu(u.arg, cast(DType,u.dtype), tuple(map(uop_alu_resolve, u.src)))
raise RuntimeError(f"ALU resolve fail @ {u.op}")

View File

@ -34,7 +34,7 @@ class Program:
if not self._ran_post_init and self.uops is not None:
# single pass through the uops
for u in self.uops:
if u.op is UOps.DEFINE_VAR: self.vars.append(u.arg[0])
if u.op is UOps.DEFINE_VAR: self.vars.append(Variable(u.arg[0], u.arg[1].arg, u.arg[2].arg))
if u.op is UOps.DEFINE_GLOBAL: self.globals.append(u.arg)
if u.op is UOps.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL])
if u.op is UOps.SPECIAL:

View File

@ -212,9 +212,9 @@ class PTXRenderer(Renderer):
r[u] = "%" + args[0]
kernel = [f".reg .u32 %{args[0]};"] + kernel
elif uop is UOps.DEFINE_VAR:
bufs.append((args[0].expr, dtype))
r[u] = f"%{args[0].expr}"
kk(*self.render_load(args[0].expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
bufs.append((args[0], dtype))
r[u] = f"%{args[0]}"
kk(*self.render_load(args[0], ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
elif uop is UOps.CONST: r[u] = const(args, dtype, mov=True)
elif uop is UOps.GEP: r[u] = r[src[0]][u.arg]
elif uop is UOps.LOAD:

View File

@ -147,10 +147,10 @@ class CStyleLanguage(Renderer):
kk(f"int {args[0]} = {self.code_for_workitem[args[0][0]](args[0][-1])}; /* {args[1]} */")
r[u] = args[0]
elif uop is UOps.DEFINE_VAR:
assert args[0].expr not in seen_vars, f"duplicate variable {args[0].expr}"
seen_vars.add(args[0].expr)
bufs[u] = (args[0].expr, (dtype,False))
r[u] = args[0].expr
assert args[0] not in seen_vars, f"duplicate variable {args[0]}"
seen_vars.add(args[0])
bufs[u] = (args[0], (dtype,False))
r[u] = args[0]
elif uop is UOps.LOAD:
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
# NOTE: this relies on the load not happening if it's in the unselected branch

View File

@ -19,7 +19,7 @@ render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.pyint, self
ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
Variable: lambda self,ops,ctx: ctx[self] if ctx is not None and self in ctx else \
UOp(UOps.DEFINE_VAR, dtypes.int, arg=(self, UOp.const(dtypes.int, self.min), UOp.const(dtypes.int, self.max))),
UOp(UOps.DEFINE_VAR, dtypes.int, arg=(self.expr, UOp.const(dtypes.int, self.min), UOp.const(dtypes.int, self.max))),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }