diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 5cf0043c..d24e64d8 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -1,7 +1,6 @@ #!/usr/bin/env python import unittest, pickle from typing import Tuple -#from tinygrad.ops import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer, create_lt_node, create_ge_node # TODO: fix all the @unittest.expectedFailure @@ -10,8 +9,8 @@ from typing import Tuple from tinygrad.helpers import DEBUG from tinygrad.dtype import dtypes, PtrDType, ConstType from tinygrad.codegen.linearize import linearize_uop -from tinygrad.codegen.uopgraph import full_graph_rewrite -from tinygrad.ops import BinaryOps, UOp, UOps, print_uops +from tinygrad.codegen.uopgraph import full_graph_rewrite, sym +from tinygrad.ops import BinaryOps, UOp, UOps, print_uops, graph_rewrite from tinygrad import Variable import functools @@ -41,7 +40,6 @@ def MulNode(x, y): return x*y # *** leave tests the same -@unittest.skip("not supported on uops yet") class TestSymbolicPickle(unittest.TestCase): def _test_pickle_unpickle(self, x): self.assertEqual(x, pickle.loads(pickle.dumps(x))) def test_pickle_variable(self): self._test_pickle_unpickle(Variable("a", 3, 8)) @@ -82,10 +80,6 @@ class TestSymbolic(unittest.TestCase): expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512), create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)]) self.helper_test_variable(expr, 0, 1, "((idx1<128)&(idx2<128))") - # # bool divided by int is not allowed - # expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512), - # create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7), 512)]) - # self.helper_test_variable(expr//4, 0, 0, "0") def test_lt_factors(self): expr = create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256), 512) @@ -94,33 +88,19 @@ class TestSymbolic(unittest.TestCase): def test_div_reduction(self): self.helper_test_variable(Variable("a", 2, 3)//2, 1, 1, "1") - @unittest.expectedFailure - def test_var_becomes_num(self): - self.helper_test_variable(Variable("a", 2, 2), 2, 2, "2") - - @unittest.expectedFailure def test_equality(self): idx1 = Variable("idx1", 0, 3) idx2 = Variable("idx2", 0, 3) - assert idx1 == idx1 - assert idx1 != idx2 - assert idx1*4 == idx1*4 - assert idx1*4 != idx1*3 - assert idx1*4 != idx1+4 - assert idx1*4 != idx2*4 - assert idx1+idx2 == idx1+idx2 - assert idx1+idx2 == idx2+idx1 - assert idx1+idx2 != idx2 - assert idx1*idx2 == idx2*idx1 - - #def test_numnode_eq_int(self): - # n1 = NumNode(1) - # n2 = NumNode(2) - # assert n1 == 1 - # assert n2 == 2 - # assert n1 != n2 - # assert hash(n1) == hash(1) - # assert hash(n2) == hash(2) + assert idx1 is idx1 + assert idx1 is not idx2 + assert idx1*4 is idx1*4 + assert idx1*4 is not idx1*3 + assert idx1*4 is not idx1+4 + assert idx1*4 is not idx2*4 + assert idx1+idx2 is idx1+idx2 + # assert idx1+idx2 is idx2+idx1 + assert idx1+idx2 is not idx2 + # assert idx1*idx2 is idx2*idx1 def test_factorize(self): a = Variable("a", 0, 8) @@ -215,7 +195,6 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable(Node.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)") def test_mod_to_sub(self): - # This is mod reduction self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, "(a+-1)") def test_sum_div_const(self): @@ -404,9 +383,6 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((idx0*v+idx1)//v, 0, 2, "(idx0)") self.helper_test_variable((idx0*v+idx1)%v, 0, start_pos, "idx1") - # *** below are uop_symbolic only - - # NOTE: tests are not correct in symbolic # TODO: simplify the expression def test_div_neg_all_range(self): gidx = Variable("gidx", 0, 124) @@ -483,26 +459,22 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)") self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)") -@unittest.skip("not supported on uops yet") class TestSymbolicNumeric(unittest.TestCase): def helper_test_numeric(self, f): - # TODO: why are the negative tests broken? (even if we did support negative variables) - #MIN, MAX = -10, 10 MIN, MAX = 0, 10 # one number for i in range(MIN, MAX): - v = f(NumNode(i)) - #print(i, f(i), v.min, v.max) - self.assertEqual(v.min, v.max) - self.assertEqual(v.min, f(i)) + v = graph_rewrite(f(NumNode(i)), sym) + self.assertEqual(v.vmin, v.vmax) + self.assertEqual(v.vmin, f(i)) for kmin in range(MIN, MAX): for kmax in range(MIN, MAX): if kmin > kmax: continue v = f(Variable("tmp", kmin, kmax)) values = [f(rv) for rv in range(kmin, kmax+1)] # the min and max may not be exact - self.assertLessEqual(v.min, min(values)) - self.assertGreaterEqual(v.max, max(values)) + self.assertLessEqual(v.vmin, min(values)) + self.assertGreaterEqual(v.vmax, max(values)) def test_mod_4(self): self.helper_test_numeric(lambda x: (x%4)) def test_div_4(self): self.helper_test_numeric(lambda x: (x//4)) @@ -543,13 +515,6 @@ class TestSymbolicVars(unittest.TestCase): assert (a * a).vars() == {a} assert (a//4 + a//6).vars() == {a} -@unittest.skip("not supported on uops yet") -class TestSymbolicMinMax(unittest.TestCase): - def test_min_max_known(self): - a = Variable("a", 1, 8) - assert max(1, a) == max(a, 1) == a - assert min(1, a) == min(a, 1) == 1 - """ @unittest.skip("not supported on uops yet") class TestSymRender(unittest.TestCase):