cache assert_equiv_uops (#6033)

This commit is contained in:
qazal 2024-08-11 17:17:05 +08:00 committed by GitHub
parent 8d108f65a4
commit b918e3c255
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 6 additions and 2 deletions

View File

@ -1,4 +1,5 @@
import sys, unittest
from typing import Optional, Set, Tuple
import numpy as np
from tinygrad import Tensor, Device, dtypes
from tinygrad.codegen.uops import UOp
@ -54,7 +55,10 @@ def rand_for_dtype(dt:DType, size:int):
return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt))
class TestUOps(unittest.TestCase):
def assert_equiv_uops(self, uop1:UOp, uop2:UOp):
def assert_equiv_uops(self, uop1:UOp, uop2:UOp, cache:Optional[Set[Tuple[UOp, UOp]]]=None):
if cache is None: cache = set()
if (uop1, uop2) in cache: return
cache.add((uop1, uop2))
# NOTE: direct UOps __eq__ is comparing object reference, use this function to compare two uops
try:
self.assertIs(uop1.op, uop2.op)
@ -65,4 +69,4 @@ class TestUOps(unittest.TestCase):
except AssertionError as e:
print(f"{uop1=}")
print(f"{uop2=}")
raise e
raise e