mirror of https://github.com/commaai/tinygrad.git
more uops testing, who isn't passing right now... (#1522)
* more uops * llvm refactor * update test uops * rest of the nodes * ors and ands
This commit is contained in:
parent
f8109b830c
commit
0b5930d406
|
@ -37,36 +37,20 @@ def _test_single_value_const(tc, tt, vals, op):
|
|||
prg([buf])
|
||||
return buf.toCPU()[0]
|
||||
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "only test for compiled backends")
|
||||
class TestUOps(unittest.TestCase):
|
||||
def _equal(self, v1, v2):
|
||||
if not (math.isnan(v1) and math.isnan(v2)): self.assertAlmostEqual(v1, v2, places=5)
|
||||
|
||||
def _test_uop_fxn(self, bop, fxn, dt=dtypes.float32):
|
||||
for f in [_test_single_value, _test_single_value_const]:
|
||||
for a in [-2.0, 2.0]:
|
||||
for a in [-2.0, 0.0, 1.0, 2.0]:
|
||||
self._equal(f(Token('c', dt), [Token('a', dt)], [a], bop), fxn(a))
|
||||
def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a))
|
||||
def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('nan'))
|
||||
def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a))
|
||||
def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan'))
|
||||
#def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1.0/a)
|
||||
|
||||
def _test_bop_fxn(self, bop, fxn, dt=dtypes.float32):
|
||||
def _test_bop_fxn(self, bop, fxn, dt=dtypes.float32, no_b_zero=False):
|
||||
for f in [_test_single_value, _test_single_value_const]:
|
||||
for a in [-2.0, 2.0]:
|
||||
for b in [-3.0, 3.0]:
|
||||
for a in [-2.0, 0.0, 1.0, 2.0]:
|
||||
for b in [-3.0, 1.0, 3.0] + ([] if no_b_zero else [0.0]):
|
||||
self._equal(f(Token('c', dt), [Token('a', dt), Token('b', dt)], [a,b], bop), fxn(a,b))
|
||||
def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b)
|
||||
def test_sub(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: a-b)
|
||||
def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b)
|
||||
def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b)
|
||||
def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
|
||||
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b))
|
||||
# MOD isn't tested
|
||||
|
||||
# doesn't work in LLVM
|
||||
#def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b, dtypes.int32)
|
||||
|
||||
def _test_top_fxn(self, bop, fxn, dt=dtypes.float32):
|
||||
for f in [_test_single_value, _test_single_value_const]:
|
||||
|
@ -74,8 +58,37 @@ class TestUOps(unittest.TestCase):
|
|||
for b in [-3.0, 3.0]:
|
||||
for c in [-4.0, 4.0]:
|
||||
self._equal(f(Token('d', dt), [Token('a', dt), Token('b', dt), Token('c', dt)], [a,b,c], bop), fxn(a,b,c))
|
||||
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled), "only test for compiled backends")
|
||||
class TestFloatUOps(TestUOps):
|
||||
def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a))
|
||||
def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan'))
|
||||
def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a))
|
||||
def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan'))
|
||||
# this is not on most backends
|
||||
#def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1.0/a if a != 0 else float('inf'))
|
||||
|
||||
def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b)
|
||||
def test_sub(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: a-b)
|
||||
def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b)
|
||||
def test_div(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: a/b if b != 0 else a*float('inf'))
|
||||
def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b))
|
||||
def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b))
|
||||
# MOD isn't tested on floats
|
||||
|
||||
def test_mulacc(self): self._test_top_fxn(TernaryOps.MULACC, lambda a,b,c: (a*b)+c)
|
||||
def test_where(self): self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c)
|
||||
|
||||
# TODO: fix this on all the backends
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or Device.DEFAULT == "LLVM" or getenv('ARM64', False), "only test for compiled backends, broken on some")
|
||||
class TestNonFloatUOps(TestUOps):
|
||||
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), dtypes.int32)
|
||||
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), dtypes.int32)
|
||||
def test_div_int32(self): self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), dtypes.int32, no_b_zero=True)
|
||||
def test_mod_int32(self): self._test_bop_fxn(BinaryOps.MOD, lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], dtypes.int32, no_b_zero=True)
|
||||
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), dtypes.int32)
|
||||
@unittest.skipIf(Device.DEFAULT == "CLANG", "broken in CLANG")
|
||||
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), dtypes.bool)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
|
|
@ -124,7 +124,7 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.ex
|
|||
|
||||
class ASTRunner:
|
||||
def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
|
||||
if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args): print(prg)
|
||||
if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args or not runtime_args['binary']): print(prg)
|
||||
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
|
||||
|
||||
def build(self, runtime):
|
||||
|
|
|
@ -35,7 +35,7 @@ class CStyleLanguage(NamedTuple):
|
|||
UnaryOps.SQRT: lambda x: f"sqrt({x})",
|
||||
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
|
||||
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
|
||||
BinaryOps.MAX: lambda a,b: f"max({a},{b})",
|
||||
BinaryOps.MAX: lambda a,b: f"max({a},{b})", BinaryOps.MOD: lambda a,b: f"({a}%{b})",
|
||||
BinaryOps.CMPLT: lambda a,b: f"({a}<{b})", TernaryOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})",
|
||||
TernaryOps.WHERE: lambda a,b,c: f"({a}!=0?{b}:{c})"
|
||||
}
|
||||
|
|
|
@ -32,6 +32,35 @@ code_for_op: Final[Dict[Op, Callable]] = {
|
|||
TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), y, z, flags=('fast',)),
|
||||
}
|
||||
|
||||
dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}
|
||||
|
||||
def cast(bb, val, input_type, output_type):
|
||||
if input_type == output_type: return val
|
||||
|
||||
if output_type == dtypes.float32:
|
||||
if dtypes.is_int(input_type) or input_type == dtypes.bool:
|
||||
val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(input_type) or input_type == dtypes.bool else bb[-1].sitofp(val, ir.FloatType())
|
||||
elif input_type == dtypes.bfloat16:
|
||||
val = bb[-1].sext(val, ir.IntType(32))
|
||||
val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16))
|
||||
val = bb[-1].bitcast(val, ir.FloatType())
|
||||
else:
|
||||
val = bb[-1].fpext(val, ir.FloatType())
|
||||
return val
|
||||
|
||||
if input_type == dtypes.float32:
|
||||
if dtypes.is_int(output_type) or output_type == dtypes.bool:
|
||||
val = bb[-1].fptoui(val, dtype_to_llvm_dtype[output_type]) if dtypes.is_unsigned(output_type) or output_type == dtypes.bool else bb[-1].fptosi(val, dtype_to_llvm_dtype[output_type])
|
||||
elif output_type == dtypes.bfloat16:
|
||||
val = bb[-1].bitcast(val, ir.IntType(32))
|
||||
val = bb[-1].lshr(val, ir.Constant(ir.IntType(32), 16))
|
||||
val = bb[-1].trunc(val, ir.IntType(16))
|
||||
else:
|
||||
val = bb[-1].fptrunc(val, dtype_to_llvm_dtype[output_type])
|
||||
return val
|
||||
|
||||
raise NotImplementedError(f"cast from {input_type} -> {output_type} not implemented")
|
||||
|
||||
def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[List[int]], Optional[List[int]]]:
|
||||
# all llvm stuff goes into a module
|
||||
module = ir.Module(name=__file__)
|
||||
|
@ -41,7 +70,6 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
|
|||
buf_index = {x:i for i,x in enumerate(buf_to_dtype.keys())}
|
||||
|
||||
# create llvm function
|
||||
dtype_to_llvm_dtype = {dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32)}
|
||||
func_dtypes = [dtype_to_llvm_dtype[dtype] for dtype in buf_to_dtype.values()]
|
||||
func = ir.Function(module, ir.FunctionType(ir.VoidType(), [x.as_pointer() for x in func_dtypes]), name=function_name)
|
||||
|
||||
|
@ -84,9 +112,9 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
|
|||
bb[-2].cbranch(bb[-2].icmp_unsigned("==", idx_p1, int_const(var.max+1)), bb[-1]._block, block._block)
|
||||
if uop == UOps.LOAD:
|
||||
assert newvar is not None and isinstance(args, (MemOp, ConstOp))
|
||||
assert newvar.dtype == dtypes.float, "newvar must be float"
|
||||
valid = args.valid.render(render_llvm, bb[-1])
|
||||
if isinstance(args, ConstOp):
|
||||
assert newvar.dtype == dtypes.float, "newvar must be float"
|
||||
if args.valid.min == 0 and args.valid.max == 1:
|
||||
val = bb[-1].select(valid, ir.Constant(ir.FloatType(), args.value), ir.Constant(ir.FloatType(), args.invalid_value))
|
||||
else:
|
||||
|
@ -100,30 +128,12 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li
|
|||
val = bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [aug_idx], inbounds=True)), ir.Constant(dtype_to_llvm_dtype[args.memory_dtype], args.invalid_value))
|
||||
else:
|
||||
val = bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True))
|
||||
|
||||
if args.memory_dtype != newvar.dtype:
|
||||
if dtypes.is_int(args.memory_dtype) or args.memory_dtype == dtypes.bool:
|
||||
val = bb[-1].uitofp(val, ir.FloatType()) if dtypes.is_unsigned(args.memory_dtype) or args.memory_dtype == dtypes.bool else bb[-1].sitofp(val, ir.FloatType())
|
||||
elif args.memory_dtype == dtypes.bfloat16:
|
||||
val = bb[-1].sext(val, ir.IntType(32))
|
||||
val = bb[-1].shl(val, ir.Constant(ir.IntType(32), 16))
|
||||
val = bb[-1].bitcast(val, ir.FloatType())
|
||||
else:
|
||||
val = bb[-1].fpext(val, ir.FloatType())
|
||||
val = cast(bb, val, args.memory_dtype, newvar.dtype)
|
||||
lvars[newvar] = val
|
||||
if uop == UOps.STORE:
|
||||
assert args.valid.min == 1 and isinstance(args, MemOp), "store must be valid and to memory"
|
||||
idx = args.idx.render(render_llvm, bb[-1])
|
||||
element = lvars[vin[0]]
|
||||
if args.memory_dtype != vin[0].dtype:
|
||||
if dtypes.is_int(args.memory_dtype) or args.memory_dtype == dtypes.bool:
|
||||
element = bb[-1].fptoui(element, dtype_to_llvm_dtype[args.memory_dtype]) if dtypes.is_unsigned(args.memory_dtype) or args.memory_dtype == dtypes.bool else bb[-1].fptosi(element, dtype_to_llvm_dtype[args.memory_dtype])
|
||||
elif args.memory_dtype == dtypes.bfloat16:
|
||||
element = bb[-1].bitcast(element, ir.IntType(32))
|
||||
element = bb[-1].lshr(element, ir.Constant(ir.IntType(32), 16))
|
||||
element = bb[-1].trunc(element, ir.IntType(16))
|
||||
else:
|
||||
element = bb[-1].fptrunc(element, dtype_to_llvm_dtype[args.memory_dtype])
|
||||
element = cast(bb, lvars[vin[0]], vin[0].dtype, args.memory_dtype)
|
||||
bb[-1].store(element, bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True))
|
||||
if uop == UOps.ALU:
|
||||
lvars[newvar] = code_for_op[args](bb[-1], *[lvars[x] for x in vin])
|
||||
|
|
|
@ -55,7 +55,8 @@ class LLVMProgram:
|
|||
LLVM.engine.finalize_object()
|
||||
self.fxn = LLVM.engine.get_function_address(name)
|
||||
|
||||
def __del__(self): LLVM.engine.remove_module(self.mod)
|
||||
def __del__(self):
|
||||
if hasattr(self, 'mod'): LLVM.engine.remove_module(self.mod)
|
||||
|
||||
def __call__(self, unused_global_size, unused_local_size, *bufs, wait=False):
|
||||
cfunc = CFUNCTYPE(ctypes.c_int, *[ctypes.c_void_p for _ in bufs])(self.fxn)
|
||||
|
|
Loading…
Reference in New Issue