mirror of https://github.com/commaai/tinygrad.git
match dataclass.replace in UOp.replace [run_process_replay] (#6792)
* UOp replace matching dataclass replace * p2 * replace creates a copy
This commit is contained in:
parent
494b20e886
commit
dab05ff070
|
@ -374,6 +374,11 @@ class TestUOpMethod(unittest.TestCase):
|
|||
self.assertEqual((gidx0*3+6).const_factor(), 3)
|
||||
self.assertEqual((gidx0*3+1).const_factor(), 1)
|
||||
|
||||
def test_replace(self):
|
||||
x = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.void), (), 0)
|
||||
self.assertIs(x.replace(arg=None).arg, None)
|
||||
with self.assertRaises(AssertionError): x.replace(field="a")
|
||||
|
||||
class TestUOpStr(unittest.TestCase):
|
||||
def test_uop_str(self):
|
||||
a = UOp(UOps.CONST, dtypes.float, (), 2.0) + UOp(UOps.CONST, dtypes.float, (), 3.0)
|
||||
|
|
|
@ -144,8 +144,9 @@ class UOp(MathTrait):
|
|||
#if op is UOps.ALU and arg not in (BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE): assert all_same([dtype] + [x.dtype for x in src])
|
||||
#if op is UOps.CAST: assert dtype.count == src[0].dtype.count, f"cast can't change vectorization {src[0].dtype} --> {dtype}"
|
||||
self.op, self.dtype, self.src, self.arg = op, dtype, src, arg
|
||||
def replace(self, op: Optional[UOps]=None, dtype:Optional[DType]=None, src: Optional[Tuple[UOp,...]]=None, arg:Any=None):
|
||||
return UOp(op or self.op, dtype or self.dtype, self.src if src is None else src, self.arg if arg is None else arg)
|
||||
def replace(self, **kwargs) -> UOp:
|
||||
for k in kwargs: assert k in self.__slots__, f"unkown replace arg, expected one of {self.__slots__}, got {k}"
|
||||
return UOp(kwargs.get("op", self.op), kwargs.get("dtype", self.dtype), kwargs.get("src", self.src), kwargs.get("arg", self.arg))
|
||||
@property
|
||||
def has_st(self) -> bool: return self.op not in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL, UOps.CONST, UOps.DEFINE_VAR}
|
||||
@functools.cached_property
|
||||
|
|
Loading…
Reference in New Issue