remove float division from idiv in python_alu (#5777)

* removes float division from idiv in python_alu

* add test

* cleaner logic

* pass clang unsigned literals correctly

* suffix ULL instead of U

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
samm393 2024-07-29 17:14:12 +01:00 committed by GitHub
parent 2c94316bd2
commit 573e0f9a48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 6 additions and 1 deletions

View File

@ -4,6 +4,7 @@ import torch
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI
from tinygrad import Tensor, Device, dtypes
from tinygrad.tensor import _to_np_dtype
import functools
if CI:
import warnings
@ -409,6 +410,10 @@ class TestOps(unittest.TestCase):
def test_div_int(self):
helper_test_op(None, lambda x,y: x/y, Tensor.div, forward_only=True, vals=np.array([[5, 6, 7],[1, 2, 3]], dtype=np.int32))
helper_test_op(None, lambda x: x/2, lambda x: x/2, forward_only=True, vals=np.array([[3, 4, 5]], dtype=np.int32))
torch_idiv, tiny_idiv = functools.partial(torch.div, rounding_mode="trunc"), functools.partial(Tensor.div, upcast=False)
helper_test_op(None, torch_idiv, tiny_idiv, forward_only=True, vals=np.array([[5, -6, 7],[1, 2, 3]], dtype=np.int32))
x = Tensor(2**64 - 1, dtype=dtypes.uint64).div(1, upcast=False)
np.testing.assert_equal(x.numpy(), 2**64 - 1)
def test_scalar_div(self):
helper_test_op([(45,65)], lambda x: x/255)
helper_test_op([(45,65)], lambda x: x/1)

View File

@ -113,7 +113,7 @@ python_alu: Dict[Op, Callable] = {
BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift, BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add,
BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt,
BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_,
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x, y: int(x/y) if y != 0 else x*math.inf,
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z}
def truncate_fp16(x):