mirror of https://github.com/commaai/tinygrad.git
remove pyint [pr] (#7016)
* remove pyint * bump time on tp [pr] * dont truncate in const fold * remove dead code * Revert "dont truncate in const fold" This reverts commit 29c81db0f7880848b001c2728aa555a1ef17e7d3. * remove define_var
This commit is contained in:
parent
38d45dfba5
commit
85a45164fb
|
@ -32,7 +32,6 @@ def assert_jit_cache_len(fxn, expected_len):
|
|||
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len
|
||||
|
||||
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
|
||||
if dtype == dtypes.pyint and device != "PYTHON": return False
|
||||
if dtype == dtypes.bfloat16:
|
||||
# NOTE: this requires bf16 buffer support
|
||||
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
|
||||
|
|
|
@ -12,7 +12,7 @@ from test.helpers import is_dtype_supported, rand_for_dtype
|
|||
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
||||
settings.load_profile("my_profile")
|
||||
|
||||
core_dtypes = list([v for k,v in DTYPES_DICT.items() if k != 'pyint'])
|
||||
core_dtypes = list(DTYPES_DICT.values())
|
||||
if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove
|
||||
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)]
|
||||
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
|
||||
|
@ -20,7 +20,7 @@ dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_sup
|
|||
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
|
||||
if not is_dtype_supported(dtype): return []
|
||||
# dont cast internal dtypes
|
||||
return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_") and k != 'pyint']
|
||||
return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")]
|
||||
|
||||
def _test_to_np(a:Tensor, np_dtype, target):
|
||||
if DEBUG >= 2: print(a)
|
||||
|
|
|
@ -147,7 +147,7 @@ class TestProfiler(unittest.TestCase):
|
|||
|
||||
transfer_node_1 = helper_profile_filter_node(profile, name=f"{Device.DEFAULT} -> {Device.DEFAULT}:1")[0]
|
||||
helper_validate_node(transfer_node_1, profile=profile, pid_name=Device.DEFAULT, tid_name="DMA")
|
||||
assert 80 < transfer_node_1['dur'] < (16000 if CI else 1400), f"Duration is not in the range: {transfer_node_1['dur']}"
|
||||
assert 80 < transfer_node_1['dur'] < (20000 if CI else 1400), f"Duration is not in the range: {transfer_node_1['dur']}"
|
||||
|
||||
@unittest.skipIf(MOCKGPU and Device.DEFAULT == "AMD", "AMD mockgpu with indirect buffers does not support queue wait interrupts")
|
||||
def test_profile_deps(self):
|
||||
|
|
|
@ -172,10 +172,10 @@ class TestGraphRewrite(unittest.TestCase):
|
|||
self.assertEqual(nout.src[1].arg, 3.0)
|
||||
|
||||
def test_consts_go_last(self):
|
||||
a = UOp.define_var('a', dtypes.int, 0, 1)
|
||||
b = UOp.define_var('b', dtypes.int, 0, 1)
|
||||
c = UOp.define_var('c', dtypes.int, 0, 1)
|
||||
d = UOp.define_var('d', dtypes.int, 0, 1)
|
||||
a = UOp.variable('a', 0, 1)
|
||||
b = UOp.variable('b', 0, 1)
|
||||
c = UOp.variable('c', 0, 1)
|
||||
d = UOp.variable('d', 0, 1)
|
||||
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b]
|
||||
for out in outs:
|
||||
sink = graph_rewrite(out, sym)
|
||||
|
@ -196,7 +196,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
self.assertEqual(out.arg, 3.0)
|
||||
|
||||
def test_where_same_fold(self):
|
||||
v = UOp.define_var('tmp', dtypes.int, 0, 1)
|
||||
v = UOp.variable('tmp', 0, 1)
|
||||
c0 = UOp(UOps.CONST, dtypes.int, arg=0)
|
||||
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
|
@ -290,7 +290,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
for i in [2, 4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
|
||||
acc = UOp.define_var('acc', dtypes.half.vec(i), 0, 1)
|
||||
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[0], acc)
|
||||
|
@ -299,7 +299,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
for i in [2, 4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||
acc = UOp.define_var('acc', dtypes.half.vec(i), 0, 1)
|
||||
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[0], acc)
|
||||
|
@ -366,7 +366,7 @@ class TestUOpGraph(unittest.TestCase):
|
|||
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1)
|
||||
|
||||
def test_depth_2_const_fold(self):
|
||||
v = UOp.define_var("tmp", dtypes.int, 0, 1)
|
||||
v = UOp.variable("tmp", 0, 1)
|
||||
c2 = UOp(UOps.CONST, dtypes.int, arg=2)
|
||||
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
|
||||
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
|
||||
|
@ -620,8 +620,8 @@ class TestLoadStoreFolder(unittest.TestCase):
|
|||
|
||||
def test_simple_load_dont_fold_different_gated(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
gate = UOp.define_var("g1", dtypes.bool, False, True)
|
||||
gate2 = UOp.define_var("g2", dtypes.bool, False, True)
|
||||
gate = UOp.variable("g1", False, True, dtypes.bool)
|
||||
gate2 = UOp.variable("g2", False, True, dtypes.bool)
|
||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
|
||||
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
|
@ -636,7 +636,7 @@ class TestLoadStoreFolder(unittest.TestCase):
|
|||
|
||||
def test_simple_store_fold_gate(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
gate = UOp.define_var("g1", dtypes.bool, False, True)
|
||||
gate = UOp.variable("g1", False, True, dtypes.bool)
|
||||
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
|
@ -647,8 +647,8 @@ class TestLoadStoreFolder(unittest.TestCase):
|
|||
|
||||
def test_simple_store_dont_fold(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
gate = UOp.define_var("g1", dtypes.bool, False, True)
|
||||
gate2 = UOp.define_var("g2", dtypes.bool, False, True)
|
||||
gate = UOp.variable("g1", False, True, dtypes.bool)
|
||||
gate2 = UOp.variable("g2", False, True, dtypes.bool)
|
||||
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
|
|
|
@ -94,20 +94,20 @@ class TestFoldingAndReduction(unittest.TestCase):
|
|||
|
||||
class TestModuloAndDivisionFolding(unittest.TestCase):
|
||||
def test_full_graph_rewrite_modulo_folding_with_define_var(self):
|
||||
x_var_uop = UOp.define_var('x', dtypes.int32, 0, 100)
|
||||
x_var_uop = UOp.variable('x', 0, 100)
|
||||
optimized_mod_uop = apply_rewrite(((x_var_uop * 4) + 2) % 4)
|
||||
self.assertEqual(optimized_mod_uop.op, UOps.CONST)
|
||||
self.assertEqual(optimized_mod_uop.arg, 2)
|
||||
|
||||
def test_full_graph_rewrite_division_folding_with_define_var(self):
|
||||
n_var_uop = UOp.define_var('n', dtypes.int32, 1, 1000)
|
||||
n_var_uop = UOp.variable('n', 1, 1000)
|
||||
optimized_div_uop = apply_rewrite((n_var_uop * 6) // 3)
|
||||
self.assertEqual(optimized_div_uop.op, UOps.ALU)
|
||||
self.assertEqual(optimized_div_uop.arg, BinaryOps.MUL)
|
||||
self.assertEqual(optimized_div_uop.src[1].arg, 2)
|
||||
|
||||
def test_full_graph_rewrite_complex_mod_div_folding(self):
|
||||
k_var_uop = UOp.define_var('k', dtypes.int32, 0, 50)
|
||||
k_var_uop = UOp.variable('k', 0, 50)
|
||||
optimized_div_uop = apply_rewrite(((k_var_uop * 12 + 8) % 6) // 2)
|
||||
self.assertEqual(optimized_div_uop.op, UOps.CONST)
|
||||
self.assertEqual(optimized_div_uop.arg, 1)
|
||||
|
@ -124,17 +124,17 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
|
|||
if opt.op is UOps.VECTORIZE: self.assertFalse(all_same(opt.src))
|
||||
|
||||
def test_full_graph_rewrite_modulo_large_divisor(self):
|
||||
x_var_uop = UOp.define_var('x', dtypes.int32, 1, 5)
|
||||
x_var_uop = UOp.variable('x', 1, 5)
|
||||
self.assertIs(apply_rewrite(x_var_uop % 10), x_var_uop)
|
||||
|
||||
def test_full_graph_rewrite_division_with_remainder(self):
|
||||
x_var_uop = UOp.define_var('x', dtypes.int32, 7, 9)
|
||||
x_var_uop = UOp.variable('x', 7, 9)
|
||||
optimized_sink = apply_rewrite(x_var_uop // 2)
|
||||
for x_value in range(7, 10):
|
||||
self.assertEqual(x_value // 2, evaluate_uop(optimized_sink, {'x': x_value}))
|
||||
|
||||
def test_full_graph_rewrite_complex_mod_div_expression(self):
|
||||
x_var_uop = UOp.define_var('x', dtypes.int32, 1, 10)
|
||||
x_var_uop = UOp.variable('x', 1, 10)
|
||||
optimized_sink = apply_rewrite(((x_var_uop * 5) % 3) // 2)
|
||||
for x_value in range(1, 11):
|
||||
original_result = ((x_value * 5) % 3) // 2
|
||||
|
@ -152,14 +152,14 @@ class TestEdgeCasesAndSpecialOperations(unittest.TestCase):
|
|||
|
||||
@unittest.skip("broken")
|
||||
def test_full_graph_rewrite_modulo_negative_dividend(self):
|
||||
x_var_uop = UOp.define_var('x', dtypes.int32, -5, -1)
|
||||
x_var_uop = UOp.variable('x', -5, -1)
|
||||
optimized_sink = full_graph_rewrite((x_var_uop % 3).sink())
|
||||
for x_value in range(-5, 0):
|
||||
self.assertEqual(x_value % 3, evaluate_uop(optimized_sink.src[0], {'x': x_value}))
|
||||
|
||||
@unittest.skip("broken")
|
||||
def test_full_graph_rewrite_division_negative_divisor(self):
|
||||
x_var_uop = UOp.define_var('x', dtypes.int32, 1, 5)
|
||||
x_var_uop = UOp.variable('x', 1, 5)
|
||||
optimized_sink = full_graph_rewrite((x_var_uop // -2).sink())
|
||||
for x_value in range(1, 6):
|
||||
self.assertEqual(x_value // -2, evaluate_uop(optimized_sink.src[0], {'x': x_value}))
|
||||
|
|
|
@ -11,7 +11,7 @@ from tinygrad.codegen.uopgraph import sym
|
|||
from itertools import product
|
||||
|
||||
def shapetracker_getitem(st:ShapeTracker, val:int):
|
||||
idx, valid = st.reshape((st.size,)).to_indexed_uops([UOp.const(dtypes.pyint, val)])
|
||||
idx, valid = st.reshape((st.size,)).to_indexed_uops([UOp.const(dtypes.int, val)])
|
||||
idx, valid = graph_rewrite(idx, sym), graph_rewrite(valid, sym)
|
||||
assert idx.op is UOps.CONST and valid.op is UOps.CONST
|
||||
return idx.arg, valid.arg
|
||||
|
|
|
@ -32,7 +32,7 @@ def render(uop:UOp) -> str:
|
|||
return fxn.split("val0 = ")[1].split(";")[0]
|
||||
|
||||
def Special(expr, nmax): return UOp(UOps.SPECIAL, dtypes.int, (), (expr, nmax))
|
||||
def Variable(expr, nmin, nmax): return UOp.define_var(expr, dtypes.int, nmin, nmax)
|
||||
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
|
||||
def Range(n, nmax):
|
||||
return UOp(UOps.RANGE, dtypes.int, arg=(n, True), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, nmax),))
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ class TestUOpResolve(unittest.TestCase):
|
|||
self.assertEqual((8 * UOp.const(dtypes.int, 4)).ssimplify(), 32)
|
||||
|
||||
def test_ambiguous_less_than(self):
|
||||
u = UOp.define_var("i", dtypes.pyint, 1, 10)
|
||||
u = UOp.variable("i", 1, 10)
|
||||
self.assertTrue(resolve(u < 4))
|
||||
self.assertFalse(resolve(u < 4, False))
|
||||
self.assertTrue(resolve(u < 11, False))
|
||||
|
@ -56,64 +56,64 @@ class TestUOpResolve(unittest.TestCase):
|
|||
self.assertEqual(float(u), 11.5)
|
||||
|
||||
def test_var_cmp_t(self):
|
||||
u = UOp.define_var("i", dtypes.pyint, 1, 10) < 20
|
||||
u = UOp.variable("i", 1, 10) < 20
|
||||
self.assertTrue(u)
|
||||
|
||||
def test_var_cmp_t2(self):
|
||||
u = UOp.define_var("i", dtypes.pyint, 1, 10)//2 < 20
|
||||
u = UOp.variable("i", 1, 10)//2 < 20
|
||||
self.assertTrue(u)
|
||||
|
||||
def test_var_cmp_f(self):
|
||||
u = UOp.define_var("i", dtypes.pyint, 1, 10) < 1
|
||||
u = UOp.variable("i", 1, 10) < 1
|
||||
self.assertFalse(u)
|
||||
|
||||
def test_var_cmp_f2(self):
|
||||
u = UOp.define_var("i", dtypes.pyint, 1, 10) > 11
|
||||
u = UOp.variable("i", 1, 10) > 11
|
||||
self.assertFalse(u)
|
||||
|
||||
def test_or_true(self):
|
||||
u = UOp.define_var("b", dtypes.bool, False, True) | True
|
||||
u = UOp.variable("b", False, True, dtypes.bool) | True
|
||||
self.assertTrue(u)
|
||||
|
||||
def test_or_false(self):
|
||||
with self.assertRaises(ValueError):
|
||||
u = UOp.define_var("b", dtypes.bool, False, True) | False
|
||||
u = UOp.variable("b", False, True, dtypes.bool) | False
|
||||
self.assertTrue(u)
|
||||
|
||||
def test_and_false(self):
|
||||
u = UOp.define_var("b", dtypes.bool, False, True) & False
|
||||
u = UOp.variable("b", False, True, dtypes.bool) & False
|
||||
self.assertFalse(u)
|
||||
|
||||
def test_max(self):
|
||||
x = UOp.define_var("x", dtypes.pyint, 1, 10)
|
||||
y = UOp.define_var("y", dtypes.pyint, 5, 10)
|
||||
x = UOp.variable("x", 1, 10)
|
||||
y = UOp.variable("y", 5, 10)
|
||||
u = x.max(y)
|
||||
self.assertTrue(u < 20)
|
||||
self.assertFalse(u < 3)
|
||||
|
||||
def test_x_lt_x(self):
|
||||
x = UOp.define_var("i", dtypes.pyint, 1, 10)
|
||||
x = UOp.variable("i", 1, 10)
|
||||
self.assertFalse(x < x)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_x_lt_xp1(self):
|
||||
x = UOp.define_var("i", dtypes.pyint, 1, 10)
|
||||
x = UOp.variable("i", 1, 10)
|
||||
self.assertTrue(x < (x+1))
|
||||
|
||||
def test_and_true(self):
|
||||
with self.assertRaises(ValueError):
|
||||
u = UOp.define_var("b", dtypes.bool, False, True) & True
|
||||
u = UOp.variable("b", False, True, dtypes.bool) & True
|
||||
self.assertFalse(u)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_var_cmp_range(self):
|
||||
v = UOp.define_var("i", dtypes.pyint, 1, 10)
|
||||
v = UOp.variable("i", 1, 10)
|
||||
u = (v > 4) | (v < 6)
|
||||
self.assertTrue(u)
|
||||
|
||||
def test_var_cmp_assert(self):
|
||||
with self.assertRaises(ValueError):
|
||||
u = UOp.define_var("i", dtypes.pyint, 1, 10) < 5
|
||||
u = UOp.variable("i", 1, 10) < 5
|
||||
self.assertFalse(u)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -19,36 +19,36 @@ class TestVminVmaxProperties(unittest.TestCase):
|
|||
|
||||
def test_vmin_vmax_addition_with_variable(self):
|
||||
# vmin and vmax for addition with a variable
|
||||
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = x + 5
|
||||
self.assertEqual(uop.vmin, 15)
|
||||
self.assertEqual(uop.vmax, 25)
|
||||
|
||||
def test_vmin_vmax_multiplication_with_variable(self):
|
||||
# vmin and vmax for multiplication with a variable
|
||||
x = UOp.define_var('x', dtypes.int32, -3, 4)
|
||||
x = UOp.variable('x', -3, 4)
|
||||
uop = x * 2
|
||||
self.assertEqual(uop.vmin, -6)
|
||||
self.assertEqual(uop.vmax, 8)
|
||||
|
||||
def test_vmin_vmax_with_negative_multiplication(self):
|
||||
# vmin and vmax when multiplying by a negative number
|
||||
x = UOp.define_var('x', dtypes.int32, 2, 5)
|
||||
x = UOp.variable('x', 2, 5)
|
||||
uop = x * -3
|
||||
self.assertEqual(uop.vmin, -15)
|
||||
self.assertEqual(uop.vmax, -6)
|
||||
|
||||
def test_vmin_vmax_nested_min_max(self):
|
||||
# vmin and vmax with nested min/max operations
|
||||
x = UOp.define_var('x', dtypes.int32, 0, 10)
|
||||
x = UOp.variable('x', 0, 10)
|
||||
uop = x.max(5).min(8)
|
||||
self.assertEqual(uop.vmin, 5)
|
||||
self.assertEqual(uop.vmax, 8)
|
||||
|
||||
def test_vmin_vmax_where(self):
|
||||
x = UOp.define_var('x', dtypes.int, 0, 10)
|
||||
y = UOp.define_var('y', dtypes.int, 1, 11)
|
||||
z = UOp.define_var('z', dtypes.int, 2, 12)
|
||||
x = UOp.variable('x', 0, 10)
|
||||
y = UOp.variable('y', 1, 11)
|
||||
z = UOp.variable('z', 2, 12)
|
||||
uop = x.lt(5).where(y, z)
|
||||
self.assertEqual(uop.vmin, 1)
|
||||
self.assertEqual(uop.vmax, 12)
|
||||
|
@ -56,21 +56,21 @@ class TestVminVmaxProperties(unittest.TestCase):
|
|||
class TestVminVmaxDivMod(unittest.TestCase):
|
||||
def test_vmin_vmax_division_positive(self):
|
||||
# vmin and vmax for division of a variable by a positive constant
|
||||
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = x // 2
|
||||
self.assertEqual(uop.vmin, 5)
|
||||
self.assertEqual(uop.vmax, 10)
|
||||
|
||||
def test_vmin_vmax_division_negative(self):
|
||||
# vmin and vmax for division of a variable by a negative constant
|
||||
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = x // -2
|
||||
self.assertEqual(uop.vmin, -10)
|
||||
self.assertEqual(uop.vmax, -5)
|
||||
|
||||
def test_vmin_vmax_mod_positive(self):
|
||||
# vmin and vmax for modulo of a variable by a positive constant
|
||||
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = x % 3
|
||||
self.assertEqual(uop.vmin, 0)
|
||||
self.assertEqual(uop.vmax, 2)
|
||||
|
@ -78,21 +78,21 @@ class TestVminVmaxDivMod(unittest.TestCase):
|
|||
@unittest.skip("broken")
|
||||
def test_vmin_vmax_mod_negative(self):
|
||||
# vmin and vmax for modulo of a variable by a negative constant
|
||||
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = x % -3
|
||||
self.assertEqual(uop.vmin, -2)
|
||||
self.assertEqual(uop.vmax, 0)
|
||||
|
||||
def test_vmin_vmax_division_with_mixed_range(self):
|
||||
# vmin and vmax for division of a variable with a range crossing zero
|
||||
x = UOp.define_var('x', dtypes.int32, -10, 10)
|
||||
x = UOp.variable('x', -10, 10)
|
||||
uop = x // 3
|
||||
self.assertEqual(uop.vmin, -4) # -10//3 = -4
|
||||
self.assertEqual(uop.vmax, 3) # 10//3 = 3
|
||||
|
||||
def test_vmin_vmax_mod_with_mixed_range(self):
|
||||
# vmin and vmax for modulo of a variable with a range crossing zero
|
||||
x = UOp.define_var('x', dtypes.int32, -10, 10)
|
||||
x = UOp.variable('x', -10, 10)
|
||||
uop = x % 4
|
||||
self.assertEqual(uop.vmin, 0) # modulo always positive or zero when divisor is positive
|
||||
self.assertEqual(uop.vmax, 3) # max possible mod is 3 when dividing by 4
|
||||
|
@ -146,26 +146,26 @@ class TestConstFactor(unittest.TestCase):
|
|||
|
||||
def test_const_factor_with_variable(self):
|
||||
# const_factor for an expression involving a variable
|
||||
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = x * 3
|
||||
self.assertEqual(uop.const_factor(), 3)
|
||||
|
||||
def test_const_factor_division(self):
|
||||
# const_factor for an expression with division
|
||||
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = x // 4
|
||||
self.assertEqual(uop.const_factor(), 1) # Division reduces the const_factor to 1
|
||||
|
||||
def test_const_factor_multiplication_of_var_and_const(self):
|
||||
# const_factor for multiplication of a variable and a constant
|
||||
x = UOp.define_var('x', dtypes.int32, 6, 18)
|
||||
x = UOp.variable('x', 6, 18)
|
||||
uop = x * 4
|
||||
self.assertEqual(uop.const_factor(), 4) # Constant factor 4
|
||||
|
||||
@unittest.skip("broken")
|
||||
def test_const_factor_multiplication_of_consts_and_vars(self):
|
||||
# Multiplying constants and variables
|
||||
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = (x * 3) * 5
|
||||
self.assertEqual(uop.const_factor(), 15) # Constant multipliers are combined (3 * 5 = 15)
|
||||
|
||||
|
@ -186,7 +186,7 @@ class TestDivides(unittest.TestCase):
|
|||
@unittest.skip("broken")
|
||||
def test_divides_variable_and_constant(self):
|
||||
# Multiplying a variable by a constant, then dividing by the same constant
|
||||
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = x * 6
|
||||
result = uop.divides(6)
|
||||
self.assertIsNotNone(result)
|
||||
|
@ -194,7 +194,7 @@ class TestDivides(unittest.TestCase):
|
|||
|
||||
def test_divides_complex_expression(self):
|
||||
# Dividing a more complex expression
|
||||
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = (x * 6) + 18
|
||||
result = uop.divides(6)
|
||||
self.assertIsNotNone(result)
|
||||
|
@ -202,7 +202,7 @@ class TestDivides(unittest.TestCase):
|
|||
|
||||
def test_divides_with_inexact_factors(self):
|
||||
# Multiplying by a constant but dividing by a non-exact divisor
|
||||
x = UOp.define_var('x', dtypes.int32, 15, 45)
|
||||
x = UOp.variable('x', 15, 45)
|
||||
uop = x * 4
|
||||
result = uop.divides(3)
|
||||
self.assertIsNone(result) # Cannot divide by 3, since 4 is not divisible by 3
|
||||
|
|
|
@ -33,7 +33,7 @@ def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
|
|||
def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse=False) -> List[UOp]:
|
||||
if reverse: dims = dims[::-1]
|
||||
limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims
|
||||
ret = raw_idxs = [UOp(UOps.SPECIAL, dtypes.pyint, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
|
||||
ret = raw_idxs = [UOp(UOps.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
|
||||
if limited != dims:
|
||||
ret = []
|
||||
# cast for mypy, get_contraction won't be None
|
||||
|
@ -75,22 +75,22 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
|||
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
|
||||
else:
|
||||
# all loops are RANGES
|
||||
idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, False))
|
||||
idxs = [UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, False))
|
||||
for i,g in enumerate(full_shape[:first_reduce])]
|
||||
|
||||
# reduce loops
|
||||
idxs += [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, True))
|
||||
idxs += [UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, True))
|
||||
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
|
||||
|
||||
# upcast loops
|
||||
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
|
||||
assert isinstance(g, int), "needs to be int to upcast/unroll"
|
||||
idxs.append(UOp(UOps.EXPAND, dtypes.pyint, (UOp.const(dtypes.pyint.vec(g), tuple(range(g))),), ((i,g),)))
|
||||
idxs.append(UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))
|
||||
|
||||
# late indexes (group for reduce)
|
||||
ridxs = idxs[:]
|
||||
for a in range(first_reduce, first_reduce+group_for_reduces):
|
||||
ridxs[a] = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(full_shape[a])), (1000+a, True))
|
||||
ridxs[a] = UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(full_shape[a])), (1000+a, True))
|
||||
|
||||
return IndexContext(idxs, ridxs)
|
||||
|
||||
|
|
|
@ -118,9 +118,9 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]:
|
|||
candidates = []
|
||||
if uop.op is UOps.ALU and uop.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(uop, BinaryOps.ADD)):
|
||||
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
||||
candidates.append([(Xi, UOp.define_var("fake", Xi.dtype, 1, Xi.vmax)) for Xi in _get_chain(uop, BinaryOps.ADD)])
|
||||
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in _get_chain(uop, BinaryOps.ADD)])
|
||||
# try checking the whole clause
|
||||
candidates.append([(uop, UOp.define_var("fake", uop.dtype, uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1]))])
|
||||
candidates.append([(uop, UOp.variable("fake", uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1], uop.dtype))])
|
||||
|
||||
for candidate in candidates:
|
||||
newidxs:List[List[UOp]] = [[], []]
|
||||
|
@ -538,9 +538,6 @@ reducer = PatternMatcher([
|
|||
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), simplify_valid_image_load),
|
||||
])
|
||||
|
||||
no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.VCONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE, UOps.DEFINE_VAR),
|
||||
name="x"), lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count), x.src, x.arg) if x.dtype.scalar() == dtypes.pyint else None)])
|
||||
|
||||
# *** uop graph ***
|
||||
|
||||
linearize_cnt = 0
|
||||
|
@ -552,9 +549,6 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
|||
acc_number = 0
|
||||
sink = graph_rewrite(sink, sym)
|
||||
|
||||
# rewrite pyint to int32
|
||||
sink = graph_rewrite(sink, no_pyint)
|
||||
|
||||
# expand
|
||||
linearize_cnt += 1
|
||||
if linearize_cnt != (de:=getenv("DEBUG_EXPAND", 0)) and de != -1:
|
||||
|
|
|
@ -86,7 +86,6 @@ class dtypes:
|
|||
def fields() -> Dict[str, DType]: return DTYPES_DICT
|
||||
# TODO: priority should be higher than bool
|
||||
void: Final[DType] = DType(-1, 0, "void", None, 1)
|
||||
pyint: Final[DType] = DType(-1, 8, "pyint", None, 1) # arbitrary precision integer, same itemsize to int64 so min/max works
|
||||
bool: Final[DType] = DType(0, 1, "bool", '?', 1)
|
||||
int8: Final[DType] = DType(1, 1, "char", 'b', 1)
|
||||
uint8: Final[DType] = DType(2, 1, "unsigned char", 'B', 1)
|
||||
|
@ -118,7 +117,7 @@ class dtypes:
|
|||
|
||||
floats = (float16, bfloat16, float32, float64)
|
||||
uints = (uint8, uint16, uint32, uint64)
|
||||
sints = (int8, int16, int32, int64, pyint)
|
||||
sints = (int8, int16, int32, int64)
|
||||
ints = uints + sints
|
||||
|
||||
if (env_default_float := getenv("DEFAULT_FLOAT", "")):
|
||||
|
|
|
@ -282,11 +282,8 @@ class UOp(MathTrait):
|
|||
out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
||||
return UOp(UOps.ALU, out_dtype, (self,)+src, arg)
|
||||
@staticmethod
|
||||
def const(dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable): return UOp._const(dtype, b)
|
||||
@staticmethod
|
||||
def _const(dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable):
|
||||
def const(dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable):
|
||||
# TODO: fix dtype of b.max after Variable is just an UOp
|
||||
#if isinstance(b, Variable): return UOp.define_var(b.expr, dtype, b.min, cast(int, b.max))
|
||||
if isinstance(b, UOp): return b.unbind()[0] if b.op is UOps.BIND else b
|
||||
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
||||
return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
|
||||
|
@ -299,9 +296,8 @@ class UOp(MathTrait):
|
|||
# *** uop Variable stuff ***
|
||||
|
||||
@staticmethod
|
||||
def variable(name:str, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(name, min_val, max_val))
|
||||
@staticmethod
|
||||
def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
||||
def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.int):
|
||||
return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
||||
@property
|
||||
def expr(self):
|
||||
assert self.op is UOps.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
||||
|
@ -694,9 +690,6 @@ spec = PatternMatcher([
|
|||
(UPat(UOps.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
|
||||
(UPat(UOps.SPECIAL, src=()), lambda: True),
|
||||
|
||||
# no pyint allowed here!
|
||||
(UPat(UOps.ALU, dtype=dtypes.pyint), lambda: False),
|
||||
|
||||
# TODO: confirm the args of both of these are shapetrackers
|
||||
(UPat(UOps.VIEW, src=()), lambda: True),
|
||||
(UPat(UOps.VIEW, src=(UPat(),)), lambda: True),
|
||||
|
|
|
@ -82,7 +82,7 @@ def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
|
|||
offs -= here * stride
|
||||
return result
|
||||
|
||||
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x
|
||||
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x, int) else x
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class View:
|
||||
|
@ -93,7 +93,7 @@ class View:
|
|||
contiguous:bool
|
||||
|
||||
def to_indexed_uops(self:View, _idxs:Optional[List[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
|
||||
idxs = [UOp.range(dtypes.pyint, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs
|
||||
idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs
|
||||
iexpr = variable_to_uop(self.offset)
|
||||
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
|
||||
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
|
||||
|
|
Loading…
Reference in New Issue