2024-03-09 02:17:49 +08:00
|
|
|
import itertools
|
|
|
|
from tinygrad import Device
|
2024-05-14 22:47:03 +08:00
|
|
|
from tinygrad.engine.realize import CompiledRunner
|
2024-07-16 02:29:03 +08:00
|
|
|
from tinygrad.helpers import getenv, colorize_float
|
2024-03-09 02:17:49 +08:00
|
|
|
from extra.optimization.helpers import load_worlds, ast_str_to_lin
|
2024-05-15 14:12:59 +08:00
|
|
|
from tinygrad.engine.search import bufs_from_lin
|
2024-05-14 22:47:03 +08:00
|
|
|
from tinygrad.runtime.ops_cuda import PTXCompiler, PTXRenderer, CUDACompiler
|
2024-03-09 02:17:49 +08:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
ast_strs = load_worlds(filter_reduce=False, filter_novariable=True)
|
2024-05-14 22:47:03 +08:00
|
|
|
# no bfloat16 for ptx at the moment
|
|
|
|
ast_strs = [x for x in ast_strs if "dtypes.bfloat16" not in x]
|
2024-03-09 02:17:49 +08:00
|
|
|
dev = Device["CUDA"]
|
2024-05-14 22:47:03 +08:00
|
|
|
ptx = PTXRenderer(dev.arch)
|
2024-03-09 02:17:49 +08:00
|
|
|
|
|
|
|
# NUM=112 python3 test/external/speed_compare_cuda_ptx.py
|
|
|
|
|
|
|
|
single = getenv("NUM", -1)
|
|
|
|
if single != -1: ast_strs = ast_strs[single:single+1]
|
|
|
|
|
|
|
|
average_tm_cuda, average_tm_ptx = 0, 0
|
|
|
|
for num,ast in enumerate(ast_strs):
|
|
|
|
# cuda compile
|
2024-05-14 22:47:03 +08:00
|
|
|
dev.compiler = CUDACompiler(dev.arch)
|
2024-05-11 12:40:02 +08:00
|
|
|
lin = ast_str_to_lin(ast, opts=dev.renderer)
|
2024-03-09 02:17:49 +08:00
|
|
|
lin.hand_coded_optimizations()
|
2024-05-14 22:47:03 +08:00
|
|
|
cuda_prg = CompiledRunner(lin.to_program())
|
2024-03-09 02:17:49 +08:00
|
|
|
|
|
|
|
bufs = bufs_from_lin(lin)
|
|
|
|
|
|
|
|
# ptx compile
|
2024-05-14 22:47:03 +08:00
|
|
|
dev.compiler = PTXCompiler(dev.arch)
|
|
|
|
lin = ast_str_to_lin(ast, opts=ptx)
|
2024-03-09 02:17:49 +08:00
|
|
|
lin.hand_coded_optimizations()
|
|
|
|
lin.linearize()
|
2024-05-14 22:47:03 +08:00
|
|
|
ptx_prg = CompiledRunner(lin.to_program())
|
|
|
|
|
|
|
|
# warmup
|
2024-03-09 02:17:49 +08:00
|
|
|
try:
|
2024-05-14 22:47:03 +08:00
|
|
|
cuda_prg(bufs, {}, wait=True)
|
2024-03-09 02:17:49 +08:00
|
|
|
except RuntimeError:
|
2024-05-14 22:47:03 +08:00
|
|
|
print("cuda failed ast:", num)
|
2024-03-09 02:17:49 +08:00
|
|
|
continue
|
|
|
|
ptx_prg(bufs, {}, wait=True)
|
|
|
|
|
|
|
|
tm_cuda, tm_ptx = [], []
|
|
|
|
for i in range(5):
|
|
|
|
tm_cuda.append(cuda_prg(bufs, {}, wait=True))
|
|
|
|
tm_ptx.append(ptx_prg(bufs, {}, wait=True))
|
|
|
|
average_tm_cuda += min(tm_cuda)
|
|
|
|
average_tm_ptx += min(tm_ptx)
|
|
|
|
ratio = min(tm_ptx)/min(tm_cuda)
|
|
|
|
print(f"{average_tm_ptx/average_tm_cuda:5.2f}x -- {num:4d} {colorize_float(ratio)} {min(tm_ptx)*1e6:7.2f} us", lin.name)
|
|
|
|
if ratio > 1.5:
|
|
|
|
def fix(x): return x.replace('\t', ' ').strip()
|
2024-05-14 22:47:03 +08:00
|
|
|
ll1, ll2 = cuda_prg.lib.decode().split('\n'), ptx_prg.lib.decode().split('\n')
|
2024-03-09 02:17:49 +08:00
|
|
|
if single != -1:
|
|
|
|
for ln, (l1, l2) in enumerate(itertools.zip_longest(ll1, ll2, fillvalue='')):
|
|
|
|
print(f"{ln:5d} | {fix(l1):80s} | {fix(l2):80s}")
|
|
|
|
print(len(ll1), len(ll2), "RATIO", ratio, "us", min(tm_ptx)*1e6)
|