remove UnaryOps.NEG from lazy.py (#6193)

* remove UnaryOps.NEG from lazy.py

* neg is no longer unary
This commit is contained in:
chenyu 2024-08-19 18:41:28 -04:00 committed by GitHub
parent 4d1b5781b5
commit 21d6739237
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 4 additions and 7 deletions

View File

@ -23,6 +23,7 @@ class TestUnaryOpsConstFolding(unittest.TestCase):
_check_ast_count(0, Tensor.ones(4).cast(dtypes.int16))
_check_ast_count(0, Tensor.full(4, fill_value=-1).cast(dtypes.uint16))
@unittest.expectedFailure # no two level fold at lazybuffer
def test_neg_folding(self):
_check_ast_count(0, Tensor([1, 2, 3]).mul(-1).neg())
_check_ast_count(0, Tensor([1, 2, 3]).neg().mul(-1))

View File

@ -27,7 +27,7 @@ if Device.DEFAULT == "LLVM":
integer_binary_operations = binary_operations + [(Tensor.xor, np.bitwise_xor), (Tensor.bitwise_and, np.bitwise_and),
(Tensor.bitwise_or, np.bitwise_or)]
unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), operator.neg, (Tensor.sin, np.sin),
unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), (Tensor.sin, np.sin),
(Tensor.sqrt, np.sqrt), (Tensor.reciprocal, np.reciprocal)]
# TODO: enable this (this is a dtype issue)

View File

@ -145,24 +145,20 @@ class LazyBuffer:
raise AssertionError(f"all dtypes must match {dts} on {op}")
assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
if op is TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
if op is UnaryOps.NEG: assert srcs[0].dtype != dtypes.bool, "UnaryOps.NEG does not accept dtype bool"
out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else srcs[-1].dtype
# const folding
if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs):
return self.cast(out_dtype).const(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
if op is UnaryOps.NEG and self.base.op is UnaryOps.NEG and self.base.realized is None: return self.base.srcs[0]
if op in BinaryOps:
x, y = self, in_srcs[0]
if op is BinaryOps.ADD:
if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y
if op is BinaryOps.MUL:
if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0, -1):
return y if val == 1 else y.const(0) if val == 0 else y.e(UnaryOps.NEG)
if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0, -1):
return x if val == 1 else x.const(0) if val == 0 else x.e(UnaryOps.NEG)
if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0): return y if val == 1 else y.const(0)
if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0): return x if val == 1 else x.const(0)
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))