Python div alu behavior differs slightly from others (#3596)

* Divide op rounding for negatives

* extra space

---------

Co-authored-by: Patrick Tsai <patosai@users.noreply.github.com>
This commit is contained in:
Patrick Tsai 2024-03-03 13:48:25 -05:00 committed by GitHub
parent 56d21d77b3
commit bc562c4747
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 1 deletions

View File

@ -113,6 +113,16 @@ class TestExecALU(TestUOps):
def test_sqrt(self):
self.assertEqual(exec_alu(UnaryOps.SQRT, dtypes.int, (0,)), 0)
def test_div(self):
self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.int8, (8, 2)), 4)
self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.int8, (7, 3)), 2)
self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.int8, (7, -3)), -2)
self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.int8, (-50, 6)), -8)
self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.float32, (8.0, 2.0)), 4.0)
self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.float32, (7.0, 3.0)), 2+(1.0/3.0))
self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.float32, (7.0, -3.0)), -2-(1.0/3.0))
@unittest.skip("not enabled because it's slow")
def test_overflow(self):
self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (250, 250)), 244)

View File

@ -37,7 +37,7 @@ python_alu = {
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin, UnaryOps.NEG: operator.neg,
BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.XOR: operator.xor,
BinaryOps.MAX: max, BinaryOps.CMPEQ: operator.eq, BinaryOps.CMPLT: operator.lt, BinaryOps.MOD: operator.mod,
BinaryOps.DIV: lambda x,y: x//y if isinstance(x, int) else (x/y if y != 0 else math.nan),
BinaryOps.DIV: lambda x,y: int(x/y) if isinstance(x, int) else (x/y if y != 0 else math.nan),
TernaryOps.WHERE: lambda x,y,z: y if x else z}
def exec_alu(arg, dtype, p):