mirror of https://github.com/commaai/tinygrad.git
remove Variable from UOp.DEFINE_VAR (#6393)
now it's just arg = (expr as str, min as UOp.const, max as UOp.const)
This commit is contained in:
parent
9ed2b8b818
commit
26c5d8346a
|
@ -1,7 +1,7 @@
|
|||
from typing import List
|
||||
import unittest, time
|
||||
from test.helpers import assert_equiv_uops
|
||||
from tinygrad import dtypes, Variable, Device
|
||||
from tinygrad import dtypes, Device
|
||||
from tinygrad.dtype import PtrDType
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps, UOps, UOp, KernelInfo
|
||||
|
@ -131,10 +131,10 @@ class TestGraphRewrite(unittest.TestCase):
|
|||
self.assertEqual(nout.src[1].arg, 3.0)
|
||||
|
||||
def test_consts_go_last(self):
|
||||
a = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('a', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
b = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('b', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
c = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('c', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
d = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('d', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
a = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('a', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
b = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('b', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
c = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('c', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
d = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('d', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b]
|
||||
for out in outs:
|
||||
sink = graph_rewrite(out, constant_folder)
|
||||
|
@ -155,7 +155,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
self.assertEqual(out.arg, 3.0)
|
||||
|
||||
def test_where_same_fold(self):
|
||||
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('tmp', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('tmp', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
c0 = UOp(UOps.CONST, dtypes.int, arg=0)
|
||||
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
|
@ -249,7 +249,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
for i in [2, 4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[0], acc)
|
||||
|
@ -258,7 +258,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
for i in [2, 4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[0], acc)
|
||||
|
@ -268,21 +268,19 @@ class TestUOpGraph(unittest.TestCase):
|
|||
for i in [4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
|
||||
tuple(UOp(UOps.DEFINE_VAR, dtypes.half,
|
||||
arg=(Variable(f'tmp{j}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[-1], wmma)
|
||||
|
||||
for i in [4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
|
||||
tuple(UOp(UOps.DEFINE_VAR, dtypes.half,
|
||||
arg=(Variable(f'tmp{j}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(f'tmp{j}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1))) for j in range(i//2)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[-1], wmma)
|
||||
|
@ -290,17 +288,17 @@ class TestUOpGraph(unittest.TestCase):
|
|||
for i in [2, 4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[-1], wmma)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(f'tmp{i}', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0), UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=('acc', UOp.const(dtypes.half, 0), UOp.const(dtypes.half, 1)))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[-1], wmma)
|
||||
|
@ -326,7 +324,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1)
|
||||
|
||||
def test_depth_2_const_fold(self):
|
||||
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('tmp', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=('tmp', UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
c2 = UOp(UOps.CONST, dtypes.int, arg=2)
|
||||
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
|
||||
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
|
||||
|
|
|
@ -27,10 +27,7 @@ def render(self) -> Tuple[str, ConstType, ConstType]:
|
|||
|
||||
def NumNode(val): return UOp.const(dtypes.int, val)
|
||||
def Variable(expr, nmin, nmax):
|
||||
# TODO: fix DEFINE_VAR to not need this
|
||||
class TempVar:
|
||||
def __init__(self, x): self.expr = x
|
||||
return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(TempVar(expr), UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)))
|
||||
return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(expr, UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)))
|
||||
class Node:
|
||||
@staticmethod
|
||||
def sum(ops): return functools.reduce(lambda x,y: x+y, ops)
|
||||
|
@ -306,7 +303,6 @@ class TestSymbolic(unittest.TestCase):
|
|||
def test_sum_combine_num(self):
|
||||
self.helper_test_variable(Node.sum([NumNode(29), Variable("a", 0, 10), NumNode(-23)]), 6, 16, {"(6+a)", "(a+6)"})
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_sum_num_hoisted_and_factors_cancel_out(self):
|
||||
self.helper_test_variable(Node.sum([Variable("a", 0, 1) * -4 + 1, Variable("a", 0, 1) * 4]), 1, 1, "1")
|
||||
|
||||
|
|
|
@ -333,7 +333,7 @@ class UOp(MathTrait):
|
|||
@functools.cached_property
|
||||
def cmp_tuple(self) -> Tuple[int, Any, Optional[DType], Tuple[UOp, ...]]:
|
||||
# NOTE: this sort of DEFINE_VAR shouldn't have to be here. only for PTX
|
||||
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg[0].expr) if self.op is not UOps.ALU else \
|
||||
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg[0]) if self.op is not UOps.ALU else \
|
||||
self.arg.value, self.dtype, self.src)
|
||||
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
|
||||
@functools.cached_property
|
||||
|
@ -366,7 +366,7 @@ class UOp(MathTrait):
|
|||
@classmethod
|
||||
def _const(cls, dtype:Optional[DType], b:ConstType|Variable):
|
||||
# TODO: fix dtype of b.max after Variable is just an UOp
|
||||
if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, arg=(b, cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max))))
|
||||
if isinstance(b, Variable): return cls(UOps.DEFINE_VAR, dtype, arg=(b.expr, cls.const(dtypes.int, b.min), cls.const(dtypes.int, cast(int,b.max))))
|
||||
if dtype is not None and dtype != (sdtype := dtype.scalar()):
|
||||
return cls(UOps.VECTORIZE, dtype, src=tuple(cls(UOps.CONST, sdtype, arg=dtypes.as_const(b, sdtype)) for _ in range(dtype.count)))
|
||||
return cls(UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b)
|
||||
|
@ -388,7 +388,8 @@ class UOp(MathTrait):
|
|||
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
|
||||
def variables(self) -> List[Variable]:
|
||||
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
|
||||
return sorted(set.union(*st_vars, set([x.arg[0] for x in self.sparents if x.op is UOps.DEFINE_VAR])), key=lambda v: v.expr)
|
||||
return sorted(set.union(*st_vars, set([Variable(x.arg[0], x.arg[1], x.arg[2]) for x in self.sparents if x.op is UOps.DEFINE_VAR])),
|
||||
key=lambda v: v.expr)
|
||||
def const_factor(self) -> int:
|
||||
"""largest known int that divides self"""
|
||||
if self.op is UOps.CONST: return self.arg
|
||||
|
@ -476,7 +477,7 @@ def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x:
|
|||
|
||||
def uop_alu_resolve(u:UOp) -> sint:
|
||||
if u.op is UOps.CONST: return u.arg
|
||||
if u.op is UOps.DEFINE_VAR: return u.arg[0]
|
||||
if u.op is UOps.DEFINE_VAR: return Variable(u.arg[0], u.arg[1].arg, u.arg[2].arg)
|
||||
if u.op is UOps.ALU: return exec_alu(u.arg, cast(DType,u.dtype), tuple(map(uop_alu_resolve, u.src)))
|
||||
raise RuntimeError(f"ALU resolve fail @ {u.op}")
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ class Program:
|
|||
if not self._ran_post_init and self.uops is not None:
|
||||
# single pass through the uops
|
||||
for u in self.uops:
|
||||
if u.op is UOps.DEFINE_VAR: self.vars.append(u.arg[0])
|
||||
if u.op is UOps.DEFINE_VAR: self.vars.append(Variable(u.arg[0], u.arg[1].arg, u.arg[2].arg))
|
||||
if u.op is UOps.DEFINE_GLOBAL: self.globals.append(u.arg)
|
||||
if u.op is UOps.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL])
|
||||
if u.op is UOps.SPECIAL:
|
||||
|
|
|
@ -212,9 +212,9 @@ class PTXRenderer(Renderer):
|
|||
r[u] = "%" + args[0]
|
||||
kernel = [f".reg .u32 %{args[0]};"] + kernel
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
bufs.append((args[0].expr, dtype))
|
||||
r[u] = f"%{args[0].expr}"
|
||||
kk(*self.render_load(args[0].expr, ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
|
||||
bufs.append((args[0], dtype))
|
||||
r[u] = f"%{args[0]}"
|
||||
kk(*self.render_load(args[0], ssa('dat', u, self.types[dtype]), dtype, ss=".param"))
|
||||
elif uop is UOps.CONST: r[u] = const(args, dtype, mov=True)
|
||||
elif uop is UOps.GEP: r[u] = r[src[0]][u.arg]
|
||||
elif uop is UOps.LOAD:
|
||||
|
|
|
@ -147,10 +147,10 @@ class CStyleLanguage(Renderer):
|
|||
kk(f"int {args[0]} = {self.code_for_workitem[args[0][0]](args[0][-1])}; /* {args[1]} */")
|
||||
r[u] = args[0]
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
assert args[0].expr not in seen_vars, f"duplicate variable {args[0].expr}"
|
||||
seen_vars.add(args[0].expr)
|
||||
bufs[u] = (args[0].expr, (dtype,False))
|
||||
r[u] = args[0].expr
|
||||
assert args[0] not in seen_vars, f"duplicate variable {args[0]}"
|
||||
seen_vars.add(args[0])
|
||||
bufs[u] = (args[0], (dtype,False))
|
||||
r[u] = args[0]
|
||||
elif uop is UOps.LOAD:
|
||||
val = self.render_load(dtype, r[src[0]], src[0].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
|
||||
# NOTE: this relies on the load not happening if it's in the unselected branch
|
||||
|
|
|
@ -19,7 +19,7 @@ render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.pyint, self
|
|||
ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
|
||||
LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
|
||||
Variable: lambda self,ops,ctx: ctx[self] if ctx is not None and self in ctx else \
|
||||
UOp(UOps.DEFINE_VAR, dtypes.int, arg=(self, UOp.const(dtypes.int, self.min), UOp.const(dtypes.int, self.max))),
|
||||
UOp(UOps.DEFINE_VAR, dtypes.int, arg=(self.expr, UOp.const(dtypes.int, self.min), UOp.const(dtypes.int, self.max))),
|
||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
|
||||
|
|
Loading…
Reference in New Issue