mirror of https://github.com/commaai/tinygrad.git
clean up test_uop_symbolic.py (#7058)
enable more tests and remove dead tests
This commit is contained in:
parent
8094340221
commit
a99e42cf2f
|
@ -1,7 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
import unittest, pickle
|
||||
from typing import Tuple
|
||||
#from tinygrad.ops import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer, create_lt_node, create_ge_node
|
||||
|
||||
# TODO: fix all the @unittest.expectedFailure
|
||||
|
||||
|
@ -10,8 +9,8 @@ from typing import Tuple
|
|||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.dtype import dtypes, PtrDType, ConstType
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite
|
||||
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
|
||||
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops, graph_rewrite
|
||||
from tinygrad import Variable
|
||||
import functools
|
||||
|
||||
|
@ -41,7 +40,6 @@ def MulNode(x, y): return x*y
|
|||
|
||||
# *** leave tests the same
|
||||
|
||||
@unittest.skip("not supported on uops yet")
|
||||
class TestSymbolicPickle(unittest.TestCase):
|
||||
def _test_pickle_unpickle(self, x): self.assertEqual(x, pickle.loads(pickle.dumps(x)))
|
||||
def test_pickle_variable(self): self._test_pickle_unpickle(Variable("a", 3, 8))
|
||||
|
@ -82,10 +80,6 @@ class TestSymbolic(unittest.TestCase):
|
|||
expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
|
||||
create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)])
|
||||
self.helper_test_variable(expr, 0, 1, "((idx1<128)&(idx2<128))")
|
||||
# # bool divided by int is not allowed
|
||||
# expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
|
||||
# create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7), 512)])
|
||||
# self.helper_test_variable(expr//4, 0, 0, "0")
|
||||
|
||||
def test_lt_factors(self):
|
||||
expr = create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256), 512)
|
||||
|
@ -94,33 +88,19 @@ class TestSymbolic(unittest.TestCase):
|
|||
def test_div_reduction(self):
|
||||
self.helper_test_variable(Variable("a", 2, 3)//2, 1, 1, "1")
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_var_becomes_num(self):
|
||||
self.helper_test_variable(Variable("a", 2, 2), 2, 2, "2")
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_equality(self):
|
||||
idx1 = Variable("idx1", 0, 3)
|
||||
idx2 = Variable("idx2", 0, 3)
|
||||
assert idx1 == idx1
|
||||
assert idx1 != idx2
|
||||
assert idx1*4 == idx1*4
|
||||
assert idx1*4 != idx1*3
|
||||
assert idx1*4 != idx1+4
|
||||
assert idx1*4 != idx2*4
|
||||
assert idx1+idx2 == idx1+idx2
|
||||
assert idx1+idx2 == idx2+idx1
|
||||
assert idx1+idx2 != idx2
|
||||
assert idx1*idx2 == idx2*idx1
|
||||
|
||||
#def test_numnode_eq_int(self):
|
||||
# n1 = NumNode(1)
|
||||
# n2 = NumNode(2)
|
||||
# assert n1 == 1
|
||||
# assert n2 == 2
|
||||
# assert n1 != n2
|
||||
# assert hash(n1) == hash(1)
|
||||
# assert hash(n2) == hash(2)
|
||||
assert idx1 is idx1
|
||||
assert idx1 is not idx2
|
||||
assert idx1*4 is idx1*4
|
||||
assert idx1*4 is not idx1*3
|
||||
assert idx1*4 is not idx1+4
|
||||
assert idx1*4 is not idx2*4
|
||||
assert idx1+idx2 is idx1+idx2
|
||||
# assert idx1+idx2 is idx2+idx1
|
||||
assert idx1+idx2 is not idx2
|
||||
# assert idx1*idx2 is idx2*idx1
|
||||
|
||||
def test_factorize(self):
|
||||
a = Variable("a", 0, 8)
|
||||
|
@ -215,7 +195,6 @@ class TestSymbolic(unittest.TestCase):
|
|||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)")
|
||||
|
||||
def test_mod_to_sub(self):
|
||||
# This is mod reduction
|
||||
self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, "(a+-1)")
|
||||
|
||||
def test_sum_div_const(self):
|
||||
|
@ -404,9 +383,6 @@ class TestSymbolic(unittest.TestCase):
|
|||
self.helper_test_variable((idx0*v+idx1)//v, 0, 2, "(idx0)")
|
||||
self.helper_test_variable((idx0*v+idx1)%v, 0, start_pos, "idx1")
|
||||
|
||||
# *** below are uop_symbolic only
|
||||
|
||||
# NOTE: tests are not correct in symbolic
|
||||
# TODO: simplify the expression
|
||||
def test_div_neg_all_range(self):
|
||||
gidx = Variable("gidx", 0, 124)
|
||||
|
@ -483,26 +459,22 @@ class TestSymbolic(unittest.TestCase):
|
|||
self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)")
|
||||
self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)")
|
||||
|
||||
@unittest.skip("not supported on uops yet")
|
||||
class TestSymbolicNumeric(unittest.TestCase):
|
||||
def helper_test_numeric(self, f):
|
||||
# TODO: why are the negative tests broken? (even if we did support negative variables)
|
||||
#MIN, MAX = -10, 10
|
||||
MIN, MAX = 0, 10
|
||||
# one number
|
||||
for i in range(MIN, MAX):
|
||||
v = f(NumNode(i))
|
||||
#print(i, f(i), v.min, v.max)
|
||||
self.assertEqual(v.min, v.max)
|
||||
self.assertEqual(v.min, f(i))
|
||||
v = graph_rewrite(f(NumNode(i)), sym)
|
||||
self.assertEqual(v.vmin, v.vmax)
|
||||
self.assertEqual(v.vmin, f(i))
|
||||
for kmin in range(MIN, MAX):
|
||||
for kmax in range(MIN, MAX):
|
||||
if kmin > kmax: continue
|
||||
v = f(Variable("tmp", kmin, kmax))
|
||||
values = [f(rv) for rv in range(kmin, kmax+1)]
|
||||
# the min and max may not be exact
|
||||
self.assertLessEqual(v.min, min(values))
|
||||
self.assertGreaterEqual(v.max, max(values))
|
||||
self.assertLessEqual(v.vmin, min(values))
|
||||
self.assertGreaterEqual(v.vmax, max(values))
|
||||
|
||||
def test_mod_4(self): self.helper_test_numeric(lambda x: (x%4))
|
||||
def test_div_4(self): self.helper_test_numeric(lambda x: (x//4))
|
||||
|
@ -543,13 +515,6 @@ class TestSymbolicVars(unittest.TestCase):
|
|||
assert (a * a).vars() == {a}
|
||||
assert (a//4 + a//6).vars() == {a}
|
||||
|
||||
@unittest.skip("not supported on uops yet")
|
||||
class TestSymbolicMinMax(unittest.TestCase):
|
||||
def test_min_max_known(self):
|
||||
a = Variable("a", 1, 8)
|
||||
assert max(1, a) == max(a, 1) == a
|
||||
assert min(1, a) == min(a, 1) == 1
|
||||
|
||||
"""
|
||||
@unittest.skip("not supported on uops yet")
|
||||
class TestSymRender(unittest.TestCase):
|
||||
|
|
Loading…
Reference in New Issue