mirror of https://github.com/commaai/tinygrad.git
lazy const fold idiv 1 (#6285)
This commit is contained in:
parent
af7c04ff57
commit
b76f0c875e
|
@ -79,6 +79,11 @@ class TestBinaryOpsConstFolding(unittest.TestCase):
|
|||
def test_div_tensor_one(self):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) / Tensor.ones(4))
|
||||
|
||||
def test_idiv_literal_one(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) // 1)
|
||||
def test_idiv_tensor_one(self):
|
||||
_check_ast_count(0, Tensor([1, 2, 3, 4]) // Tensor.ones(4, dtype=dtypes.int32))
|
||||
|
||||
def test_pow_literal_zero(self):
|
||||
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** 0)
|
||||
def test_pow_tensor_zero(self):
|
||||
|
|
|
@ -159,6 +159,7 @@ class LazyBuffer:
|
|||
if op is BinaryOps.MUL:
|
||||
if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0): return y if val == 1 else y.const(0)
|
||||
if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0): return x if val == 1 else x.const(0)
|
||||
if op is BinaryOps.IDIV and y.is_unrealized_unmasked_const() and y.base.arg == 1: return x
|
||||
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
|
||||
|
||||
|
|
Loading…
Reference in New Issue