From b918e3c255851c805e730407dc89c37ef839ff7e Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Sun, 11 Aug 2024 17:17:05 +0800 Subject: [PATCH] cache assert_equiv_uops (#6033) --- test/helpers.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/test/helpers.py b/test/helpers.py index f972dc20..92a826b6 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -1,4 +1,5 @@ import sys, unittest +from typing import Optional, Set, Tuple import numpy as np from tinygrad import Tensor, Device, dtypes from tinygrad.codegen.uops import UOp @@ -54,7 +55,10 @@ def rand_for_dtype(dt:DType, size:int): return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt)) class TestUOps(unittest.TestCase): - def assert_equiv_uops(self, uop1:UOp, uop2:UOp): + def assert_equiv_uops(self, uop1:UOp, uop2:UOp, cache:Optional[Set[Tuple[UOp, UOp]]]=None): + if cache is None: cache = set() + if (uop1, uop2) in cache: return + cache.add((uop1, uop2)) # NOTE: direct UOps __eq__ is comparing object reference, use this function to compare two uops try: self.assertIs(uop1.op, uop2.op) @@ -65,4 +69,4 @@ class TestUOps(unittest.TestCase): except AssertionError as e: print(f"{uop1=}") print(f"{uop2=}") - raise e \ No newline at end of file + raise e