From 671259417f4394393ca2bcd62d1e20f4b4732d15 Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 26 Jul 2024 16:59:55 -0400 Subject: [PATCH] reuse UOp `__repr__` for NOp (#5738) --- test/test_uops.py | 6 +++++- tinygrad/codegen/uops.py | 5 ++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/test/test_uops.py b/test/test_uops.py index dd0a9bca..6bd24844 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -9,7 +9,7 @@ from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps, exec_alu # from tinygrad.renderer import Program from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel -from tinygrad.codegen.uops import UOps, UOp +from tinygrad.codegen.uops import UOps, NOp, UOp from tinygrad.codegen.uopgraph import UOpGraph from test.helpers import is_dtype_supported, TestUOps as TestEqUOps @@ -363,5 +363,9 @@ class TestUOpStr(TestEqUOps): sink = get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops.sink self.assert_equiv_uops(sink, eval(str(sink))) + def test_nop_str(self): + a = NOp(UOps.CONST, dtypes.float, (), 2.0, varname="c0") + NOp(UOps.CONST, dtypes.float, (), 3.0, varname="c1") + assert str(eval(str(a))) == str(a) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index b9a42581..fea7b9ff 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -43,7 +43,7 @@ class UOp: return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \ self.arg.value, self.dtype, self.src) def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple - def __repr__(self): return pretty_print(self, lambda x: f"UOp({x.op}, {x.dtype}, arg={x.arg}, src=(%s))") + def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.arg}, src=(%s))") # *** uop syntactic sugar def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,)) @@ -111,11 +111,10 @@ class UOp: return self.const(self.src[0].vmin.arg+self.src[1].vmin.arg), self.const(self.src[0].vmax.arg+self.src[1].vmax.arg) return None, None -@dataclass(frozen=True) +@dataclass(frozen=True, repr=False) # reuse repr from UOp class NOp(UOp): varname:Optional[str] = None src:Tuple[NOp, ...] = tuple() - def __repr__(self): return pretty_print(self, lambda x: f"NOp({x.op}, {x.dtype}, arg={x.arg}, src=(%s))") def name(self, name:Optional[str]=None): return NOp(self.op, self.dtype, self.src, self.arg, varname=name) @staticmethod def var(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.NOOP, dtype=dtype, varname=name)