mirror of https://github.com/commaai/tinygrad.git
remove UnaryOps.NEG from lazy.py (#6193)
* remove UnaryOps.NEG from lazy.py * neg is no longer unary
This commit is contained in:
parent
4d1b5781b5
commit
21d6739237
|
@ -23,6 +23,7 @@ class TestUnaryOpsConstFolding(unittest.TestCase):
|
|||
_check_ast_count(0, Tensor.ones(4).cast(dtypes.int16))
|
||||
_check_ast_count(0, Tensor.full(4, fill_value=-1).cast(dtypes.uint16))
|
||||
|
||||
@unittest.expectedFailure # no two level fold at lazybuffer
|
||||
def test_neg_folding(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3]).mul(-1).neg())
|
||||
_check_ast_count(0, Tensor([1, 2, 3]).neg().mul(-1))
|
||||
|
|
|
@ -27,7 +27,7 @@ if Device.DEFAULT == "LLVM":
|
|||
|
||||
integer_binary_operations = binary_operations + [(Tensor.xor, np.bitwise_xor), (Tensor.bitwise_and, np.bitwise_and),
|
||||
(Tensor.bitwise_or, np.bitwise_or)]
|
||||
unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), operator.neg, (Tensor.sin, np.sin),
|
||||
unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), (Tensor.sin, np.sin),
|
||||
(Tensor.sqrt, np.sqrt), (Tensor.reciprocal, np.reciprocal)]
|
||||
|
||||
# TODO: enable this (this is a dtype issue)
|
||||
|
|
|
@ -145,24 +145,20 @@ class LazyBuffer:
|
|||
raise AssertionError(f"all dtypes must match {dts} on {op}")
|
||||
assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}"
|
||||
if op is TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool"
|
||||
if op is UnaryOps.NEG: assert srcs[0].dtype != dtypes.bool, "UnaryOps.NEG does not accept dtype bool"
|
||||
|
||||
out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else srcs[-1].dtype
|
||||
|
||||
# const folding
|
||||
if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs):
|
||||
return self.cast(out_dtype).const(exec_alu(op, out_dtype, [s.base.arg for s in srcs]))
|
||||
if op is UnaryOps.NEG and self.base.op is UnaryOps.NEG and self.base.realized is None: return self.base.srcs[0]
|
||||
if op in BinaryOps:
|
||||
x, y = self, in_srcs[0]
|
||||
if op is BinaryOps.ADD:
|
||||
if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x
|
||||
if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y
|
||||
if op is BinaryOps.MUL:
|
||||
if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0, -1):
|
||||
return y if val == 1 else y.const(0) if val == 0 else y.e(UnaryOps.NEG)
|
||||
if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0, -1):
|
||||
return x if val == 1 else x.const(0) if val == 0 else x.e(UnaryOps.NEG)
|
||||
if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0): return y if val == 1 else y.const(0)
|
||||
if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0): return x if val == 1 else x.const(0)
|
||||
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
|
||||
|
||||
|
|
Loading…
Reference in New Issue