mirror of https://github.com/commaai/tinygrad.git
clean up test_uop_symbolic.py (#7058)
enable more tests and remove dead tests
This commit is contained in:
parent
8094340221
commit
a99e42cf2f
|
@ -1,7 +1,6 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
import unittest, pickle
|
import unittest, pickle
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
#from tinygrad.ops import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer, create_lt_node, create_ge_node
|
|
||||||
|
|
||||||
# TODO: fix all the @unittest.expectedFailure
|
# TODO: fix all the @unittest.expectedFailure
|
||||||
|
|
||||||
|
@ -10,8 +9,8 @@ from typing import Tuple
|
||||||
from tinygrad.helpers import DEBUG
|
from tinygrad.helpers import DEBUG
|
||||||
from tinygrad.dtype import dtypes, PtrDType, ConstType
|
from tinygrad.dtype import dtypes, PtrDType, ConstType
|
||||||
from tinygrad.codegen.linearize import linearize_uop
|
from tinygrad.codegen.linearize import linearize_uop
|
||||||
from tinygrad.codegen.uopgraph import full_graph_rewrite
|
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
|
||||||
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops
|
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops, graph_rewrite
|
||||||
from tinygrad import Variable
|
from tinygrad import Variable
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
|
@ -41,7 +40,6 @@ def MulNode(x, y): return x*y
|
||||||
|
|
||||||
# *** leave tests the same
|
# *** leave tests the same
|
||||||
|
|
||||||
@unittest.skip("not supported on uops yet")
|
|
||||||
class TestSymbolicPickle(unittest.TestCase):
|
class TestSymbolicPickle(unittest.TestCase):
|
||||||
def _test_pickle_unpickle(self, x): self.assertEqual(x, pickle.loads(pickle.dumps(x)))
|
def _test_pickle_unpickle(self, x): self.assertEqual(x, pickle.loads(pickle.dumps(x)))
|
||||||
def test_pickle_variable(self): self._test_pickle_unpickle(Variable("a", 3, 8))
|
def test_pickle_variable(self): self._test_pickle_unpickle(Variable("a", 3, 8))
|
||||||
|
@ -82,10 +80,6 @@ class TestSymbolic(unittest.TestCase):
|
||||||
expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
|
expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
|
||||||
create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)])
|
create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512)])
|
||||||
self.helper_test_variable(expr, 0, 1, "((idx1<128)&(idx2<128))")
|
self.helper_test_variable(expr, 0, 1, "((idx1<128)&(idx2<128))")
|
||||||
# # bool divided by int is not allowed
|
|
||||||
# expr = Node.ands([create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 3), 512),
|
|
||||||
# create_lt_node(Variable("idx2", 0, 511)*4 + Variable("FLOAT8_INDEX", 0, 7), 512)])
|
|
||||||
# self.helper_test_variable(expr//4, 0, 0, "0")
|
|
||||||
|
|
||||||
def test_lt_factors(self):
|
def test_lt_factors(self):
|
||||||
expr = create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256), 512)
|
expr = create_lt_node(Variable("idx1", 0, 511)*4 + Variable("FLOAT4_INDEX", 0, 256), 512)
|
||||||
|
@ -94,33 +88,19 @@ class TestSymbolic(unittest.TestCase):
|
||||||
def test_div_reduction(self):
|
def test_div_reduction(self):
|
||||||
self.helper_test_variable(Variable("a", 2, 3)//2, 1, 1, "1")
|
self.helper_test_variable(Variable("a", 2, 3)//2, 1, 1, "1")
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
def test_var_becomes_num(self):
|
|
||||||
self.helper_test_variable(Variable("a", 2, 2), 2, 2, "2")
|
|
||||||
|
|
||||||
@unittest.expectedFailure
|
|
||||||
def test_equality(self):
|
def test_equality(self):
|
||||||
idx1 = Variable("idx1", 0, 3)
|
idx1 = Variable("idx1", 0, 3)
|
||||||
idx2 = Variable("idx2", 0, 3)
|
idx2 = Variable("idx2", 0, 3)
|
||||||
assert idx1 == idx1
|
assert idx1 is idx1
|
||||||
assert idx1 != idx2
|
assert idx1 is not idx2
|
||||||
assert idx1*4 == idx1*4
|
assert idx1*4 is idx1*4
|
||||||
assert idx1*4 != idx1*3
|
assert idx1*4 is not idx1*3
|
||||||
assert idx1*4 != idx1+4
|
assert idx1*4 is not idx1+4
|
||||||
assert idx1*4 != idx2*4
|
assert idx1*4 is not idx2*4
|
||||||
assert idx1+idx2 == idx1+idx2
|
assert idx1+idx2 is idx1+idx2
|
||||||
assert idx1+idx2 == idx2+idx1
|
# assert idx1+idx2 is idx2+idx1
|
||||||
assert idx1+idx2 != idx2
|
assert idx1+idx2 is not idx2
|
||||||
assert idx1*idx2 == idx2*idx1
|
# assert idx1*idx2 is idx2*idx1
|
||||||
|
|
||||||
#def test_numnode_eq_int(self):
|
|
||||||
# n1 = NumNode(1)
|
|
||||||
# n2 = NumNode(2)
|
|
||||||
# assert n1 == 1
|
|
||||||
# assert n2 == 2
|
|
||||||
# assert n1 != n2
|
|
||||||
# assert hash(n1) == hash(1)
|
|
||||||
# assert hash(n2) == hash(2)
|
|
||||||
|
|
||||||
def test_factorize(self):
|
def test_factorize(self):
|
||||||
a = Variable("a", 0, 8)
|
a = Variable("a", 0, 8)
|
||||||
|
@ -215,7 +195,6 @@ class TestSymbolic(unittest.TestCase):
|
||||||
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)")
|
self.helper_test_variable(Node.sum([Variable("a", 0, 7)*100, Variable("b", 0, 3)*50]) % 100, 0, 99, "((b*50)%100)")
|
||||||
|
|
||||||
def test_mod_to_sub(self):
|
def test_mod_to_sub(self):
|
||||||
# This is mod reduction
|
|
||||||
self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, "(a+-1)")
|
self.helper_test_variable((1+Variable("a",1,2))%2, 0, 1, "(a+-1)")
|
||||||
|
|
||||||
def test_sum_div_const(self):
|
def test_sum_div_const(self):
|
||||||
|
@ -404,9 +383,6 @@ class TestSymbolic(unittest.TestCase):
|
||||||
self.helper_test_variable((idx0*v+idx1)//v, 0, 2, "(idx0)")
|
self.helper_test_variable((idx0*v+idx1)//v, 0, 2, "(idx0)")
|
||||||
self.helper_test_variable((idx0*v+idx1)%v, 0, start_pos, "idx1")
|
self.helper_test_variable((idx0*v+idx1)%v, 0, start_pos, "idx1")
|
||||||
|
|
||||||
# *** below are uop_symbolic only
|
|
||||||
|
|
||||||
# NOTE: tests are not correct in symbolic
|
|
||||||
# TODO: simplify the expression
|
# TODO: simplify the expression
|
||||||
def test_div_neg_all_range(self):
|
def test_div_neg_all_range(self):
|
||||||
gidx = Variable("gidx", 0, 124)
|
gidx = Variable("gidx", 0, 124)
|
||||||
|
@ -483,26 +459,22 @@ class TestSymbolic(unittest.TestCase):
|
||||||
self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)")
|
self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)")
|
||||||
self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)")
|
self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, "(((a+b+c)<1)!=1)")
|
||||||
|
|
||||||
@unittest.skip("not supported on uops yet")
|
|
||||||
class TestSymbolicNumeric(unittest.TestCase):
|
class TestSymbolicNumeric(unittest.TestCase):
|
||||||
def helper_test_numeric(self, f):
|
def helper_test_numeric(self, f):
|
||||||
# TODO: why are the negative tests broken? (even if we did support negative variables)
|
|
||||||
#MIN, MAX = -10, 10
|
|
||||||
MIN, MAX = 0, 10
|
MIN, MAX = 0, 10
|
||||||
# one number
|
# one number
|
||||||
for i in range(MIN, MAX):
|
for i in range(MIN, MAX):
|
||||||
v = f(NumNode(i))
|
v = graph_rewrite(f(NumNode(i)), sym)
|
||||||
#print(i, f(i), v.min, v.max)
|
self.assertEqual(v.vmin, v.vmax)
|
||||||
self.assertEqual(v.min, v.max)
|
self.assertEqual(v.vmin, f(i))
|
||||||
self.assertEqual(v.min, f(i))
|
|
||||||
for kmin in range(MIN, MAX):
|
for kmin in range(MIN, MAX):
|
||||||
for kmax in range(MIN, MAX):
|
for kmax in range(MIN, MAX):
|
||||||
if kmin > kmax: continue
|
if kmin > kmax: continue
|
||||||
v = f(Variable("tmp", kmin, kmax))
|
v = f(Variable("tmp", kmin, kmax))
|
||||||
values = [f(rv) for rv in range(kmin, kmax+1)]
|
values = [f(rv) for rv in range(kmin, kmax+1)]
|
||||||
# the min and max may not be exact
|
# the min and max may not be exact
|
||||||
self.assertLessEqual(v.min, min(values))
|
self.assertLessEqual(v.vmin, min(values))
|
||||||
self.assertGreaterEqual(v.max, max(values))
|
self.assertGreaterEqual(v.vmax, max(values))
|
||||||
|
|
||||||
def test_mod_4(self): self.helper_test_numeric(lambda x: (x%4))
|
def test_mod_4(self): self.helper_test_numeric(lambda x: (x%4))
|
||||||
def test_div_4(self): self.helper_test_numeric(lambda x: (x//4))
|
def test_div_4(self): self.helper_test_numeric(lambda x: (x//4))
|
||||||
|
@ -543,13 +515,6 @@ class TestSymbolicVars(unittest.TestCase):
|
||||||
assert (a * a).vars() == {a}
|
assert (a * a).vars() == {a}
|
||||||
assert (a//4 + a//6).vars() == {a}
|
assert (a//4 + a//6).vars() == {a}
|
||||||
|
|
||||||
@unittest.skip("not supported on uops yet")
|
|
||||||
class TestSymbolicMinMax(unittest.TestCase):
|
|
||||||
def test_min_max_known(self):
|
|
||||||
a = Variable("a", 1, 8)
|
|
||||||
assert max(1, a) == max(a, 1) == a
|
|
||||||
assert min(1, a) == min(a, 1) == 1
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@unittest.skip("not supported on uops yet")
|
@unittest.skip("not supported on uops yet")
|
||||||
class TestSymRender(unittest.TestCase):
|
class TestSymRender(unittest.TestCase):
|
||||||
|
|
Loading…
Reference in New Issue