more uops testing, who isn't passing right now... (#1522)

* more uops

* llvm refactor

* update test uops

* rest of the nodes

* ors and ands
This commit is contained in:
George Hotz 2023-08-15 09:07:26 -07:00 committed by GitHub
parent f8109b830c
commit 0b5930d406
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 69 additions and 45 deletions

View File

@ -37,36 +37,20 @@ def _test_single_value_const(tc, tt, vals, op):
prg([buf])
return buf.toCPU()[0]
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "only test for compiled backends")
class TestUOps(unittest.TestCase):
def _equal(self, v1, v2):
if not (math.isnan(v1) and math.isnan(v2)): self.assertAlmostEqual(v1, v2, places=5)
def _test_uop_fxn(self, bop, fxn, dt=dtypes.float32):
for f in [_test_single_value, _test_single_value_const]:
for a in [-2.0, 2.0]:
for a in [-2.0, 0.0, 1.0, 2.0]:
self._equal(f(Token('c', dt), [Token('a', dt)], [a], bop), fxn(a))
def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a))
def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('nan'))
def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a))
def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan'))
#def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1.0/a)
def _test_bop_fxn(self, bop, fxn, dt=dtypes.float32):
def _test_bop_fxn(self, bop, fxn, dt=dtypes.float32, no_b_zero=False):
for f in [_test_single_value, _test_single_value_const]:
for a in [-2.0, 2.0]:
for b in [-3.0, 3.0]:
for a in [-2.0, 0.0, 1.0, 2.0]:
for b in [-3.0, 1.0, 3.0] + ([] if no_b_zero else [0.0]):
self._equal(f(Token('c', dt), [Token('a', dt), Token('b', dt)], [a,b], bop), fxn(a,b))
def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b)
def test_sub(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: a-b)
def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b)
def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b)
def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b))
# MOD isn't tested
# doesn't work in LLVM
#def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b, dtypes.int32)
def _test_top_fxn(self, bop, fxn, dt=dtypes.float32):
for f in [_test_single_value, _test_single_value_const]:
@ -74,8 +58,37 @@ class TestUOps(unittest.TestCase):
for b in [-3.0, 3.0]:
for c in [-4.0, 4.0]:
self._equal(f(Token('d', dt), [Token('a', dt), Token('b', dt), Token('c', dt)], [a,b,c], bop), fxn(a,b,c))
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "only test for compiled backends")
class TestFloatUOps(TestUOps):
def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a))
def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan'))
def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a))
def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan'))
# this is not on most backends
#def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1.0/a if a != 0 else float('inf'))
def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b)
def test_sub(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: a-b)
def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b)
def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b if b != 0 else a*float('inf'))
def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b))
# MOD isn't tested on floats
def test_mulacc(self): self._test_top_fxn(TernaryOps.MULACC, lambda a,b,c: (a*b)+c)
def test_where(self): self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c)
# TODO: fix this on all the backends
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or Device.DEFAULT == "LLVM" or getenv('ARM64', False), "only test for compiled backends, broken on some")
class TestNonFloatUOps(TestUOps):
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), dtypes.int32)
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), dtypes.int32)
def test_div_int32(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), dtypes.int32, no_b_zero=True)
def test_mod_int32(self): self._test_bop_fxn(BinaryOps.MOD, lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], dtypes.int32, no_b_zero=True)
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), dtypes.int32)
@unittest.skipIf(Device.DEFAULT == "CLANG", "broken in CLANG")
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), dtypes.bool)
if __name__ == '__main__':
unittest.main(verbosity=2)

View File

@ -124,7 +124,7 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.ex
class ASTRunner:
def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args): print(prg)
if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args or not runtime_args['binary']): print(prg)
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
def build(self, runtime):

View File

@ -35,7 +35,7 @@ class CStyleLanguage(NamedTuple):
UnaryOps.SQRT: lambda x: f"sqrt({x})",
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
BinaryOps.MAX: lambda a,b: f"max({a},{b})",
BinaryOps.MAX: lambda a,b: f"max({a},{b})", BinaryOps.MOD: lambda a,b: f"({a}%{b})",
BinaryOps.CMPLT: lambda a,b: f"({a}<{b})", TernaryOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})",
TernaryOps.WHERE: lambda a,b,c: f"({a}!=0?{b}:{c})"
}

View File

