clean up test_uop_symbolic.py (#7058)

enable more tests and remove dead tests
This commit is contained in:
chenyu 2024-10-14 15:35:58 -04:00 committed by GitHub
parent 8094340221
commit a99e42cf2f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 17 additions and 52 deletions

View File

@ -1,7 +1,6 @@
#!/usr/bin/env python
import unittest, pickle
from typing import Tuple
#from tinygrad.ops import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer, create_lt_node, create_ge_node
# TODO: fix all the @unittest.expectedFailure
@ -10,8 +9,8 @@ from typing import Tuple
from tinygrad.helpers import DEBUG
from tinygrad.dtype import dtypes, PtrDType, ConstType
from tinygrad.codegen.linearize import linearize_uop
from tinygrad.codegen.uopgraph import full_graph_rewrite
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops, graph_rewrite
from tinygrad import Variable
import functools
@ -41,7 +40,6 @@ def MulNode(x, y): return x*y
# *** leave tests the same
@unittest.skip("not supported on uops yet")
class TestSymbolicPickle(unittest.TestCase):
def _test_pickle_unpickle(self, x): self.assertEqual(x, pickle.loads(pickle.dumps(x)))
def test_pickle_variable(self): self._test_pickle_unpickle(Variable("a", 3, 8))
@ -82,10 +80,6 @@ class TestSymbolic(unittest.TestCase):
expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)])
self.helper_test_variable(expr, 0, 1, "((idx1<128)&(idx2<128))")
# # bool divided by int is not allowed
# expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
# create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7), 512)])
# self.helper_test_variable(expr//4, 0, 0, "0")
def test_lt_factors(self):
expr = create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256), 512)
@ -94,33 +88,19 @@ class TestSymbolic(unittest.TestCase):
def test_div_reduction(self):
self.helper_test_variable(Variable("a", 2, 3)//2, 1, 1, "1")
@unittest.expectedFailure
def test_var_becomes_num(self):
self.helper_test_variable(Variable("a", 2, 2), 2, 2, "2")
@unittest.expectedFailure
def test_equality(self):
idx1 = Variable("idx1", 0, 3)
idx2 = Variable("idx2", 0, 3)
assert idx1 == idx1
assert idx1 != idx2
assert idx1*4 == idx1*4
assert idx1*4 != idx1*3
assert idx1*4 != idx1+4
assert idx1*4 != idx2*4
assert idx1+idx2 == idx1+idx2
assert idx1+idx2 == idx2+idx1
assert idx1+idx2 != idx2
assert idx1*idx2 == idx2*idx1
#def test_numnode_eq_int(self):
# n1 = NumNode(1)
# n2 = NumNode(2)
# assert n1 == 1
# assert n2 == 2
# assert n1 != n2
# assert hash(n1) == hash(1)
# assert hash(n2) == hash(2)
assert idx1 is idx1
assert idx1 is not idx2
assert idx1*4 is idx1*4
assert idx1*4 is not idx1*3
assert idx1*4 is not idx1+4
assert idx1*4 is not idx2*4
assert idx1+idx2 is idx1+idx2
# assert idx1+idx2 is idx2+idx1
assert idx1+idx2 is not idx2
# assert idx1*idx2 is idx2*idx1
def test_factorize(self):
a = Variable("a", 0, 8)
@ -215,7 +195,6 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)")
def test_mod_to_sub(self):
# This is mod reduction
self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, "(a+-1)")
def test_sum_div_const(self):
@ -404,9 +383,6 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((idx0*v+idx1)//v, 0, 2, "(idx0)")
self.helper_test_variable((idx0*v+idx1)%v, 0, start_pos, "idx1")
# *** below are uop_symbolic only
# NOTE: tests are not correct in symbolic
# TODO: simplify the expression
def test_div_neg_all_range(self):
gidx = Variable("gidx", 0, 124)
@ -483,26 +459,22 @@ class TestSymbolic(unittest.TestCase):
self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)")
self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)")
@unittest.skip("not supported on uops yet")
class TestSymbolicNumeric(unittest.TestCase):
def helper_test_numeric(self, f):
# TODO: why are the negative tests broken? (even if we did support negative variables)
#MIN, MAX = -10, 10
MIN, MAX = 0, 10
# one number
for i in range(MIN, MAX):
v = f(NumNode(i))
#print(i, f(i), v.min, v.max)
self.assertEqual(v.min, v.max)
self.assertEqual(v.min, f(i))
v = graph_rewrite(f(NumNode(i)), sym)
self.assertEqual(v.vmin, v.vmax)
self.assertEqual(v.vmin, f(i))
for kmin in range(MIN, MAX):
for kmax in range(MIN, MAX):
if kmin > kmax: continue
v = f(Variable("tmp", kmin, kmax))
values = [f(rv) for rv in range(kmin, kmax+1)]
# the min and max may not be exact
self.assertLessEqual(v.min, min(values))
self.assertGreaterEqual(v.max, max(values))
self.assertLessEqual(v.vmin, min(values))
self.assertGreaterEqual(v.vmax, max(values))
def test_mod_4(self): self.helper_test_numeric(lambda x: (x%4))
def test_div_4(self): self.helper_test_numeric(lambda x: (x//4))
@ -543,13 +515,6 @@ class TestSymbolicVars(unittest.TestCase):
assert (a * a).vars() == {a}
assert (a//4 + a//6).vars() == {a}
@unittest.skip("not supported on uops yet")
class TestSymbolicMinMax(unittest.TestCase):
def test_min_max_known(self):
a = Variable("a", 1, 8)
assert max(1, a) == max(a, 1) == a
assert min(1, a) == min(a, 1) == 1
"""
@unittest.skip("not supported on uops yet")
class TestSymRender(unittest.TestCase):