diff --git a/extra/backends/ptx.py b/extra/backends/ptx.py new file mode 100644 index 00000000..2684853a --- /dev/null +++ b/extra/backends/ptx.py @@ -0,0 +1,240 @@ +from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple +import functools, struct +from collections import defaultdict +from tinygrad.codegen.linearizer import UOps, UOp +from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op +from tinygrad.dtype import dtypes, DType, PtrDType, INVERSE_DTYPES_DICT + +def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1]) +def double_to_hex(x): return "%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1]) +def trunc_float(x, fmt): return struct.unpack(fmt, struct.pack(fmt, x))[0] + +def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype) + +class AssemblyLanguage(NamedTuple): + kernel_prefix: str = "" + barrier: str = "" + load_global: bool = False + label_prefix: str = "" + gid: List[str] = [] + gdim: List[str] = [] + lid: List[str] = [] + const_requires_mov: List[DType] = [] # list of dtypes for which creating a const requires a move + asm_for_op: Dict[Op, Callable[...,str]] = {} + types: Dict[DType, str] = INVERSE_DTYPES_DICT + + def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]: raise NotImplementedError() + def render_local(self, dest, name, size, dtype) -> List[str]: raise NotImplementedError() + + def render_loop(self, idx, start, label, acc=None) -> List[str]: raise NotImplementedError() + def render_bra(self, b1, pred=None, b2=None) -> List[str]: raise NotImplementedError() + def render_gep(self, loc, base, offset, dtype, gate=None) -> List[str]: raise NotImplementedError() + def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="") -> List[str]: raise NotImplementedError() + def render_store(self, loc, val, dtype, gate=None, ss="") -> List[str]: raise NotImplementedError() + def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]: raise NotImplementedError() + + def render_kernel(self, kernel, function_name, bufs, regs) -> str: raise NotImplementedError() + +def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:List[UOp]) -> str: + local_size: List[int] = [] + kernel:List[str] = [] + bufs = [] + + def kk(*s: str): kernel.append("\n".join(s)) + + c: DefaultDict[str, int] = defaultdict(int) + r: Dict[UOp, str] = {} + def ssa(u, prefix="t", dtype=None) -> str: + nonlocal c, r + prefix += f"_{dtype if dtype else lang.types[u.dtype]}_" + c[prefix] += 1 + if u: r[u] = f"%{prefix}{c[prefix]-1}" + return f"%{prefix}{c[prefix]-1}" + + c_label: DefaultDict[str, int] = defaultdict(int) + r_label: Dict[UOp, str] = {} + def ssa_label(u, prefix): + nonlocal c_label, r_label + c_label[prefix] += 1 + r_label[u] = f"{lang.label_prefix}{prefix}_{c_label[prefix]-1}" + return r_label[u] + + def const(x:Union[float,int,bool], dtype, mov=False): + if mov or dtype in lang.const_requires_mov: + kk(*lang.render_const(x, dtype, mov=(out:=ssa(None, 'const', lang.types[dtype])))) + return out + return lang.render_const(x, dtype) + + def cast(a:str, dtype:DType, atype:DType, bitcast=False, u=None, pred=False): + if atype == dtype: + if u: r[u] = a + return a + kk(*lang.render_cast((ret:=ssa(u, 'cast', lang.types[dtype])), a, dtype, atype, bitcast)) + return ret + + for u in uops: + uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg + if uop == UOps.IF: + assert vin[0].dtype is not None + kk(*lang.render_bra(lb:=ssa_label(u, 'if'), cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:") + elif uop == UOps.BARRIER and lang.barrier: kk(lang.barrier) + elif uop == UOps.END: + if vin[0].uop == UOps.LOOP: + kk(lang.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, lang.types[dtypes.int]), + lang.asm_for_op[BinaryOps.CMPLT](pred:=ssa(None, "pred", "pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, lang.types[dtypes.int])) + kk(*lang.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:") + else: kk(f"{r_label[vin[0]]}:") + elif uop == UOps.STORE: + assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None + kk(*lang.render_gep(loc:=ssa(None,'loc','u64'), r[vin[0]], r[vin[1]], vin[0].dtype)) + if len(vin) > 3: + assert vin[3].dtype is not None + pred = cast(r[vin[3]], dtypes.bool, vin[3].dtype, pred=True) + kk(*lang.render_store(loc, r[vin[2]], vin[0].dtype, gate=pred if len(vin)>3 else None, ss='.shared' if vin[0].uop == UOps.DEFINE_LOCAL else '')) + else: + assert dtype is not None, f"None dtype for uop {uop}" + if uop == UOps.LOOP: kk(*lang.render_loop(ssa(u, 'ridx'), r[vin[0]], ssa_label(u, 'loop'))) + elif uop == UOps.ALU: + assert vin[0].dtype is not None + if args == BinaryOps.CMPLT or args == BinaryOps.CMPEQ: + regs = [cast(r[x], dtypes.int16, dtypes.bool) if x.dtype == dtypes.bool else r[x] for x in vin] + dt = dtypes.int16 if vin[0].dtype == dtypes.bool else vin[0].dtype + kk(lang.asm_for_op[args](pred:=ssa(u,'lt','pred'), *regs, dt, lang.types[dt])) + elif args == TernaryOps.MULACC: + assert vin[1].dtype is not None + kk(lang.asm_for_op[args](ssa(u, 'alu'), *[r[x] for x in vin], dtype, lang.types[vin[1].dtype])) + else: kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], dtype, lang.types[dtype])) + elif uop == UOps.DEFINE_ACC: kk(f"mov.b{lang.types[dtype][1:]} {ssa(u, 'acc')}, {const(args, dtype)};") + elif uop == UOps.SPECIAL: + if args[1][0] == "i": kk(f"mov.u32 %{args[1]}, {lang.gid[args[0]]};", f"mov.u32 {(gdim:=ssa(None,'tmp','u32'))}, {lang.gdim[args[0]]};", + f"mov.u32 {(lid:=ssa(None,'tmp','u32'))}, {lang.lid[args[0]]};", + f"mad.lo.u32 {(tmp:=ssa(None, 'tmp', 'u32'))}, %{args[1]}, {gdim}, {lid};") + else: kk(f"mov.u32 {(tmp:=ssa(None, 'tmp', 'u32'))}, {(lang.gid if args[1][0] == 'g' else lang.lid)[args[0]]};") + kk(*lang.render_cast(f"%{args[1]}", tmp, dtypes.uint, dtypes.int)) + if args[1][0] == "l": local_size.append(args[2]) + r[u] = "%" + args[1] + kernel = [f".reg .u32 %{args[1]};"] + kernel + elif uop == UOps.CONST: r[u] = const(args, dtype, mov=True) + elif uop == UOps.LOAD: + assert vin[1].dtype is not None + val = ssa(u, 'val') + if len(vin) > 3: + assert vin[2].dtype is not None + pred = cast(r[vin[2]], dtypes.bool, vin[2].dtype, pred=True) + off = cast(r[vin[1]], dtypes.uint, vin[1].dtype) + kk(*lang.render_gep(loc:=ssa(None,'loc',lang.types[dtypes.ulong]), r[vin[0]], off if len(vin)>3 else cast(r[vin[1]], + dtypes.uint, vin[1].dtype), dtype), + *lang.render_load(loc, val, dtype, gate=pred if len(vin) > 3 else None, + alt=r[vin[3]] if len(vin) > 3 else None, ss='.shared' if vin[0].uop == UOps.DEFINE_LOCAL else '')) + elif uop == UOps.PHI: + kk(f"mov.b{lang.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};") + r[u] = r[vin[0]] + elif uop == UOps.CAST: + assert vin[0].dtype is not None + cast(r[vin[0]], dtype, vin[0].dtype, bitcast=isinstance(args, tuple) and args[1], u=u) + elif uop == UOps.DEFINE_LOCAL: kk(*lang.render_local(ssa(u, 'local', lang.types[dtypes.ulong]), args[0], args[1], dtype)) + elif uop == UOps.DEFINE_GLOBAL: + bufs.append((args, dtype)) + r[u] = f"%{args}" + if lang.load_global: + dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype + kk(*lang.render_load(args, ssa(u, 'dat', dtype=lang.types[dt]), dt, ss=".param")) + else: raise NotImplementedError(f"no code for {uop}") + + return lang.render_kernel(kernel, function_name, bufs, c.items()) + +class PTXLanguage(AssemblyLanguage): + kernel_prefix = """.version 7.8 +.target TARGET +.address_size 64 +.visible .entry""" + barrier = "bar.sync\t0;" + has_pred = True + load_global = True + label_prefix = "$" + gid = [f'%ctaid.{chr(120+i)}' for i in range(3)] + gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)] + lid = [f'%tid.{chr(120+i)}' for i in range(3)] + asm_for_op = { + UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if dt == dtypes.bool else f"neg.{name} {d}, {a};", + UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};", + UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", + UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};", + BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};", + BinaryOps.SUB: lambda d,a,b,dt,name: f"sub.{name} {d}, {a}, {b};", + BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};", + BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.b{name[1:]} {d}, {a}, {b};", + BinaryOps.DIV: lambda d,a,b,dt,name: f"div{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a}, {b};", + BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};", + BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", + BinaryOps.CMPEQ: lambda d,a,b,dt,name: f"setp.eq.{name} {d}, {a}, {b};", + TernaryOps.MULACC: lambda d,a,b,c,dt,name: (('fma.rn' if dtypes.is_float(dt) else 'mad.lo' if a.split('_')[1]==c.split('_')[1] else 'mad.wide') + + f".{name} {d}, {a}, {b}, {c};"), + TernaryOps.WHERE: lambda d,a,b,c,dt,name: f"selp.{name} {d}, {b}, {c}, {a};" + } + supports_half = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, + TernaryOps.MULACC, TernaryOps.WHERE] + types = { + dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64", + dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64", + dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64", + dtypes.bool: "pred" + } + + const_requires_mov = [dtypes.half, dtypes.bool] + + def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]: + if dtypes.is_float(dtype): val = f"0f{float_to_hex(x)}" if dtype != dtypes.float64 else f"0d{double_to_hex(x)}" + else: val = str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "") + if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"] + if dtype == dtypes.half: return [f".reg .f32 {mov}_tmp;", f"mov.f32 {mov}_tmp, {val};", f"cvt.rn.f16.f32 {mov}, {mov}_tmp;"] + return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val + + def render_local(self, dest, name, size, dtype) -> List[str]: + return [f".shared .align 4 .b8 {name}[{size*dtype.itemsize}];", f"mov.u64 {dest}, {name}[0];"] + + def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"] + + def render_bra(self, b1, pred=None, b2=None) -> List[str]: return [f"@{pred} bra {b1};", f"@!{pred} bra {b2};"] if pred else [f"bra {b1};"] + + def render_gep(self, loc, base, offset, dtype, gate=None) -> List[str]: + # this cast is only required because of ocelot + if "s32" in offset: + return [f".reg .u32 {offset}_cast;", f"cvt.u32.s32 {offset}_cast, {offset};", f"mad.wide.u32 {loc}, {offset}_cast, {dtype.itemsize}, {base};"] + else: return [f"mad.wide.u32 {loc}, {offset}, {dtype.itemsize}, {base};"] + + def mem_type(self, dtype): return 's8' if dtype.itemsize == 1 else 'b16' if dtype == dtypes.float16 else self.types[dtype] + + def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="") -> List[str]: + ret = [] + if (byte:=dtype.itemsize == 1): ret.append(f".reg .s8 {dest}_tmp;") + if (isbool:= dtype == dtypes.bool): ret.append(f".reg .s16 {dest}_bool;") + if gate: ret.extend([f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}];", + f"@!{gate} mov.b{'8' if byte else self.types[dtype][1:]} {dest + ('_tmp' if byte else '')}, {alt};"]) + else: ret.append(f"ld{ss}.{'s8' if byte else 'b16' if dtype==dtypes.float16 else self.types[dtype]} {dest + ('_tmp' if byte else '')}, [{loc}];") + if byte: ret.append(f"cvt.{'s16' if isbool else self.types[dtype]}.s8 {dest + ('_bool' if isbool else '')}, {dest}_tmp;") + if isbool: ret.append(f"setp.ne.s16 {dest}, {dest}_bool, {self.render_const(0, dtypes.int16)};") + return ret + + def render_store(self, loc, val, dtype, gate=None, ss="") -> List[str]: + if dtype == dtypes.bool: return [f".reg .s16 {val}_cast;", *self.render_cast(f"{val}_cast", val, dtypes.int16, dtype), + (f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}], {val}_cast;"] + return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}], {val};"] + + def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]: + if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"] + if atype == dtypes.bool: return [f"selp.{self.types[dtype]} {d}, {self.render_const(1, dtype)}, {self.render_const(0, dtype)}, {a};"] + if dtype == dtypes.bool: return [f"setp.ne.{self.types[atype]} {d}, {a}, {self.render_const(0, atype)};"] + rnd = ('.rzi' if dtypes.is_int(dtype) and dtypes.is_float(atype) else + '.rn' if dtypes.is_float(dtype) and (dtype.itemsize < atype.itemsize or dtypes.is_int(atype) or atype == dtypes.bool) else '') + return [f"cvt{rnd}.{self.types[dtype]}.{self.types[atype]} {d}, {a};"] + + def render_kernel(self, kernel, function_name, bufs, regs) -> str: + kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"] + def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1) + return (f"{self.kernel_prefix} {function_name}(\n\t" + + ',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) + "\n)\n{\n" + + '\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) + + "\n}") + +PTXRenderer = functools.partial(uops_to_asm, PTXLanguage()) diff --git a/test/test_dtype.py b/test/test_dtype.py index a80b94de..85dc7c4c 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -40,7 +40,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target): if DEBUG >= 2: print(tensor.numpy()) try: assert tensor.dtype == target_dtype - np.testing.assert_allclose(tensor.numpy(), target) + np.testing.assert_allclose(tensor.numpy(), target, rtol=1e-3 if target_dtype == dtypes.float16 else 1e-7) except AssertionError as e: raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e diff --git a/test/test_dtype_alu.py b/test/test_dtype_alu.py index b7433f9a..db4abf44 100644 --- a/test/test_dtype_alu.py +++ b/test/test_dtype_alu.py @@ -63,7 +63,7 @@ def universal_test_unary(a, dtype, op): tensor_value = out.numpy() numpy_value = op[1](np.array([a]).astype(dtype.np)) if dtype in dtypes_float: - atol = 2 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-3 + atol = 2 if (Device.DEFAULT == "METAL" or getenv("PTX")) and op[0] == Tensor.sin else 1e-3 rtol = 2 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-4 if dtype == dtypes.float32 else 1e-2 # exp and log and sin are approximations (in METAL, the default fast-math versions are less precise) np.testing.assert_allclose(tensor_value, numpy_value, atol=atol, rtol=rtol) @@ -84,7 +84,7 @@ def universal_test_midcast(a, b, c, op1, op2, d1:DType, d2:DType): an, bn, cn = np.array([a]).astype(d1.np), np.array([b]).astype(d1.np), np.array([c]).astype(d2.np) tensor_value = op2[0](op1[0](at, bt).cast(d2), ct).numpy() numpy_value = op2[1](op1[1](an, bn).astype(d2.np), cn) - np.testing.assert_almost_equal(tensor_value, numpy_value) + np.testing.assert_allclose(tensor_value, numpy_value, rtol=1e-6 if getenv("PTX") else 1e-7) class TestDTypeALU(unittest.TestCase): @unittest.skipIf(OSX and Device.DEFAULT in {"GPU", "METAL"}, "no float64 on OSX GPU") diff --git a/test/test_fusion_op.py b/test/test_fusion_op.py index e26ddb43..0562e608 100644 --- a/test/test_fusion_op.py +++ b/test/test_fusion_op.py @@ -30,7 +30,7 @@ class TestFusionOp(unittest.TestCase): sched = create_schedule([a.lazydata], None) ji = lower_schedule_item(sched[-1]) self.assertLess(time.perf_counter()-st, 1.0) - assert isinstance(ji, InterpretedASTRunner) or len(ji.prg) < 5000 + assert isinstance(ji, InterpretedASTRunner) or len(ji.prg.splitlines()) < 250 def test_recursive_add_cmp(self): st = time.perf_counter() diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index 669349b1..12f19ee4 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -34,7 +34,7 @@ def compile_cuda(prg) -> bytes: return compile_cuda_style(prg, [f'--gpu-architec class CUDAProgram: def __init__(self, device:CUDADevice, name:str, lib:bytes): self.device, self.name, self.lib = device, name, lib - if DEBUG >= 5: print(pretty_ptx(lib.decode('utf-8'))) + if DEBUG >= 5: print("\n".join([f"{i+1:>3} {line}" for i, line in enumerate(pretty_ptx(lib.decode('utf-8')).split("\n"))])) if DEBUG >= 6: try: fn = (Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()