reuse UOp `__repr__` for NOp (#5738)

This commit is contained in:
chenyu 2024-07-26 16:59:55 -04:00 committed by GitHub
parent b0c1dba299
commit 671259417f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 4 deletions

View File

@ -9,7 +9,7 @@ from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps, exec_alu #
from tinygrad.renderer import Program from tinygrad.renderer import Program
from tinygrad.engine.schedule import create_schedule from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
from tinygrad.codegen.uops import UOps, UOp from tinygrad.codegen.uops import UOps, NOp, UOp
from tinygrad.codegen.uopgraph import UOpGraph from tinygrad.codegen.uopgraph import UOpGraph
from test.helpers import is_dtype_supported, TestUOps as TestEqUOps from test.helpers import is_dtype_supported, TestUOps as TestEqUOps
@ -363,5 +363,9 @@ class TestUOpStr(TestEqUOps):
sink = get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops.sink sink = get_kernel(Device[Device.DEFAULT].renderer, t.schedule()[-1].ast).linearize().uops.sink
self.assert_equiv_uops(sink, eval(str(sink))) self.assert_equiv_uops(sink, eval(str(sink)))
def test_nop_str(self):
a = NOp(UOps.CONST, dtypes.float, (), 2.0, varname="c0") + NOp(UOps.CONST, dtypes.float, (), 3.0, varname="c1")
assert str(eval(str(a))) == str(a)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main(verbosity=2) unittest.main(verbosity=2)

View File

@ -43,7 +43,7 @@ class UOp:
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \ return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
self.arg.value, self.dtype, self.src) self.arg.value, self.dtype, self.src)
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
def __repr__(self): return pretty_print(self, lambda x: f"UOp({x.op}, {x.dtype}, arg={x.arg}, src=(%s))") def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.arg}, src=(%s))")
# *** uop syntactic sugar # *** uop syntactic sugar
def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x
def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,)) def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,))
@ -111,11 +111,10 @@ class UOp:
return self.const(self.src[0].vmin.arg+self.src[1].vmin.arg), self.const(self.src[0].vmax.arg+self.src[1].vmax.arg) return self.const(self.src[0].vmin.arg+self.src[1].vmin.arg), self.const(self.src[0].vmax.arg+self.src[1].vmax.arg)
return None, None return None, None
@dataclass(frozen=True) @dataclass(frozen=True, repr=False) # reuse repr from UOp
class NOp(UOp): class NOp(UOp):
varname:Optional[str] = None varname:Optional[str] = None
src:Tuple[NOp, ...] = tuple() src:Tuple[NOp, ...] = tuple()
def __repr__(self): return pretty_print(self, lambda x: f"NOp({x.op}, {x.dtype}, arg={x.arg}, src=(%s))")
def name(self, name:Optional[str]=None): return NOp(self.op, self.dtype, self.src, self.arg, varname=name) def name(self, name:Optional[str]=None): return NOp(self.op, self.dtype, self.src, self.arg, varname=name)
@staticmethod @staticmethod
def var(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.NOOP, dtype=dtype, varname=name) def var(name:Optional[str]=None, dtype:Optional[DType]=None): return NOp(UOps.NOOP, dtype=dtype, varname=name)