2024-10-01 13:11:42 +08:00
|
|
|
import unittest
|
|
|
|
from tinygrad.dtype import dtypes
|
2024-10-02 12:57:16 +08:00
|
|
|
from tinygrad.ops import UOp, resolve
|
2024-10-01 13:11:42 +08:00
|
|
|
|
|
|
|
class TestUOpResolve(unittest.TestCase):
|
|
|
|
def test_simple_int(self):
|
|
|
|
u = UOp.const(dtypes.int, 4)
|
|
|
|
self.assertEqual(int(u), 4)
|
|
|
|
|
|
|
|
def test_int_add(self):
|
|
|
|
u = UOp.const(dtypes.int, 4) + 7
|
|
|
|
self.assertEqual(int(u), 11)
|
|
|
|
|
|
|
|
def test_lt(self):
|
|
|
|
u = UOp.const(dtypes.int, 4) < 7
|
|
|
|
self.assertTrue(u)
|
|
|
|
|
2024-10-01 18:49:09 +08:00
|
|
|
def test_rfloordiv(self):
|
|
|
|
u = 8 // UOp.const(dtypes.int, 4)
|
|
|
|
self.assertEqual(int(u), 2)
|
|
|
|
|
|
|
|
def test_rtruediv(self):
|
|
|
|
u = 9 / UOp.const(dtypes.float, 4)
|
|
|
|
self.assertEqual(float(u), 2.25)
|
|
|
|
|
2024-10-01 13:11:42 +08:00
|
|
|
def test_leq(self):
|
|
|
|
u = UOp.const(dtypes.int, 4) <= 4
|
|
|
|
self.assertTrue(u)
|
|
|
|
|
|
|
|
def test_ne(self):
|
2024-10-01 18:49:09 +08:00
|
|
|
u = UOp.const(dtypes.int, 4) != 7
|
2024-10-01 13:11:42 +08:00
|
|
|
self.assertTrue(u)
|
|
|
|
|
|
|
|
def test_ne_f(self):
|
2024-10-01 18:49:09 +08:00
|
|
|
u = UOp.const(dtypes.int, 4) != 4
|
2024-10-01 13:11:42 +08:00
|
|
|
self.assertFalse(u)
|
|
|
|
|
|
|
|
def test_ngt(self):
|
|
|
|
u = UOp.const(dtypes.int, 4) > 7
|
|
|
|
self.assertFalse(u)
|
|
|
|
|
2024-10-02 17:12:30 +08:00
|
|
|
def test_ssimplify(self):
|
|
|
|
self.assertEqual((8 % UOp.const(dtypes.int, 4)).ssimplify(), 0)
|
|
|
|
self.assertEqual((8 * UOp.const(dtypes.int, 4)).ssimplify(), 32)
|
|
|
|
|
2024-10-02 12:57:16 +08:00
|
|
|
def test_ambiguous_less_than(self):
|
2024-10-12 22:36:24 +08:00
|
|
|
u = UOp.variable("i", 1, 10)
|
2024-10-02 12:57:16 +08:00
|
|
|
self.assertTrue(resolve(u < 4))
|
|
|
|
self.assertFalse(resolve(u < 4, False))
|
|
|
|
self.assertTrue(resolve(u < 11, False))
|
|
|
|
self.assertFalse(resolve(u < -1, False))
|
|
|
|
self.assertFalse(resolve(u < -1, True))
|
|
|
|
|
2024-10-01 13:11:42 +08:00
|
|
|
def test_float_direct(self):
|
|
|
|
u = UOp.const(dtypes.float, 4.5) + 7
|
|
|
|
self.assertEqual(float(u), 11.5)
|
|
|
|
|
|
|
|
def test_var_cmp_t(self):
|
2024-10-12 22:36:24 +08:00
|
|
|
u = UOp.variable("i", 1, 10) < 20
|
2024-10-01 13:11:42 +08:00
|
|
|
self.assertTrue(u)
|
|
|
|
|
|
|
|
def test_var_cmp_t2(self):
|
2024-10-12 22:36:24 +08:00
|
|
|
u = UOp.variable("i", 1, 10)//2 < 20
|
2024-10-01 13:11:42 +08:00
|
|
|
self.assertTrue(u)
|
|
|
|
|
|
|
|
def test_var_cmp_f(self):
|
2024-10-12 22:36:24 +08:00
|
|
|
u = UOp.variable("i", 1, 10) < 1
|
2024-10-01 13:11:42 +08:00
|
|
|
self.assertFalse(u)
|
|
|
|
|
|
|
|
def test_var_cmp_f2(self):
|
2024-10-12 22:36:24 +08:00
|
|
|
u = UOp.variable("i", 1, 10) > 11
|
2024-10-01 13:11:42 +08:00
|
|
|
self.assertFalse(u)
|
|
|
|
|
|
|
|
def test_or_true(self):
|
2024-10-12 22:36:24 +08:00
|
|
|
u = UOp.variable("b", False, True, dtypes.bool) | True
|
2024-10-01 13:11:42 +08:00
|
|
|
self.assertTrue(u)
|
|
|
|
|
|
|
|
def test_or_false(self):
|
|
|
|
with self.assertRaises(ValueError):
|
2024-10-12 22:36:24 +08:00
|
|
|
u = UOp.variable("b", False, True, dtypes.bool) | False
|
2024-10-01 13:11:42 +08:00
|
|
|
self.assertTrue(u)
|
|
|
|
|
|
|
|
def test_and_false(self):
|
2024-10-12 22:36:24 +08:00
|
|
|
u = UOp.variable("b", False, True, dtypes.bool) & False
|
2024-10-01 13:11:42 +08:00
|
|
|
self.assertFalse(u)
|
|
|
|
|
2024-10-01 18:28:41 +08:00
|
|
|
def test_max(self):
|
2024-10-12 22:36:24 +08:00
|
|
|
x = UOp.variable("x", 1, 10)
|
|
|
|
y = UOp.variable("y", 5, 10)
|
2024-10-01 18:28:41 +08:00
|
|
|
u = x.max(y)
|
|
|
|
self.assertTrue(u < 20)
|
|
|
|
self.assertFalse(u < 3)
|
|
|
|
|
|
|
|
def test_x_lt_x(self):
|
2024-10-12 22:36:24 +08:00
|
|
|
x = UOp.variable("i", 1, 10)
|
2024-10-01 18:28:41 +08:00
|
|
|
self.assertFalse(x < x)
|
|
|
|
|
|
|
|
@unittest.expectedFailure
|
|
|
|
def test_x_lt_xp1(self):
|
2024-10-12 22:36:24 +08:00
|
|
|
x = UOp.variable("i", 1, 10)
|
2024-10-01 18:28:41 +08:00
|
|
|
self.assertTrue(x < (x+1))
|
|
|
|
|
2024-10-01 13:11:42 +08:00
|
|
|
def test_and_true(self):
|
|
|
|
with self.assertRaises(ValueError):
|
2024-10-12 22:36:24 +08:00
|
|
|
u = UOp.variable("b", False, True, dtypes.bool) & True
|
2024-10-01 13:11:42 +08:00
|
|
|
self.assertFalse(u)
|
|
|
|
|
2024-10-01 18:28:41 +08:00
|
|
|
@unittest.expectedFailure
|
2024-10-01 13:11:42 +08:00
|
|
|
def test_var_cmp_range(self):
|
2024-10-12 22:36:24 +08:00
|
|
|
v = UOp.variable("i", 1, 10)
|
2024-10-02 09:54:17 +08:00
|
|
|
u = (v > 4) | (v < 6)
|
2024-10-01 13:11:42 +08:00
|
|
|
self.assertTrue(u)
|
|
|
|
|
|
|
|
def test_var_cmp_assert(self):
|
|
|
|
with self.assertRaises(ValueError):
|
2024-10-12 22:36:24 +08:00
|
|
|
u = UOp.variable("i", 1, 10) < 5
|
2024-10-01 13:11:42 +08:00
|
|
|
self.assertFalse(u)
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
unittest.main()
|