match dataclass.replace in UOp.replace [run_process_replay] (#6792)

* UOp replace matching dataclass replace

* p2

* replace creates a copy
This commit is contained in:
qazal 2024-09-28 16:28:49 +08:00 committed by GitHub
parent 494b20e886
commit dab05ff070
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 2 deletions

View File

@ -374,6 +374,11 @@ class TestUOpMethod(unittest.TestCase):
self.assertEqual((gidx0*3+6).const_factor(), 3)
self.assertEqual((gidx0*3+1).const_factor(), 1)
def test_replace(self):
x = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.void), (), 0)
self.assertIs(x.replace(arg=None).arg, None)
with self.assertRaises(AssertionError): x.replace(field="a")
class TestUOpStr(unittest.TestCase):
def test_uop_str(self):
a = UOp(UOps.CONST, dtypes.float, (), 2.0) + UOp(UOps.CONST, dtypes.float, (), 3.0)

View File

@ -144,8 +144,9 @@ class UOp(MathTrait):
#if op is UOps.ALU and arg not in (BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE): assert all_same([dtype] + [x.dtype for x in src])
#if op is UOps.CAST: assert dtype.count == src[0].dtype.count, f"cast can't change vectorization {src[0].dtype} --> {dtype}"
self.op, self.dtype, self.src, self.arg = op, dtype, src, arg
def replace(self, op: Optional[UOps]=None, dtype:Optional[DType]=None, src: Optional[Tuple[UOp,...]]=None, arg:Any=None):
return UOp(op or self.op, dtype or self.dtype, self.src if src is None else src, self.arg if arg is None else arg)
def replace(self, **kwargs) -> UOp:
for k in kwargs: assert k in self.__slots__, f"unkown replace arg, expected one of {self.__slots__}, got {k}"
return UOp(kwargs.get("op", self.op), kwargs.get("dtype", self.dtype), kwargs.get("src", self.src), kwargs.get("arg", self.arg))
@property
def has_st(self) -> bool: return self.op not in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL, UOps.CONST, UOps.DEFINE_VAR}
@functools.cached_property