mirror of https://github.com/commaai/tinygrad.git
cache assert_equiv_uops (#6033)
This commit is contained in:
parent
8d108f65a4
commit
b918e3c255
|
@ -1,4 +1,5 @@
|
|||
import sys, unittest
|
||||
from typing import Optional, Set, Tuple
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.codegen.uops import UOp
|
||||
|
@ -54,7 +55,10 @@ def rand_for_dtype(dt:DType, size:int):
|
|||
return np.random.uniform(-10, 10, size=size).astype(_to_np_dtype(dt))
|
||||
|
||||
class TestUOps(unittest.TestCase):
|
||||
def assert_equiv_uops(self, uop1:UOp, uop2:UOp):
|
||||
def assert_equiv_uops(self, uop1:UOp, uop2:UOp, cache:Optional[Set[Tuple[UOp, UOp]]]=None):
|
||||
if cache is None: cache = set()
|
||||
if (uop1, uop2) in cache: return
|
||||
cache.add((uop1, uop2))
|
||||
# NOTE: direct UOps __eq__ is comparing object reference, use this function to compare two uops
|
||||
try:
|
||||
self.assertIs(uop1.op, uop2.op)
|
||||
|
@ -65,4 +69,4 @@ class TestUOps(unittest.TestCase):
|
|||
except AssertionError as e:
|
||||
print(f"{uop1=}")
|
||||
print(f"{uop2=}")
|
||||
raise e
|
||||
raise e
|
||||
|
|
Loading…
Reference in New Issue