mirror of https://github.com/commaai/tinygrad.git
add more uops tests for vmin/vmax/const_factor/divides (#6533)
This commit is contained in:
parent
c447ec2190
commit
07bd6e070d
|
@ -0,0 +1,194 @@
|
||||||
|
import unittest
|
||||||
|
from tinygrad.ops import UOp, dtypes
|
||||||
|
|
||||||
|
class TestVminVmaxProperties(unittest.TestCase):
|
||||||
|
def test_vmin_vmax_constant(self):
|
||||||
|
# vmin and vmax for a constant
|
||||||
|
uop = UOp.const(dtypes.int32, 42)
|
||||||
|
self.assertEqual(uop.vmin, 42)
|
||||||
|
self.assertEqual(uop.vmax, 42)
|
||||||
|
|
||||||
|
def test_vmin_vmax_addition_with_variable(self):
|
||||||
|
# vmin and vmax for addition with a variable
|
||||||
|
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||||
|
uop = x + 5
|
||||||
|
self.assertEqual(uop.vmin, 15)
|
||||||
|
self.assertEqual(uop.vmax, 25)
|
||||||
|
|
||||||
|
def test_vmin_vmax_multiplication_with_variable(self):
|
||||||
|
# vmin and vmax for multiplication with a variable
|
||||||
|
x = UOp.define_var('x', dtypes.int32, -3, 4)
|
||||||
|
uop = x * 2
|
||||||
|
self.assertEqual(uop.vmin, -6)
|
||||||
|
self.assertEqual(uop.vmax, 8)
|
||||||
|
|
||||||
|
def test_vmin_vmax_with_negative_multiplication(self):
|
||||||
|
# vmin and vmax when multiplying by a negative number
|
||||||
|
x = UOp.define_var('x', dtypes.int32, 2, 5)
|
||||||
|
uop = x * -3
|
||||||
|
self.assertEqual(uop.vmin, -15)
|
||||||
|
self.assertEqual(uop.vmax, -6)
|
||||||
|
|
||||||
|
def test_vmin_vmax_nested_min_max(self):
|
||||||
|
# vmin and vmax with nested min/max operations
|
||||||
|
x = UOp.define_var('x', dtypes.int32, 0, 10)
|
||||||
|
uop = x.max(5).min(8)
|
||||||
|
self.assertEqual(uop.vmin, 5)
|
||||||
|
self.assertEqual(uop.vmax, 8)
|
||||||
|
|
||||||
|
class TestVminVmaxDivMod(unittest.TestCase):
|
||||||
|
def test_vmin_vmax_division_positive(self):
|
||||||
|
# vmin and vmax for division of a variable by a positive constant
|
||||||
|
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||||
|
uop = x // 2
|
||||||
|
self.assertEqual(uop.vmin, 5)
|
||||||
|
self.assertEqual(uop.vmax, 10)
|
||||||
|
|
||||||
|
def test_vmin_vmax_division_negative(self):
|
||||||
|
# vmin and vmax for division of a variable by a negative constant
|
||||||
|
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||||
|
uop = x // -2
|
||||||
|
self.assertEqual(uop.vmin, -10)
|
||||||
|
self.assertEqual(uop.vmax, -5)
|
||||||
|
|
||||||
|
def test_vmin_vmax_mod_positive(self):
|
||||||
|
# vmin and vmax for modulo of a variable by a positive constant
|
||||||
|
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||||
|
uop = x % 3
|
||||||
|
self.assertEqual(uop.vmin, 0)
|
||||||
|
self.assertEqual(uop.vmax, 2)
|
||||||
|
|
||||||
|
@unittest.skip("broken")
|
||||||
|
def test_vmin_vmax_mod_negative(self):
|
||||||
|
# vmin and vmax for modulo of a variable by a negative constant
|
||||||
|
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||||
|
uop = x % -3
|
||||||
|
self.assertEqual(uop.vmin, -2)
|
||||||
|
self.assertEqual(uop.vmax, 0)
|
||||||
|
|
||||||
|
def test_vmin_vmax_division_with_mixed_range(self):
|
||||||
|
# vmin and vmax for division of a variable with a range crossing zero
|
||||||
|
x = UOp.define_var('x', dtypes.int32, -10, 10)
|
||||||
|
uop = x // 3
|
||||||
|
self.assertEqual(uop.vmin, -4) # -10//3 = -4
|
||||||
|
self.assertEqual(uop.vmax, 3) # 10//3 = 3
|
||||||
|
|
||||||
|
def test_vmin_vmax_mod_with_mixed_range(self):
|
||||||
|
# vmin and vmax for modulo of a variable with a range crossing zero
|
||||||
|
x = UOp.define_var('x', dtypes.int32, -10, 10)
|
||||||
|
uop = x % 4
|
||||||
|
self.assertEqual(uop.vmin, 0) # modulo always positive or zero when divisor is positive
|
||||||
|
self.assertEqual(uop.vmax, 3) # max possible mod is 3 when dividing by 4
|
||||||
|
|
||||||
|
class TestVminVmaxVConst(unittest.TestCase):
|
||||||
|
def test_vmin_vmax_vconst_single_element(self):
|
||||||
|
# vmin and vmax for a single-element vector constant
|
||||||
|
uop = UOp.const(dtypes.int32.vec(1), (42,))
|
||||||
|
self.assertEqual(uop.vmin, 42)
|
||||||
|
self.assertEqual(uop.vmax, 42)
|
||||||
|
|
||||||
|
def test_vmin_vmax_vconst_multiple_elements(self):
|
||||||
|
# vmin and vmax for a multi-element vector constant
|
||||||
|
uop = UOp.const(dtypes.int32.vec(4), (10, 20, -5, 7))
|
||||||
|
self.assertEqual(uop.vmin, -5)
|
||||||
|
self.assertEqual(uop.vmax, 20)
|
||||||
|
|
||||||
|
def test_vmin_vmax_vconst_all_equal(self):
|
||||||
|
# vmin and vmax for a vector where all elements are equal
|
||||||
|
uop = UOp.const(dtypes.int32.vec(3), (7, 7, 7))
|
||||||
|
self.assertEqual(uop.vmin, 7)
|
||||||
|
self.assertEqual(uop.vmax, 7)
|
||||||
|
|
||||||
|
def test_vmin_vmax_vconst_with_negative_values(self):
|
||||||
|
# vmin and vmax for a vector constant containing negative values
|
||||||
|
uop = UOp.const(dtypes.int32.vec(4), (-10, -20, -5, -15))
|
||||||
|
self.assertEqual(uop.vmin, -20)
|
||||||
|
self.assertEqual(uop.vmax, -5)
|
||||||
|
|
||||||
|
def test_vmin_vmax_vconst_with_floats(self):
|
||||||
|
# vmin and vmax for a vector constant of float values
|
||||||
|
uop = UOp.const(dtypes.float32.vec(3), (1.5, -3.2, 0.0))
|
||||||
|
self.assertEqual(uop.vmin, -3.2)
|
||||||
|
self.assertEqual(uop.vmax, 1.5)
|
||||||
|
|
||||||
|
class TestConstFactor(unittest.TestCase):
|
||||||
|
def test_const_factor_constant(self):
|
||||||
|
# const_factor for a constant
|
||||||
|
uop = UOp.const(dtypes.int32, 42)
|
||||||
|
self.assertEqual(uop.const_factor(), 42)
|
||||||
|
|
||||||
|
def test_const_factor_addition(self):
|
||||||
|
# const_factor for an addition of constants
|
||||||
|
uop = UOp.const(dtypes.int32, 30) + UOp.const(dtypes.int32, 12)
|
||||||
|
self.assertEqual(uop.const_factor(), 6) # GCD(30, 12) = 6
|
||||||
|
|
||||||
|
def test_const_factor_multiplication(self):
|
||||||
|
# const_factor for a multiplication of constants
|
||||||
|
uop = UOp.const(dtypes.int32, 5) * UOp.const(dtypes.int32, 7)
|
||||||
|
self.assertEqual(uop.const_factor(), 5) # For multiplication, it's one of the factors
|
||||||
|
|
||||||
|
def test_const_factor_with_variable(self):
|
||||||
|
# const_factor for an expression involving a variable
|
||||||
|
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||||
|
uop = x * 3
|
||||||
|
self.assertEqual(uop.const_factor(), 3)
|
||||||
|
|
||||||
|
def test_const_factor_division(self):
|
||||||
|
# const_factor for an expression with division
|
||||||
|
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||||
|
uop = x // 4
|
||||||
|
self.assertEqual(uop.const_factor(), 1) # Division reduces the const_factor to 1
|
||||||
|
|
||||||
|
def test_const_factor_multiplication_of_var_and_const(self):
|
||||||
|
# const_factor for multiplication of a variable and a constant
|
||||||
|
x = UOp.define_var('x', dtypes.int32, 6, 18)
|
||||||
|
uop = x * 4
|
||||||
|
self.assertEqual(uop.const_factor(), 4) # Constant factor 4
|
||||||
|
|
||||||
|
@unittest.skip("broken")
|
||||||
|
def test_const_factor_multiplication_of_consts_and_vars(self):
|
||||||
|
# Multiplying constants and variables
|
||||||
|
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||||
|
uop = (x * 3) * 5
|
||||||
|
self.assertEqual(uop.const_factor(), 15) # Constant multipliers are combined (3 * 5 = 15)
|
||||||
|
|
||||||
|
class TestDivides(unittest.TestCase):
|
||||||
|
def test_divides_constant_exact(self):
|
||||||
|
# Divides a constant by an exact divisor
|
||||||
|
uop = UOp.const(dtypes.int32, 42)
|
||||||
|
result = uop.divides(7)
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
self.assertEqual(result.const_factor(), 6) # 42 / 7 = 6
|
||||||
|
|
||||||
|
def test_divides_constant_inexact(self):
|
||||||
|
# Try to divide a constant by a non-exact divisor
|
||||||
|
uop = UOp.const(dtypes.int32, 42)
|
||||||
|
result = uop.divides(5)
|
||||||
|
self.assertIsNone(result) # 42 is not divisible by 5
|
||||||
|
|
||||||
|
@unittest.skip("broken")
|
||||||
|
def test_divides_variable_and_constant(self):
|
||||||
|
# Multiplying a variable by a constant, then dividing by the same constant
|
||||||
|
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||||
|
uop = x * 6
|
||||||
|
result = uop.divides(6)
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
self.assertEqual(result, x) # (x * 6) / 6 = x
|
||||||
|
|
||||||
|
def test_divides_complex_expression(self):
|
||||||
|
# Dividing a more complex expression
|
||||||
|
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||||
|
uop = (x * 6) + 18
|
||||||
|
result = uop.divides(6)
|
||||||
|
self.assertIsNotNone(result)
|
||||||
|
self.assertEqual(result.const_factor(), 1) # (x + 3), const_factor is 1
|
||||||
|
|
||||||
|
def test_divides_with_inexact_factors(self):
|
||||||
|
# Multiplying by a constant but dividing by a non-exact divisor
|
||||||
|
x = UOp.define_var('x', dtypes.int32, 15, 45)
|
||||||
|
uop = x * 4
|
||||||
|
result = uop.divides(3)
|
||||||
|
self.assertIsNone(result) # Cannot divide by 3, since 4 is not divisible by 3
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
|
@ -232,6 +232,8 @@ constant_folder = PatternMatcher([
|
||||||
# lambda x,y,alu: UOp(UOps.VECTORIZE, alu.dtype, (UOp(UOps.ALU, alu.dtype.scalar(), (x,y), alu.arg),)*alu.dtype.count)),
|
# lambda x,y,alu: UOp(UOps.VECTORIZE, alu.dtype, (UOp(UOps.ALU, alu.dtype.scalar(), (x,y), alu.arg),)*alu.dtype.count)),
|
||||||
# VECTORIZE of a single element is just that element
|
# VECTORIZE of a single element is just that element
|
||||||
(UPat(UOps.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
(UPat(UOps.VECTORIZE, src=(UPat(name='x'),)), lambda x: x),
|
||||||
|
# VECTORIZE void is SINK
|
||||||
|
(UPat(UOps.VECTORIZE, dtype=dtypes.void, name='x'), lambda x: UOp(UOps.SINK, dtypes.void, x.src)),
|
||||||
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
|
# GEP/VECTORIZE, GEP/GEP, GEP/CONST, GEP/VCONST
|
||||||
(UPat(UOps.GEP, src=(UPat(UOps.GEP, name='g2'),), name='g1'),
|
(UPat(UOps.GEP, src=(UPat(UOps.GEP, name='g2'),), name='g1'),
|
||||||
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
|
lambda g1, g2: g2.src[0].gep(tuple(g2.arg[g1.arg[i]] for i in range(g1.dtype.count)))),
|
||||||
|
|
|
@ -435,7 +435,7 @@ class UOp(MathTrait):
|
||||||
def const_factor(self) -> int:
|
def const_factor(self) -> int:
|
||||||
"""largest known int that divides self"""
|
"""largest known int that divides self"""
|
||||||
if self.op is UOps.CONST: return self.arg
|
if self.op is UOps.CONST: return self.arg
|
||||||
if self.op is UOps.VCONST: return math.gcd(*self.arg)
|
if self.op is UOps.VCONST: return functools.reduce(math.gcd, self.arg)
|
||||||
if self.op is UOps.ALU:
|
if self.op is UOps.ALU:
|
||||||
if self.arg is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
|
if self.arg is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor())
|
||||||
if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is UOps.CONST else self.src[1].arg if self.src[1].op is UOps.CONST else 1
|
if self.arg is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is UOps.CONST else self.src[1].arg if self.src[1].op is UOps.CONST else 1
|
||||||
|
|
Loading…
Reference in New Issue