From 07bd6e070d1236a51da4f63b04d0261a14b6cea7 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 16 Sep 2024 13:06:31 +0800 Subject: [PATCH] add more uops tests for vmin/vmax/const_factor/divides (#6533) --- test/unit/test_uop_vmin_vmax.py | 194 ++++++++++++++++++++++++++++++++ tinygrad/codegen/uopgraph.py | 2 + tinygrad/ops.py | 2 +- 3 files changed, 197 insertions(+), 1 deletion(-) create mode 100644 test/unit/test_uop_vmin_vmax.py diff --git a/test/unit/test_uop_vmin_vmax.py b/test/unit/test_uop_vmin_vmax.py new file mode 100644 index 00000000..f36eb773 --- /dev/null +++ b/test/unit/test_uop_vmin_vmax.py @@ -0,0 +1,194 @@ +import unittest +from tinygrad.ops import UOp, dtypes + +class TestVminVmaxProperties(unittest.TestCase): + def test_vmin_vmax_constant(self): + # vmin and vmax for a constant + uop = UOp.const(dtypes.int32, 42) + self.assertEqual(uop.vmin, 42) + self.assertEqual(uop.vmax, 42) + + def test_vmin_vmax_addition_with_variable(self): + # vmin and vmax for addition with a variable + x = UOp.define_var('x', dtypes.int32, 10, 20) + uop = x + 5 + self.assertEqual(uop.vmin, 15) + self.assertEqual(uop.vmax, 25) + + def test_vmin_vmax_multiplication_with_variable(self): + # vmin and vmax for multiplication with a variable + x = UOp.define_var('x', dtypes.int32, -3, 4) + uop = x * 2 + self.assertEqual(uop.vmin, -6) + self.assertEqual(uop.vmax, 8) + + def test_vmin_vmax_with_negative_multiplication(self): + # vmin and vmax when multiplying by a negative number + x = UOp.define_var('x', dtypes.int32, 2, 5) + uop = x * -3 + self.assertEqual(uop.vmin, -15) + self.assertEqual(uop.vmax, -6) + + def test_vmin_vmax_nested_min_max(self): + # vmin and vmax with nested min/max operations + x = UOp.define_var('x', dtypes.int32, 0, 10) + uop = x.max(5).min(8) + self.assertEqual(uop.vmin, 5) + self.assertEqual(uop.vmax, 8) + +class TestVminVmaxDivMod(unittest.TestCase): + def test_vmin_vmax_division_positive(self): + # vmin and vmax for division of a variable by a positive constant + x = UOp.define_var('x', dtypes.int32, 10, 20) + uop = x // 2 + self.assertEqual(uop.vmin, 5) + self.assertEqual(uop.vmax, 10) + + def test_vmin_vmax_division_negative(self): + # vmin and vmax for division of a variable by a negative constant + x = UOp.define_var('x', dtypes.int32, 10, 20) + uop = x // -2 + self.assertEqual(uop.vmin, -10) + self.assertEqual(uop.vmax, -5) + + def test_vmin_vmax_mod_positive(self): + # vmin and vmax for modulo of a variable by a positive constant + x = UOp.define_var('x', dtypes.int32, 10, 20) + uop = x % 3 + self.assertEqual(uop.vmin, 0) + self.assertEqual(uop.vmax, 2) + + @unittest.skip("broken") + def test_vmin_vmax_mod_negative(self): + # vmin and vmax for modulo of a variable by a negative constant + x = UOp.define_var('x', dtypes.int32, 10, 20) + uop = x % -3 + self.assertEqual(uop.vmin, -2) + self.assertEqual(uop.vmax, 0) + + def test_vmin_vmax_division_with_mixed_range(self): + # vmin and vmax for division of a variable with a range crossing zero + x = UOp.define_var('x', dtypes.int32, -10, 10) + uop = x // 3 + self.assertEqual(uop.vmin, -4) # -10//3 = -4 + self.assertEqual(uop.vmax, 3) # 10//3 = 3 + + def test_vmin_vmax_mod_with_mixed_range(self): + # vmin and vmax for modulo of a variable with a range crossing zero + x = UOp.define_var('x', dtypes.int32, -10, 10) + uop = x % 4 + self.assertEqual(uop.vmin, 0) # modulo always positive or zero when divisor is positive + self.assertEqual(uop.vmax, 3) # max possible mod is 3 when dividing by 4 + +class TestVminVmaxVConst(unittest.TestCase): + def test_vmin_vmax_vconst_single_element(self): + # vmin and vmax for a single-element vector constant + uop = UOp.const(dtypes.int32.vec(1), (42,)) + self.assertEqual(uop.vmin, 42) + self.assertEqual(uop.vmax, 42) + + def test_vmin_vmax_vconst_multiple_elements(self): + # vmin and vmax for a multi-element vector constant + uop = UOp.const(dtypes.int32.vec(4), (10, 20, -5, 7)) + self.assertEqual(uop.vmin, -5) + self.assertEqual(uop.vmax, 20) + + def test_vmin_vmax_vconst_all_equal(self): + # vmin and vmax for a vector where all elements are equal + uop = UOp.const(dtypes.int32.vec(3), (7, 7, 7)) + self.assertEqual(uop.vmin, 7) + self.assertEqual(uop.vmax, 7) + + def test_vmin_vmax_vconst_with_negative_values(self): + # vmin and vmax for a vector constant containing negative values + uop = UOp.const(dtypes.int32.vec(4), (-10, -20, -5, -15)) + self.assertEqual(uop.vmin, -20) + self.assertEqual(uop.vmax, -5) + + def test_vmin_vmax_vconst_with_floats(self): + # vmin and vmax for a vector constant of float values + uop = UOp.const(dtypes.float32.vec(3), (1.5, -3.2, 0.0)) + self.assertEqual(uop.vmin, -3.2) + self.assertEqual(uop.vmax, 1.5) + +class TestConstFactor(unittest.TestCase): + def test_const_factor_constant(self): + # const_factor for a constant + uop = UOp.const(dtypes.int32, 42) + self.assertEqual(uop.const_factor(), 42) + + def test_const_factor_addition(self): + # const_factor for an addition of constants + uop = UOp.const(dtypes.int32, 30) + UOp.const(dtypes.int32, 12) + self.assertEqual(uop.const_factor(), 6) # GCD(30, 12) = 6 + + def test_const_factor_multiplication(self): + # const_factor for a multiplication of constants + uop = UOp.const(dtypes.int32, 5) * UOp.const(dtypes.int32, 7) + self.assertEqual(uop.const_factor(), 5) # For multiplication, it's one of the factors + + def test_const_factor_with_variable(self): + # const_factor for an expression involving a variable + x = UOp.define_var('x', dtypes.int32, 10, 20) + uop = x * 3 + self.assertEqual(uop.const_factor(), 3) + + def test_const_factor_division(self): + # const_factor for an expression with division + x = UOp.define_var('x', dtypes.int32, 10, 20) + uop = x // 4 + self.assertEqual(uop.const_factor(), 1) # Division reduces the const_factor to 1 + + def test_const_factor_multiplication_of_var_and_const(self): + # const_factor for multiplication of a variable and a constant + x = UOp.define_var('x', dtypes.int32, 6, 18) + uop = x * 4 + self.assertEqual(uop.const_factor(), 4) # Constant factor 4 + + @unittest.skip("broken") + def test_const_factor_multiplication_of_consts_and_vars(self): + # Multiplying constants and variables + x = UOp.define_var('x', dtypes.int32, 10, 20) + uop = (x * 3) * 5 + self.assertEqual(uop.const_factor(), 15) # Constant multipliers are combined (3 * 5 = 15) + +class TestDivides(unittest.TestCase): + def test_divides_constant_exact(self): + # Divides a constant by an exact divisor + uop = UOp.const(dtypes.int32, 42) + result = uop.divides(7) + self.assertIsNotNone(result) + self.assertEqual(result.const_factor(), 6) # 42 / 7 = 6 + + def test_divides_constant_inexact(self): + # Try to divide a constant by a non-exact divisor + uop = UOp.const(dtypes.int32, 42) + result = uop.divides(5) + self.assertIsNone(result) # 42 is not divisible by 5 + + @unittest.skip("broken") + def test_divides_variable_and_constant(self): + # Multiplying a variable by a constant, then dividing by the same constant + x = UOp.define_var('x', dtypes.int32, 10, 20) + uop = x * 6 + result = uop.divides(6) + self.assertIsNotNone(result) + self.assertEqual(result, x) # (x * 6) / 6 = x + + def test_divides_complex_expression(self): + # Dividing a more complex expression + x = UOp.define_var('x', dtypes.int32, 10, 20) + uop = (x * 6) + 18 + result = uop.divides(6) + self.assertIsNotNone(result) + self.assertEqual(result.const_factor(), 1) # (x + 3), const_factor is 1 + + def test_divides_with_inexact_factors(self): + # Multiplying by a constant but dividing by a non-exact divisor + x = UOp.define_var('x', dtypes.int32, 15, 45) + uop = x * 4 + result = uop.divides(3) + self.assertIsNone(result) # Cannot divide by 3, since 4 is not divisible by 3 + +if __name__ == '__main__': + unittest.main() diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index e5366fe9..dcd06074 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -232,6 +232,8 @@ constant_folder = PatternMatcher([ # lambda x,y,alu: UOp(UOps.VECTORIZE, alu.dtype, (UOp(UOps.ALU, alu.dtype.scalar(), (x,y), alu.arg),)*alu.dtype.count)), # VECTORIZE of a single element is just that element (UPat(UOps.VECTORIZE, src=(UPat(name='x'),)), lambda x: x), + # VECTORIZE void is SINK + (UPat(UOps.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(UOps.SINK, dtypes.void, x.src)), # GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST (UPat(UOps.GEP, src=(UPat(UOps.GEP, name='g2'),), name='g1'), lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))), diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 47a4c8d2..7231c8ef 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -435,7 +435,7 @@ class UOp(MathTrait): def const_factor(self) -> int: """largest known int that divides self""" if self.op is UOps.CONST: return self.arg - if self.op is UOps.VCONST: return math.gcd(*self.arg) + if self.op is UOps.VCONST: return functools.reduce(math.gcd, self.arg) if self.op is UOps.ALU: if self.arg is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor()) if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is UOps.CONST else self.src[1].arg if self.src[1].op is UOps.CONST else 1