add CMPNE tests in test_uops (#6196)

fixed the output_dtype for CMPNE and match the tests for CMPLT
This commit is contained in:
chenyu 2024-08-19 19:41:21 -04:00 committed by GitHub
parent 21d6739237
commit 10330a41c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 11 additions and 3 deletions

View File

@ -25,7 +25,7 @@ def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], src:Tuple[UOp, ...], ar
def _test_single_value(vals, op, dts):
uops = []
output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0]
output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1]
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), 0)
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), i+1) for i,dtype in enumerate(dts)]
loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts))
@ -41,7 +41,7 @@ def _test_single_value(vals, op, dts):
def _test_single_value_const(vals, op, dts):
uops = []
output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0]
output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1]
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), 0)
loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
@ -112,6 +112,7 @@ class TestFloatUOps(TestUOps):
def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b)
def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: a<b)
def test_cmpne(self): self._test_bop_fxn(BinaryOps.CMPNE, lambda a,b: a!=b)
# MOD isn't tested on floats
def test_where(self):
@ -136,7 +137,8 @@ class TestNonFloatUOps(TestUOps):
def test_mod_int32(self):
self._test_bop_fxn(BinaryOps.MOD,
lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], (dtypes.int32, dtypes.int32), no_b_zero=True)
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), (dtypes.int32, dtypes.int32))
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: int(a)<int(b), (dtypes.int32, dtypes.int32))
def test_cmpne_int32(self): self._test_bop_fxn(BinaryOps.CMPNE, lambda a,b: int(a)!=int(b), (dtypes.int32, dtypes.int32))
@unittest.skipUnless(is_dtype_supported(dtypes.bool), "dtype not supported")
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), (dtypes.bool, dtypes.bool))
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "dtype not supported")
@ -205,6 +207,12 @@ class TestExecALU(TestUOps):
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (True, False)), False)
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (True, True)), False)
def test_bool_cmpne(self):
self.assertEqual(exec_alu(BinaryOps.CMPNE, dtypes.bool, (False, False)), False)
self.assertEqual(exec_alu(BinaryOps.CMPNE, dtypes.bool, (False, True)), True)
self.assertEqual(exec_alu(BinaryOps.CMPNE, dtypes.bool, (True, False)), True)
self.assertEqual(exec_alu(BinaryOps.CMPNE, dtypes.bool, (True, True)), False)
def test_bool_where(self):
self.assertEqual(exec_alu(TernaryOps.WHERE, dtypes.bool, (False, False, False)), False)
self.assertEqual(exec_alu(TernaryOps.WHERE, dtypes.int, (False, 2, 4)), 4)