diff --git a/test/test_uops.py b/test/test_uops.py index 16f8937c..101fa4fe 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -25,7 +25,7 @@ def uop(uops:List[UOp], uop:UOps, dtype:Optional[DType], src:Tuple[UOp, ...], ar def _test_single_value(vals, op, dts): uops = [] - output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0] + output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1] buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), 0) buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), i+1) for i,dtype in enumerate(dts)] loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts)) @@ -41,7 +41,7 @@ def _test_single_value(vals, op, dts): def _test_single_value_const(vals, op, dts): uops = [] - output_dtype = dts[-1] if op is TernaryOps.WHERE else dtypes.bool if op is BinaryOps.CMPLT else dts[0] + output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1] buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), 0) loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts)) alu = uop(uops, UOps.ALU, output_dtype, loads, op) @@ -112,6 +112,7 @@ class TestFloatUOps(TestUOps): def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b) def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b)) def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: a