@ -32,6 +32,35 @@ code_for_op: Final[Dict[Op, Callable]] = {
TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), y, z, flags=('fast',)),
}
dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}
def cast(bb, val, input_type, output_type):
if input_type == output_type: return val
if output_type == dtypes.float32:
if dtypes.is_int(input_type) or input_type == dtypes.bool:
val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(input_type) or input_type == dtypes.bool else bb[-1].sitofp(val, ir.FloatType())
elif input_type == dtypes.bfloat16:
val = bb[-1].sext(val, ir.IntType(32))
val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16))
val = bb[-1].bitcast(val, ir.FloatType())
else:
val = bb[-1].fpext(val, ir.FloatType())
return val
if input_type == dtypes.float32:
if dtypes.is_int(output_type) or output_type == dtypes.bool:
val = bb[-1].fptoui(val, dtype_to_llvm_dtype[output_type]) if dtypes.is_unsigned(output_type) or output_type == dtypes.bool else bb[-1].fptosi(val, dtype_to_llvm_dtype[output_type])
elif output_type == dtypes.bfloat16:
val = bb[-1].bitcast(val, ir.IntType(32))
val = bb[-1].lshr(val, ir.Constant(ir.IntType(32), 16))
val = bb[-1].trunc(val, ir.IntType(16))
else:
val = bb[-1].fptrunc(val, dtype_to_llvm_dtype[output_type])
return val
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[List[int]], Optional[List[int]]]:
# all llvm stuff goes into a module
module = ir.Module(name=__file__)
@ -41,7 +70,6 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
# create llvm function
dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}
func_dtypes = [dtype_to_llvm_dtype[dtype] for dtype in buf_to_dtype.values()]
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name=function_name)
@ -84,9 +112,9 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
bb[-2].cbranch(bb[-2].icmp_unsigned("==", idx_p1, int_const(var.max+1)), bb[-1]._block, block._block)
if uop == UOps.LOAD:
assert newvar is not None and isinstance(args, (MemOp, ConstOp))
assert newvar.dtype == dtypes.float, "newvar must be float"
valid = args.valid.render(render_llvm, bb[-1])
if isinstance(args, ConstOp):
assert newvar.dtype == dtypes.float, "newvar must be float"
if args.valid.min == 0 and args.valid.max == 1:
val = bb[-1].select(valid, ir.Constant(ir.FloatType(), args.value), ir.Constant(ir.FloatType(), args.invalid_value))
else:
@ -100,30 +128,12 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
val = bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [aug_idx], inbounds=True)), ir.Constant(dtype_to_llvm_dtype[args.memory_dtype], args.invalid_value))
else:
val = bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True))
if args.memory_dtype != newvar.dtype:
if dtypes.is_int(args.memory_dtype) or args.memory_dtype == dtypes.bool:
val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(args.memory_dtype) or args.memory_dtype == dtypes.bool else bb[-1].sitofp(val, ir.FloatType())
elif args.memory_dtype == dtypes.bfloat16:
val = bb[-1].sext(val, ir.IntType(32))
val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16))
val = bb[-1].bitcast(val, ir.FloatType())
else:
val = bb[-1].fpext(val, ir.FloatType())
val = cast(bb, val, args.memory_dtype, newvar.dtype)
lvars[newvar] = val
if uop == UOps.STORE:
assert args.valid.min == 1 and isinstance(args, MemOp), "store must be valid and to memory"
idx = args.idx.render(render_llvm, bb[-1])
element = lvars[vin[0]]
if args.memory_dtype != vin[0].dtype:
if dtypes.is_int(args.memory_dtype) or args.memory_dtype == dtypes.bool:
element = bb[-1].fptoui(element, dtype_to_llvm_dtype[args.memory_dtype]) if dtypes.is_unsigned(args.memory_dtype) or args.memory_dtype == dtypes.bool else bb[-1].fptosi(element, dtype_to_llvm_dtype[args.memory_dtype])
elif args.memory_dtype == dtypes.bfloat16:
element = bb[-1].bitcast(element, ir.IntType(32))
element = bb[-1].lshr(element, ir.Constant(ir.IntType(32), 16))
element = bb[-1].trunc(element, ir.IntType(16))
else:
element = bb[-1].fptrunc(element, dtype_to_llvm_dtype[args.memory_dtype])
element = cast(bb, lvars[vin[0]], vin[0].dtype, args.memory_dtype)
bb[-1].store(element, bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True))
if uop == UOps.ALU:
lvars[newvar] = code_for_op[args](bb[-1], *[lvars[x] for x in vin])

View File

@ -55,7 +55,8 @@ class LLVMProgram:
LLVM.engine.finalize_object()
self.fxn = LLVM.engine.get_function_address(name)
def __del__(self): LLVM.engine.remove_module(self.mod)
def __del__(self):
if hasattr(self, 'mod'): LLVM.engine.remove_module(self.mod)
def __call__(self, unused_global_size, unused_local_size, *bufs, wait=False):
cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.c_void_p for _ in bufs])(self.fxn)