add more uops tests for vmin/vmax/const_factor/divides (#6533)

This commit is contained in:
George Hotz 2024-09-16 13:06:31 +08:00 committed by GitHub
parent c447ec2190
commit 07bd6e070d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 197 additions and 1 deletions

View File

@ -0,0 +1,194 @@
import unittest
from tinygrad.ops import UOp, dtypes
class TestVminVmaxProperties(unittest.TestCase):
def test_vmin_vmax_constant(self):
# vmin and vmax for a constant
uop = UOp.const(dtypes.int32, 42)
self.assertEqual(uop.vmin, 42)
self.assertEqual(uop.vmax, 42)
def test_vmin_vmax_addition_with_variable(self):
# vmin and vmax for addition with a variable
x = UOp.define_var('x', dtypes.int32, 10, 20)
uop = x + 5
self.assertEqual(uop.vmin, 15)
self.assertEqual(uop.vmax, 25)
def test_vmin_vmax_multiplication_with_variable(self):
# vmin and vmax for multiplication with a variable
x = UOp.define_var('x', dtypes.int32, -3, 4)
uop = x * 2
self.assertEqual(uop.vmin, -6)
self.assertEqual(uop.vmax, 8)
def test_vmin_vmax_with_negative_multiplication(self):
# vmin and vmax when multiplying by a negative number
x = UOp.define_var('x', dtypes.int32, 2, 5)
uop = x * -3
self.assertEqual(uop.vmin, -15)
self.assertEqual(uop.vmax, -6)
def test_vmin_vmax_nested_min_max(self):
# vmin and vmax with nested min/max operations
x = UOp.define_var('x', dtypes.int32, 0, 10)
uop = x.max(5).min(8)
self.assertEqual(uop.vmin, 5)
self.assertEqual(uop.vmax, 8)
class TestVminVmaxDivMod(unittest.TestCase):
def test_vmin_vmax_division_positive(self):
# vmin and vmax for division of a variable by a positive constant
x = UOp.define_var('x', dtypes.int32, 10, 20)
uop = x // 2
self.assertEqual(uop.vmin, 5)
self.assertEqual(uop.vmax, 10)
def test_vmin_vmax_division_negative(self):
# vmin and vmax for division of a variable by a negative constant
x = UOp.define_var('x', dtypes.int32, 10, 20)
uop = x // -2
self.assertEqual(uop.vmin, -10)
self.assertEqual(uop.vmax, -5)
def test_vmin_vmax_mod_positive(self):
# vmin and vmax for modulo of a variable by a positive constant
x = UOp.define_var('x', dtypes.int32, 10, 20)
uop = x % 3
self.assertEqual(uop.vmin, 0)
self.assertEqual(uop.vmax, 2)
@unittest.skip("broken")
def test_vmin_vmax_mod_negative(self):
# vmin and vmax for modulo of a variable by a negative constant
x = UOp.define_var('x', dtypes.int32, 10, 20)
uop = x % -3
self.assertEqual(uop.vmin, -2)
self.assertEqual(uop.vmax, 0)
def test_vmin_vmax_division_with_mixed_range(self):
# vmin and vmax for division of a variable with a range crossing zero
x = UOp.define_var('x', dtypes.int32, -10, 10)
uop = x // 3
self.assertEqual(uop.vmin, -4) # -10//3 = -4
self.assertEqual(uop.vmax, 3) # 10//3 = 3
def test_vmin_vmax_mod_with_mixed_range(self):
# vmin and vmax for modulo of a variable with a range crossing zero
x = UOp.define_var('x', dtypes.int32, -10, 10)
uop = x % 4
self.assertEqual(uop.vmin, 0) # modulo always positive or zero when divisor is positive
self.assertEqual(uop.vmax, 3) # max possible mod is 3 when dividing by 4
class TestVminVmaxVConst(unittest.TestCase):
def test_vmin_vmax_vconst_single_element(self):
# vmin and vmax for a single-element vector constant
uop = UOp.const(dtypes.int32.vec(1), (42,))
self.assertEqual(uop.vmin, 42)
self.assertEqual(uop.vmax, 42)
def test_vmin_vmax_vconst_multiple_elements(self):
# vmin and vmax for a multi-element vector constant
uop = UOp.const(dtypes.int32.vec(4), (10, 20, -5, 7))
self.assertEqual(uop.vmin, -5)
self.assertEqual(uop.vmax, 20)
def test_vmin_vmax_vconst_all_equal(self):
# vmin and vmax for a vector where all elements are equal
uop = UOp.const(dtypes.int32.vec(3), (7, 7, 7))
self.assertEqual(uop.vmin, 7)
self.assertEqual(uop.vmax, 7)
def test_vmin_vmax_vconst_with_negative_values(self):
# vmin and vmax for a vector constant containing negative values
uop = UOp.const(dtypes.int32.vec(4), (-10, -20, -5, -15))
self.assertEqual(uop.vmin, -20)
self.assertEqual(uop.vmax, -5)
def test_vmin_vmax_vconst_with_floats(self):
# vmin and vmax for a vector constant of float values
uop = UOp.const(dtypes.float32.vec(3), (1.5, -3.2, 0.0))
self.assertEqual(uop.vmin, -3.2)
self.assertEqual(uop.vmax, 1.5)
class TestConstFactor(unittest.TestCase):
def test_const_factor_constant(self):
# const_factor for a constant
uop = UOp.const(dtypes.int32, 42)
self.assertEqual(uop.const_factor(), 42)
def test_const_factor_addition(self):
# const_factor for an addition of constants
uop = UOp.const(dtypes.int32, 30) + UOp.const(dtypes.int32, 12)
self.assertEqual(uop.const_factor(), 6) # GCD(30, 12) = 6
def test_const_factor_multiplication(self):
# const_factor for a multiplication of constants
uop = UOp.const(dtypes.int32, 5) * UOp.const(dtypes.int32, 7)
self.assertEqual(uop.const_factor(), 5) # For multiplication, it's one of the factors
def test_const_factor_with_variable(self):
# const_factor for an expression involving a variable
x = UOp.define_var('x', dtypes.int32, 10, 20)
uop = x * 3
self.assertEqual(uop.const_factor(), 3)
def test_const_factor_division(self):
# const_factor for an expression with division
x = UOp.define_var('x', dtypes.int32, 10, 20)
uop = x // 4
self.assertEqual(uop.const_factor(), 1) # Division reduces the const_factor to 1
def test_const_factor_multiplication_of_var_and_const(self):
# const_factor for multiplication of a variable and a constant
x = UOp.define_var('x', dtypes.int32, 6, 18)
uop = x * 4
self.assertEqual(uop.const_factor(), 4) # Constant factor 4
@unittest.skip("broken")
def test_const_factor_multiplication_of_consts_and_vars(self):
# Multiplying constants and variables
x = UOp.define_var('x', dtypes.int32, 10, 20)
uop = (x * 3) * 5
self.assertEqual(uop.const_factor(), 15) # Constant multipliers are combined (3 * 5 = 15)
class TestDivides(unittest.TestCase):
def test_divides_constant_exact(self):
# Divides a constant by an exact divisor
uop = UOp.const(dtypes.int32, 42)
result = uop.divides(7)
self.assertIsNotNone(result)
self.assertEqual(result.const_factor(), 6) # 42 / 7 = 6
def test_divides_constant_inexact(self):
# Try to divide a constant by a non-exact divisor
uop = UOp.const(dtypes.int32, 42)
result = uop.divides(5)
self.assertIsNone(result) # 42 is not divisible by 5
@unittest.skip("broken")
def test_divides_variable_and_constant(self):
# Multiplying a variable by a constant, then dividing by the same constant
x = UOp.define_var('x', dtypes.int32, 10, 20)
uop = x * 6
result = uop.divides(6)
self.assertIsNotNone(result)
self.assertEqual(result, x) # (x * 6) / 6 = x
def test_divides_complex_expression(self):
# Dividing a more complex expression
x = UOp.define_var('x', dtypes.int32, 10, 20)
uop = (x * 6) + 18
result = uop.divides(6)
self.assertIsNotNone(result)
self.assertEqual(result.const_factor(), 1) # (x + 3), const_factor is 1
def test_divides_with_inexact_factors(self):
# Multiplying by a constant but dividing by a non-exact divisor
x = UOp.define_var('x', dtypes.int32, 15, 45)
uop = x * 4
result = uop.divides(3)
self.assertIsNone(result) # Cannot divide by 3, since 4 is not divisible by 3
if __name__ == '__main__':
unittest.main()

View File

@ -232,6 +232,8 @@ constant_folder = PatternMatcher([
# lambda x,y,alu: UOp(UOps.VECTORIZE, alu.dtype, (UOp(UOps.ALU, alu.dtype.scalar(), (x,y), alu.arg),)*alu.dtype.count)),
# VECTORIZE of a single element is just that element
(UPat(UOps.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
# VECTORIZE void is SINK
(UPat(UOps.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(UOps.SINK, dtypes.void, x.src)),
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
(UPat(UOps.GEP, src=(UPat(UOps.GEP, name='g2'),), name='g1'),
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),

View File

@ -435,7 +435,7 @@ class UOp(MathTrait):
def const_factor(self) -> int:
"""largest known int that divides self"""
if self.op is UOps.CONST: return self.arg
if self.op is UOps.VCONST: return math.gcd(*self.arg)
if self.op is UOps.VCONST: return functools.reduce(math.gcd, self.arg)
if self.op is UOps.ALU:
if self.arg is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is UOps.CONST else self.src[1].arg if self.src[1].op is UOps.CONST else 1