mirror of https://github.com/commaai/tinygrad.git
remove float division from idiv in python_alu (#5777)
* removes float division from idiv in python_alu * add test * cleaner logic * pass clang unsigned literals correctly * suffix ULL instead of U --------- Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
parent
2c94316bd2
commit
573e0f9a48
|
@ -4,6 +4,7 @@ import torch
|
|||
from tinygrad.helpers import getenv, IMAGE, DEBUG, CI
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
import functools
|
||||
|
||||
if CI:
|
||||
import warnings
|
||||
|
@ -409,6 +410,10 @@ class TestOps(unittest.TestCase):
|
|||
def test_div_int(self):
|
||||
helper_test_op(None, lambda x,y: x/y, Tensor.div, forward_only=True, vals=np.array([[5, 6, 7],[1, 2, 3]], dtype=np.int32))
|
||||
helper_test_op(None, lambda x: x/2, lambda x: x/2, forward_only=True, vals=np.array([[3, 4, 5]], dtype=np.int32))
|
||||
torch_idiv, tiny_idiv = functools.partial(torch.div, rounding_mode="trunc"), functools.partial(Tensor.div, upcast=False)
|
||||
helper_test_op(None, torch_idiv, tiny_idiv, forward_only=True, vals=np.array([[5, -6, 7],[1, 2, 3]], dtype=np.int32))
|
||||
x = Tensor(2**64 - 1, dtype=dtypes.uint64).div(1, upcast=False)
|
||||
np.testing.assert_equal(x.numpy(), 2**64 - 1)
|
||||
def test_scalar_div(self):
|
||||
helper_test_op([(45,65)], lambda x: x/255)
|
||||
helper_test_op([(45,65)], lambda x: x/1)
|
||||
|
|
|
@ -113,7 +113,7 @@ python_alu: Dict[Op, Callable] = {
|
|||
BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift, BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add,
|
||||
BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt,
|
||||
BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_,
|
||||
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x, y: int(x/y) if y != 0 else x*math.inf,
|
||||
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf,
|
||||
TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z}
|
||||
|
||||
def truncate_fp16(x):
|
||||
|
|
Loading…
Reference in New Issue