lazy const fold idiv 1 (#6285)

This commit is contained in:
chenyu 2024-08-26 10:29:59 -04:00 committed by GitHub
parent af7c04ff57
commit b76f0c875e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 0 deletions

View File

@ -79,6 +79,11 @@ class TestBinaryOpsConstFolding(unittest.TestCase):
def test_div_tensor_one(self):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) / Tensor.ones(4))
def test_idiv_literal_one(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) // 1)
def test_idiv_tensor_one(self):
_check_ast_count(0, Tensor([1, 2, 3, 4]) // Tensor.ones(4, dtype=dtypes.int32))
def test_pow_literal_zero(self):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** 0)
def test_pow_tensor_zero(self):

View File

@ -159,6 +159,7 @@ class LazyBuffer:
if op is BinaryOps.MUL:
if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0): return y if val == 1 else y.const(0)
if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0): return x if val == 1 else x.const(0)
if op is BinaryOps.IDIV and y.is_unrealized_unmasked_const() and y.base.arg == 1: return x
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))