From 85a45164fb89fcba9032b95682184efa09a3fd4e Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 12 Oct 2024 22:36:24 +0800 Subject: [PATCH] remove pyint [pr] (#7016) * remove pyint * bump time on tp [pr] * dont truncate in const fold * remove dead code * Revert "dont truncate in const fold" This reverts commit 29c81db0f7880848b001c2728aa555a1ef17e7d3. * remove define_var --- test/helpers.py | 1 - test/test_dtype.py | 4 +-- test/test_profiler.py | 2 +- test/test_uop_graph.py | 26 +++++++++--------- test/unit/test_graph_rewrite.py | 16 +++++------ test/unit/test_shapetracker.py | 2 +- test/unit/test_simplify_valid_idx.py | 2 +- test/unit/test_uop_resolve.py | 30 ++++++++++----------- test/unit/test_uop_vmin_vmax.py | 40 ++++++++++++++-------------- tinygrad/codegen/lowerer.py | 10 +++---- tinygrad/codegen/uopgraph.py | 10 ++----- tinygrad/dtype.py | 3 +-- tinygrad/ops.py | 13 +++------ tinygrad/shape/view.py | 4 +-- 14 files changed, 74 insertions(+), 89 deletions(-) diff --git a/test/helpers.py b/test/helpers.py index 9fbed0c1..ac566493 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -32,7 +32,6 @@ def assert_jit_cache_len(fxn, expected_len): assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT): - if dtype == dtypes.pyint and device != "PYTHON": return False if dtype == dtypes.bfloat16: # NOTE: this requires bf16 buffer support return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX")) diff --git a/test/test_dtype.py b/test/test_dtype.py index 49506a62..67082a0e 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -12,7 +12,7 @@ from test.helpers import is_dtype_supported, rand_for_dtype settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False)) settings.load_profile("my_profile") -core_dtypes = list([v for k,v in DTYPES_DICT.items() if k != 'pyint']) +core_dtypes = list(DTYPES_DICT.values()) if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)] dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)] @@ -20,7 +20,7 @@ dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_sup def get_available_cast_dtypes(dtype: DType) -> List[DType]: if not is_dtype_supported(dtype): return [] # dont cast internal dtypes - return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_") and k != 'pyint'] + return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")] def _test_to_np(a:Tensor, np_dtype, target): if DEBUG >= 2: print(a) diff --git a/test/test_profiler.py b/test/test_profiler.py index f06a5008..bd2deffa 100644 --- a/test/test_profiler.py +++ b/test/test_profiler.py @@ -147,7 +147,7 @@ class TestProfiler(unittest.TestCase): transfer_node_1 = helper_profile_filter_node(profile, name=f"{Device.DEFAULT} -> {Device.DEFAULT}:1")[0] helper_validate_node(transfer_node_1, profile=profile, pid_name=Device.DEFAULT, tid_name="DMA") - assert 80 < transfer_node_1['dur'] < (16000 if CI else 1400), f"Duration is not in the range: {transfer_node_1['dur']}" + assert 80 < transfer_node_1['dur'] < (20000 if CI else 1400), f"Duration is not in the range: {transfer_node_1['dur']}" @unittest.skipIf(MOCKGPU and Device.DEFAULT == "AMD", "AMD mockgpu with indirect buffers does not support queue wait interrupts") def test_profile_deps(self): diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index f3bdac5b..741ca204 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -172,10 +172,10 @@ class TestGraphRewrite(unittest.TestCase): self.assertEqual(nout.src[1].arg, 3.0) def test_consts_go_last(self): - a = UOp.define_var('a', dtypes.int, 0, 1) - b = UOp.define_var('b', dtypes.int, 0, 1) - c = UOp.define_var('c', dtypes.int, 0, 1) - d = UOp.define_var('d', dtypes.int, 0, 1) + a = UOp.variable('a', 0, 1) + b = UOp.variable('b', 0, 1) + c = UOp.variable('c', 0, 1) + d = UOp.variable('d', 0, 1) outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b] for out in outs: sink = graph_rewrite(out, sym) @@ -196,7 +196,7 @@ class TestUOpGraph(unittest.TestCase): self.assertEqual(out.arg, 3.0) def test_where_same_fold(self): - v = UOp.define_var('tmp', dtypes.int, 0, 1) + v = UOp.variable('tmp', 0, 1) c0 = UOp(UOps.CONST, dtypes.int, arg=0) vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE) c1 = UOp(UOps.CONST, dtypes.float, arg=1.0) @@ -290,7 +290,7 @@ class TestUOpGraph(unittest.TestCase): for i in [2, 4, 8]: vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i))) var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i)) - acc = UOp.define_var('acc', dtypes.half.vec(i), 0, 1) + acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc)) uops = to_uops_list([wmma]) assert_equiv_uops(uops[0], acc) @@ -299,7 +299,7 @@ class TestUOpGraph(unittest.TestCase): for i in [2, 4, 8]: var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i)) vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i))) - acc = UOp.define_var('acc', dtypes.half.vec(i), 0, 1) + acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i)) wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc)) uops = to_uops_list([wmma]) assert_equiv_uops(uops[0], acc) @@ -366,7 +366,7 @@ class TestUOpGraph(unittest.TestCase): self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1) def test_depth_2_const_fold(self): - v = UOp.define_var("tmp", dtypes.int, 0, 1) + v = UOp.variable("tmp", 0, 1) c2 = UOp(UOps.CONST, dtypes.int, arg=2) c4 = UOp(UOps.CONST, dtypes.int, arg=4) vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD) @@ -620,8 +620,8 @@ class TestLoadStoreFolder(unittest.TestCase): def test_simple_load_dont_fold_different_gated(self): buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) - gate = UOp.define_var("g1", dtypes.bool, False, True) - gate2 = UOp.define_var("g2", dtypes.bool, False, True) + gate = UOp.variable("g1", False, True, dtypes.bool) + gate2 = UOp.variable("g2", False, True, dtypes.bool) load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)] sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load)) sink = float4_rewrite(sink) @@ -636,7 +636,7 @@ class TestLoadStoreFolder(unittest.TestCase): def test_simple_store_fold_gate(self): buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) - gate = UOp.define_var("g1", dtypes.bool, False, True) + gate = UOp.variable("g1", False, True, dtypes.bool) load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)] sink = UOp(UOps.SINK, dtypes.void, tuple(load)) sink = float4_rewrite(sink) @@ -647,8 +647,8 @@ class TestLoadStoreFolder(unittest.TestCase): def test_simple_store_dont_fold(self): buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float)) - gate = UOp.define_var("g1", dtypes.bool, False, True) - gate2 = UOp.define_var("g2", dtypes.bool, False, True) + gate = UOp.variable("g1", False, True, dtypes.bool) + gate2 = UOp.variable("g2", False, True, dtypes.bool) load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)] sink = UOp(UOps.SINK, dtypes.void, tuple(load)) sink = float4_rewrite(sink) diff --git a/test/unit/test_graph_rewrite.py b/test/unit/test_graph_rewrite.py index b5221f7e..4a92572f 100644 --- a/test/unit/test_graph_rewrite.py +++ b/test/unit/test_graph_rewrite.py @@ -94,20 +94,20 @@ class TestFoldingAndReduction(unittest.TestCase): class TestModuloAndDivisionFolding(unittest.TestCase): def test_full_graph_rewrite_modulo_folding_with_define_var(self): - x_var_uop = UOp.define_var('x', dtypes.int32, 0, 100) + x_var_uop = UOp.variable('x', 0, 100) optimized_mod_uop = apply_rewrite(((x_var_uop * 4) + 2) % 4) self.assertEqual(optimized_mod_uop.op, UOps.CONST) self.assertEqual(optimized_mod_uop.arg, 2) def test_full_graph_rewrite_division_folding_with_define_var(self): - n_var_uop = UOp.define_var('n', dtypes.int32, 1, 1000) + n_var_uop = UOp.variable('n', 1, 1000) optimized_div_uop = apply_rewrite((n_var_uop * 6) // 3) self.assertEqual(optimized_div_uop.op, UOps.ALU) self.assertEqual(optimized_div_uop.arg, BinaryOps.MUL) self.assertEqual(optimized_div_uop.src[1].arg, 2) def test_full_graph_rewrite_complex_mod_div_folding(self): - k_var_uop = UOp.define_var('k', dtypes.int32, 0, 50) + k_var_uop = UOp.variable('k', 0, 50) optimized_div_uop = apply_rewrite(((k_var_uop * 12 + 8) % 6) // 2) self.assertEqual(optimized_div_uop.op, UOps.CONST) self.assertEqual(optimized_div_uop.arg, 1) @@ -124,17 +124,17 @@ class TestModuloAndDivisionFolding(unittest.TestCase): if opt.op is UOps.VECTORIZE: self.assertFalse(all_same(opt.src)) def test_full_graph_rewrite_modulo_large_divisor(self): - x_var_uop = UOp.define_var('x', dtypes.int32, 1, 5) + x_var_uop = UOp.variable('x', 1, 5) self.assertIs(apply_rewrite(x_var_uop % 10), x_var_uop) def test_full_graph_rewrite_division_with_remainder(self): - x_var_uop = UOp.define_var('x', dtypes.int32, 7, 9) + x_var_uop = UOp.variable('x', 7, 9) optimized_sink = apply_rewrite(x_var_uop // 2) for x_value in range(7, 10): self.assertEqual(x_value // 2, evaluate_uop(optimized_sink, {'x': x_value})) def test_full_graph_rewrite_complex_mod_div_expression(self): - x_var_uop = UOp.define_var('x', dtypes.int32, 1, 10) + x_var_uop = UOp.variable('x', 1, 10) optimized_sink = apply_rewrite(((x_var_uop * 5) % 3) // 2) for x_value in range(1, 11): original_result = ((x_value * 5) % 3) // 2 @@ -152,14 +152,14 @@ class TestEdgeCasesAndSpecialOperations(unittest.TestCase): @unittest.skip("broken") def test_full_graph_rewrite_modulo_negative_dividend(self): - x_var_uop = UOp.define_var('x', dtypes.int32, -5, -1) + x_var_uop = UOp.variable('x', -5, -1) optimized_sink = full_graph_rewrite((x_var_uop % 3).sink()) for x_value in range(-5, 0): self.assertEqual(x_value % 3, evaluate_uop(optimized_sink.src[0], {'x': x_value})) @unittest.skip("broken") def test_full_graph_rewrite_division_negative_divisor(self): - x_var_uop = UOp.define_var('x', dtypes.int32, 1, 5) + x_var_uop = UOp.variable('x', 1, 5) optimized_sink = full_graph_rewrite((x_var_uop // -2).sink()) for x_value in range(1, 6): self.assertEqual(x_value // -2, evaluate_uop(optimized_sink.src[0], {'x': x_value})) diff --git a/test/unit/test_shapetracker.py b/test/unit/test_shapetracker.py index 5bd6927b..0da63f62 100644 --- a/test/unit/test_shapetracker.py +++ b/test/unit/test_shapetracker.py @@ -11,7 +11,7 @@ from tinygrad.codegen.uopgraph import sym from itertools import product def shapetracker_getitem(st:ShapeTracker, val:int): - idx, valid = st.reshape((st.size,)).to_indexed_uops([UOp.const(dtypes.pyint, val)]) + idx, valid = st.reshape((st.size,)).to_indexed_uops([UOp.const(dtypes.int, val)]) idx, valid = graph_rewrite(idx, sym), graph_rewrite(valid, sym) assert idx.op is UOps.CONST and valid.op is UOps.CONST return idx.arg, valid.arg diff --git a/test/unit/test_simplify_valid_idx.py b/test/unit/test_simplify_valid_idx.py index 85c0af81..266d184d 100644 --- a/test/unit/test_simplify_valid_idx.py +++ b/test/unit/test_simplify_valid_idx.py @@ -32,7 +32,7 @@ def render(uop:UOp) -> str: return fxn.split("val0 = ")[1].split(";")[0] def Special(expr, nmax): return UOp(UOps.SPECIAL, dtypes.int, (), (expr, nmax)) -def Variable(expr, nmin, nmax): return UOp.define_var(expr, dtypes.int, nmin, nmax) +def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax) def Range(n, nmax): return UOp(UOps.RANGE, dtypes.int, arg=(n, True), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, nmax),)) diff --git a/test/unit/test_uop_resolve.py b/test/unit/test_uop_resolve.py index 3f84d683..81b63475 100644 --- a/test/unit/test_uop_resolve.py +++ b/test/unit/test_uop_resolve.py @@ -44,7 +44,7 @@ class TestUOpResolve(unittest.TestCase): self.assertEqual((8 * UOp.const(dtypes.int, 4)).ssimplify(), 32) def test_ambiguous_less_than(self): - u = UOp.define_var("i", dtypes.pyint, 1, 10) + u = UOp.variable("i", 1, 10) self.assertTrue(resolve(u < 4)) self.assertFalse(resolve(u < 4, False)) self.assertTrue(resolve(u < 11, False)) @@ -56,64 +56,64 @@ class TestUOpResolve(unittest.TestCase): self.assertEqual(float(u), 11.5) def test_var_cmp_t(self): - u = UOp.define_var("i", dtypes.pyint, 1, 10) < 20 + u = UOp.variable("i", 1, 10) < 20 self.assertTrue(u) def test_var_cmp_t2(self): - u = UOp.define_var("i", dtypes.pyint, 1, 10)//2 < 20 + u = UOp.variable("i", 1, 10)//2 < 20 self.assertTrue(u) def test_var_cmp_f(self): - u = UOp.define_var("i", dtypes.pyint, 1, 10) < 1 + u = UOp.variable("i", 1, 10) < 1 self.assertFalse(u) def test_var_cmp_f2(self): - u = UOp.define_var("i", dtypes.pyint, 1, 10) > 11 + u = UOp.variable("i", 1, 10) > 11 self.assertFalse(u) def test_or_true(self): - u = UOp.define_var("b", dtypes.bool, False, True) | True + u = UOp.variable("b", False, True, dtypes.bool) | True self.assertTrue(u) def test_or_false(self): with self.assertRaises(ValueError): - u = UOp.define_var("b", dtypes.bool, False, True) | False + u = UOp.variable("b", False, True, dtypes.bool) | False self.assertTrue(u) def test_and_false(self): - u = UOp.define_var("b", dtypes.bool, False, True) & False + u = UOp.variable("b", False, True, dtypes.bool) & False self.assertFalse(u) def test_max(self): - x = UOp.define_var("x", dtypes.pyint, 1, 10) - y = UOp.define_var("y", dtypes.pyint, 5, 10) + x = UOp.variable("x", 1, 10) + y = UOp.variable("y", 5, 10) u = x.max(y) self.assertTrue(u < 20) self.assertFalse(u < 3) def test_x_lt_x(self): - x = UOp.define_var("i", dtypes.pyint, 1, 10) + x = UOp.variable("i", 1, 10) self.assertFalse(x < x) @unittest.expectedFailure def test_x_lt_xp1(self): - x = UOp.define_var("i", dtypes.pyint, 1, 10) + x = UOp.variable("i", 1, 10) self.assertTrue(x < (x+1)) def test_and_true(self): with self.assertRaises(ValueError): - u = UOp.define_var("b", dtypes.bool, False, True) & True + u = UOp.variable("b", False, True, dtypes.bool) & True self.assertFalse(u) @unittest.expectedFailure def test_var_cmp_range(self): - v = UOp.define_var("i", dtypes.pyint, 1, 10) + v = UOp.variable("i", 1, 10) u = (v > 4) | (v < 6) self.assertTrue(u) def test_var_cmp_assert(self): with self.assertRaises(ValueError): - u = UOp.define_var("i", dtypes.pyint, 1, 10) < 5 + u = UOp.variable("i", 1, 10) < 5 self.assertFalse(u) if __name__ == '__main__': diff --git a/test/unit/test_uop_vmin_vmax.py b/test/unit/test_uop_vmin_vmax.py index 6728e2d9..fbac2895 100644 --- a/test/unit/test_uop_vmin_vmax.py +++ b/test/unit/test_uop_vmin_vmax.py @@ -19,36 +19,36 @@ class TestVminVmaxProperties(unittest.TestCase): def test_vmin_vmax_addition_with_variable(self): # vmin and vmax for addition with a variable - x = UOp.define_var('x', dtypes.int32, 10, 20) + x = UOp.variable('x', 10, 20) uop = x + 5 self.assertEqual(uop.vmin, 15) self.assertEqual(uop.vmax, 25) def test_vmin_vmax_multiplication_with_variable(self): # vmin and vmax for multiplication with a variable - x = UOp.define_var('x', dtypes.int32, -3, 4) + x = UOp.variable('x', -3, 4) uop = x * 2 self.assertEqual(uop.vmin, -6) self.assertEqual(uop.vmax, 8) def test_vmin_vmax_with_negative_multiplication(self): # vmin and vmax when multiplying by a negative number - x = UOp.define_var('x', dtypes.int32, 2, 5) + x = UOp.variable('x', 2, 5) uop = x * -3 self.assertEqual(uop.vmin, -15) self.assertEqual(uop.vmax, -6) def test_vmin_vmax_nested_min_max(self): # vmin and vmax with nested min/max operations - x = UOp.define_var('x', dtypes.int32, 0, 10) + x = UOp.variable('x', 0, 10) uop = x.max(5).min(8) self.assertEqual(uop.vmin, 5) self.assertEqual(uop.vmax, 8) def test_vmin_vmax_where(self): - x = UOp.define_var('x', dtypes.int, 0, 10) - y = UOp.define_var('y', dtypes.int, 1, 11) - z = UOp.define_var('z', dtypes.int, 2, 12) + x = UOp.variable('x', 0, 10) + y = UOp.variable('y', 1, 11) + z = UOp.variable('z', 2, 12) uop = x.lt(5).where(y, z) self.assertEqual(uop.vmin, 1) self.assertEqual(uop.vmax, 12) @@ -56,21 +56,21 @@ class TestVminVmaxProperties(unittest.TestCase): class TestVminVmaxDivMod(unittest.TestCase): def test_vmin_vmax_division_positive(self): # vmin and vmax for division of a variable by a positive constant - x = UOp.define_var('x', dtypes.int32, 10, 20) + x = UOp.variable('x', 10, 20) uop = x // 2 self.assertEqual(uop.vmin, 5) self.assertEqual(uop.vmax, 10) def test_vmin_vmax_division_negative(self): # vmin and vmax for division of a variable by a negative constant - x = UOp.define_var('x', dtypes.int32, 10, 20) + x = UOp.variable('x', 10, 20) uop = x // -2 self.assertEqual(uop.vmin, -10) self.assertEqual(uop.vmax, -5) def test_vmin_vmax_mod_positive(self): # vmin and vmax for modulo of a variable by a positive constant - x = UOp.define_var('x', dtypes.int32, 10, 20) + x = UOp.variable('x', 10, 20) uop = x % 3 self.assertEqual(uop.vmin, 0) self.assertEqual(uop.vmax, 2) @@ -78,21 +78,21 @@ class TestVminVmaxDivMod(unittest.TestCase): @unittest.skip("broken") def test_vmin_vmax_mod_negative(self): # vmin and vmax for modulo of a variable by a negative constant - x = UOp.define_var('x', dtypes.int32, 10, 20) + x = UOp.variable('x', 10, 20) uop = x % -3 self.assertEqual(uop.vmin, -2) self.assertEqual(uop.vmax, 0) def test_vmin_vmax_division_with_mixed_range(self): # vmin and vmax for division of a variable with a range crossing zero - x = UOp.define_var('x', dtypes.int32, -10, 10) + x = UOp.variable('x', -10, 10) uop = x // 3 self.assertEqual(uop.vmin, -4) # -10//3 = -4 self.assertEqual(uop.vmax, 3) # 10//3 = 3 def test_vmin_vmax_mod_with_mixed_range(self): # vmin and vmax for modulo of a variable with a range crossing zero - x = UOp.define_var('x', dtypes.int32, -10, 10) + x = UOp.variable('x', -10, 10) uop = x % 4 self.assertEqual(uop.vmin, 0) # modulo always positive or zero when divisor is positive self.assertEqual(uop.vmax, 3) # max possible mod is 3 when dividing by 4 @@ -146,26 +146,26 @@ class TestConstFactor(unittest.TestCase): def test_const_factor_with_variable(self): # const_factor for an expression involving a variable - x = UOp.define_var('x', dtypes.int32, 10, 20) + x = UOp.variable('x', 10, 20) uop = x * 3 self.assertEqual(uop.const_factor(), 3) def test_const_factor_division(self): # const_factor for an expression with division - x = UOp.define_var('x', dtypes.int32, 10, 20) + x = UOp.variable('x', 10, 20) uop = x // 4 self.assertEqual(uop.const_factor(), 1) # Division reduces the const_factor to 1 def test_const_factor_multiplication_of_var_and_const(self): # const_factor for multiplication of a variable and a constant - x = UOp.define_var('x', dtypes.int32, 6, 18) + x = UOp.variable('x', 6, 18) uop = x * 4 self.assertEqual(uop.const_factor(), 4) # Constant factor 4 @unittest.skip("broken") def test_const_factor_multiplication_of_consts_and_vars(self): # Multiplying constants and variables - x = UOp.define_var('x', dtypes.int32, 10, 20) + x = UOp.variable('x', 10, 20) uop = (x * 3) * 5 self.assertEqual(uop.const_factor(), 15) # Constant multipliers are combined (3 * 5 = 15) @@ -186,7 +186,7 @@ class TestDivides(unittest.TestCase): @unittest.skip("broken") def test_divides_variable_and_constant(self): # Multiplying a variable by a constant, then dividing by the same constant - x = UOp.define_var('x', dtypes.int32, 10, 20) + x = UOp.variable('x', 10, 20) uop = x * 6 result = uop.divides(6) self.assertIsNotNone(result) @@ -194,7 +194,7 @@ class TestDivides(unittest.TestCase): def test_divides_complex_expression(self): # Dividing a more complex expression - x = UOp.define_var('x', dtypes.int32, 10, 20) + x = UOp.variable('x', 10, 20) uop = (x * 6) + 18 result = uop.divides(6) self.assertIsNotNone(result) @@ -202,7 +202,7 @@ class TestDivides(unittest.TestCase): def test_divides_with_inexact_factors(self): # Multiplying by a constant but dividing by a non-exact divisor - x = UOp.define_var('x', dtypes.int32, 15, 45) + x = UOp.variable('x', 15, 45) uop = x * 4 result = uop.divides(3) self.assertIsNone(result) # Cannot divide by 3, since 4 is not divisible by 3 diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 96a1dbb3..59e009b6 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -33,7 +33,7 @@ def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]): def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse=False) -> List[UOp]: if reverse: dims = dims[::-1] limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims - ret = raw_idxs = [UOp(UOps.SPECIAL, dtypes.pyint, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)] + ret = raw_idxs = [UOp(UOps.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)] if limited != dims: ret = [] # cast for mypy, get_contraction won't be None @@ -75,22 +75,22 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext: get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max) else: # all loops are RANGES - idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, False)) + idxs = [UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, False)) for i,g in enumerate(full_shape[:first_reduce])] # reduce loops - idxs += [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, True)) + idxs += [UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, True)) for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)] # upcast loops for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted): assert isinstance(g, int), "needs to be int to upcast/unroll" - idxs.append(UOp(UOps.EXPAND, dtypes.pyint, (UOp.const(dtypes.pyint.vec(g), tuple(range(g))),), ((i,g),))) + idxs.append(UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),))) # late indexes (group for reduce) ridxs = idxs[:] for a in range(first_reduce, first_reduce+group_for_reduces): - ridxs[a] = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(full_shape[a])), (1000+a, True)) + ridxs[a] = UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(full_shape[a])), (1000+a, True)) return IndexContext(idxs, ridxs) diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 6d6d3ce7..ef9977c5 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -118,9 +118,9 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]: candidates = [] if uop.op is UOps.ALU and uop.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(uop, BinaryOps.ADD)): # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output - candidates.append([(Xi, UOp.define_var("fake", Xi.dtype, 1, Xi.vmax)) for Xi in _get_chain(uop, BinaryOps.ADD)]) + candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in _get_chain(uop, BinaryOps.ADD)]) # try checking the whole clause - candidates.append([(uop, UOp.define_var("fake", uop.dtype, uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1]))]) + candidates.append([(uop, UOp.variable("fake", uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1], uop.dtype))]) for candidate in candidates: newidxs:List[List[UOp]] = [[], []] @@ -538,9 +538,6 @@ reducer = PatternMatcher([ (UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), simplify_valid_image_load), ]) -no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.VCONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE, UOps.DEFINE_VAR), - name="x"), lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count), x.src, x.arg) if x.dtype.scalar() == dtypes.pyint else None)]) - # *** uop graph *** linearize_cnt = 0 @@ -552,9 +549,6 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp: acc_number = 0 sink = graph_rewrite(sink, sym) - # rewrite pyint to int32 - sink = graph_rewrite(sink, no_pyint) - # expand linearize_cnt += 1 if linearize_cnt != (de:=getenv("DEBUG_EXPAND", 0)) and de != -1: diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 10707e40..8a2adb18 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -86,7 +86,6 @@ class dtypes: def fields() -> Dict[str, DType]: return DTYPES_DICT # TODO: priority should be higher than bool void: Final[DType] = DType(-1, 0, "void", None, 1) - pyint: Final[DType] = DType(-1, 8, "pyint", None, 1) # arbitrary precision integer, same itemsize to int64 so min/max works bool: Final[DType] = DType(0, 1, "bool", '?', 1) int8: Final[DType] = DType(1, 1, "char", 'b', 1) uint8: Final[DType] = DType(2, 1, "unsigned char", 'B', 1) @@ -118,7 +117,7 @@ class dtypes: floats = (float16, bfloat16, float32, float64) uints = (uint8, uint16, uint32, uint64) - sints = (int8, int16, int32, int64, pyint) + sints = (int8, int16, int32, int64) ints = uints + sints if (env_default_float := getenv("DEFAULT_FLOAT", "")): diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 10c95cae..cbf0cf0c 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -282,11 +282,8 @@ class UOp(MathTrait): out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool return UOp(UOps.ALU, out_dtype, (self,)+src, arg) @staticmethod - def const(dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable): return UOp._const(dtype, b) - @staticmethod - def _const(dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable): + def const(dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable): # TODO: fix dtype of b.max after Variable is just an UOp - #if isinstance(b, Variable): return UOp.define_var(b.expr, dtype, b.min, cast(int, b.max)) if isinstance(b, UOp): return b.unbind()[0] if b.op is UOps.BIND else b if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore @@ -299,9 +296,8 @@ class UOp(MathTrait): # *** uop Variable stuff *** @staticmethod - def variable(name:str, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(name, min_val, max_val)) - @staticmethod - def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val)) + def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.int): + return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val)) @property def expr(self): assert self.op is UOps.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR" @@ -694,9 +690,6 @@ spec = PatternMatcher([ (UPat(UOps.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype), (UPat(UOps.SPECIAL, src=()), lambda: True), - # no pyint allowed here! - (UPat(UOps.ALU, dtype=dtypes.pyint), lambda: False), - # TODO: confirm the args of both of these are shapetrackers (UPat(UOps.VIEW, src=()), lambda: True), (UPat(UOps.VIEW, src=(UPat(),)), lambda: True), diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index eea20ea4..0fe9a78e 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -82,7 +82,7 @@ def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]: offs -= here * stride return result -def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x +def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x, int) else x @dataclass(frozen=True) class View: @@ -93,7 +93,7 @@ class View: contiguous:bool def to_indexed_uops(self:View, _idxs:Optional[List[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]: - idxs = [UOp.range(dtypes.pyint, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs + idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs iexpr = variable_to_uop(self.offset) for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)): if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st