mirror of https://github.com/commaai/tinygrad.git
add CMPNE tests in test_uops (#6196)
fixed the output_dtype for CMPNE and match the tests for CMPLT
This commit is contained in:
parent
21d6739237
commit
10330a41c7
|
@ -25,7 +25,7 @@ def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], src:Tuple[UOp, ...], ar
|
|||
|
||||
def _test_single_value(vals, op, dts):
|
||||
uops = []
|
||||
output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0]
|
||||
output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1]
|
||||
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), 0)
|
||||
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), i+1) for i,dtype in enumerate(dts)]
|
||||
loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts))
|
||||
|
@ -41,7 +41,7 @@ def _test_single_value(vals, op, dts):
|
|||
|
||||
def _test_single_value_const(vals, op, dts):
|
||||
uops = []
|
||||
output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0]
|
||||
output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1]
|
||||
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), 0)
|
||||
loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
|
||||
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
|
||||
|
@ -112,6 +112,7 @@ class TestFloatUOps(TestUOps):
|
|||
def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b)
|
||||
def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
|
||||
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: a<b)
|
||||
def test_cmpne(self): self._test_bop_fxn(BinaryOps.CMPNE, lambda a,b: a!=b)
|
||||
# MOD isn't tested on floats
|
||||
|
||||
def test_where(self):
|
||||
|
@ -136,7 +137,8 @@ class TestNonFloatUOps(TestUOps):
|
|||
def test_mod_int32(self):
|
||||
self._test_bop_fxn(BinaryOps.MOD,
|
||||
lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], (dtypes.int32, dtypes.int32), no_b_zero=True)
|
||||
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), (dtypes.int32, dtypes.int32))
|
||||
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: int(a)<int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_cmpne_int32(self): self._test_bop_fxn(BinaryOps.CMPNE, lambda a,b: int(a)!=int(b), (dtypes.int32, dtypes.int32))
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bool), "dtype not supported")
|
||||
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), (dtypes.bool, dtypes.bool))
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "dtype not supported")
|
||||
|
@ -205,6 +207,12 @@ class TestExecALU(TestUOps):
|
|||
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (True, False)), False)
|
||||
self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (True, True)), False)
|
||||
|
||||
def test_bool_cmpne(self):
|
||||
self.assertEqual(exec_alu(BinaryOps.CMPNE, dtypes.bool, (False, False)), False)
|
||||
self.assertEqual(exec_alu(BinaryOps.CMPNE, dtypes.bool, (False, True)), True)
|
||||
self.assertEqual(exec_alu(BinaryOps.CMPNE, dtypes.bool, (True, False)), True)
|
||||
self.assertEqual(exec_alu(BinaryOps.CMPNE, dtypes.bool, (True, True)), False)
|
||||
|
||||
def test_bool_where(self):
|
||||
self.assertEqual(exec_alu(TernaryOps.WHERE, dtypes.bool, (False, False, False)), False)
|
||||
self.assertEqual(exec_alu(TernaryOps.WHERE, dtypes.int, (False, 2, 4)), 4)
|
||||
|
|
Loading…
Reference in New Issue