diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 3b9a2979..4e28c853 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -19,6 +19,14 @@ simple_pm = PatternMatcher([ def to_uops_list(u:List[UOp]) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u))) class TestGraphRewriteEfficiency(unittest.TestCase): + def test_create_many_uops(self): + c1 = UOp.const(dtypes.int, 1) + c2 = UOp.const(dtypes.int, 2) + st = time.perf_counter() + uops = [UOp(UOps.ALU, dtypes.int, (c1, c2), BinaryOps.ADD) for _ in range(10000)] + et = time.perf_counter() - st + print(f"created {len(uops)} uops in {et*1000:.2f} ms") + def test_expand_rewrite(self): sink = UOp(UOps.SINK, None, arg=KernelInfo(local_dims=2, upcasted=4, dont_use_locals=False), src=( UOp(UOps.STORE, None, arg=None, src=(