mirror of https://github.com/commaai/tinygrad.git
204 lines
11 KiB
Python
204 lines
11 KiB
Python
import yaml
|
|
from typing import Tuple, Set, Dict
|
|
from tinygrad import dtypes
|
|
from tinygrad.codegen.assembly import AssemblyCodegen, Register
|
|
from tinygrad.codegen.kernel import UOps
|
|
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
|
from tinygrad.runtime.ops_gpu import ROCM_LLVM_PATH
|
|
|
|
# ugh, is this really needed?
|
|
from extra.helpers import enable_early_exec
|
|
early_exec = enable_early_exec()
|
|
|
|
boilerplate_start = """
|
|
.global _start
|
|
_start:
|
|
.rodata
|
|
.align 0x10
|
|
.global code.kd
|
|
.type code.kd,STT_OBJECT
|
|
.amdhsa_kernel code"""
|
|
|
|
code_start = """.end_amdhsa_kernel
|
|
.text
|
|
code:
|
|
"""
|
|
|
|
# https://github.com/RadeonOpenCompute/ROCm_Documentation/blob/master/ROCm_Compiler_SDK/ROCm-Codeobj-format.rst
|
|
# https://github.com/ROCm-Developer-Tools/ROCm-ComputeABI-Doc/blob/master/AMDGPU-ABI.md#initial-kernel-register-state
|
|
# RDNA3 is actually a SIMD machine!
|
|
class RDNACodegen(AssemblyCodegen):
|
|
supports_float4: bool = True
|
|
supports_float4_alu: bool = True
|
|
supports_load3: bool = True
|
|
sin_is_sin2pi: bool = True
|
|
no_div: bool = True
|
|
|
|
def specialize(self, asm) -> Tuple[str, str]:
|
|
args = []
|
|
for i,b in enumerate(self.bufs): args.append({'.address_space': 'global', '.name': f'buf_{i}', '.offset': i*8, '.size': 8, '.type_name': b.dtype.name+"*", '.value_kind': 'global_buffer'})
|
|
ins = []
|
|
|
|
v_cnt = 3 # v[0:2] is local_xyz
|
|
s_cnt = 5 # s[0:1] is the address, s[2:4] is global_xyz
|
|
|
|
dtype_to_rdnatype = {dtypes.float32: "f32", dtypes.int64: "i64", dtypes.int32: "i32", dtypes.uint64: "u64", dtypes.bool: "i32"}
|
|
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", TernaryOps.MULACC: "fma",
|
|
BinaryOps.MAX: "max", UnaryOps.RECIP: "rcp",
|
|
UnaryOps.NOOP: "mov", UnaryOps.SIN: "sin", UnaryOps.LOG2: "log", UnaryOps.EXP2: "exp",
|
|
BinaryOps.CMPLT: "cmp_lt"}
|
|
|
|
pend_regs:Set[Register] = set()
|
|
rtor:Dict[Register, str] = {}
|
|
def reg_in(x):
|
|
nonlocal pend_regs
|
|
#print("reg_in", x, rtor[x], pend_regs)
|
|
if x in pend_regs:
|
|
#print("clear")
|
|
ins.append('s_waitcnt lgkmcnt(0), vmcnt(0)')
|
|
pend_regs.clear()
|
|
return rtor[x]
|
|
def reg_out(x):
|
|
return rtor[x]
|
|
for uop, out, vin, arg in asm:
|
|
if uop == UOps.DEFINE_REGISTER:
|
|
if arg[0][0] in [dtypes.uint32, dtypes.uint64, dtypes.int64, dtypes.int32, dtypes.float32, dtypes.float.vec(4)]:
|
|
for i in range(arg[2]):
|
|
# TODO: Re-use gaps created by this to avoid wasting registers
|
|
align = int(arg[0][0].itemsize / 4)
|
|
if arg[0][1]:
|
|
s_cnt += s_cnt % align
|
|
reg_name = f"s[{s_cnt}:{s_cnt + align - 1}]" if align > 1 else f"s{s_cnt}"
|
|
s_cnt += align
|
|
else:
|
|
v_cnt += v_cnt % align
|
|
reg_name = f"v[{v_cnt}:{v_cnt + align - 1}]" if align > 1 else f"v{v_cnt}"
|
|
v_cnt += align
|
|
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
|
|
|
|
if arg[0][0] == dtypes.float.vec(4):
|
|
for off in range(4):
|
|
reg_name = f"s{s_cnt-align+off}" if arg[0][1] else f"v{v_cnt-align+off}"
|
|
rtor[Register(f"%{arg[1]}{i}", dtypes.float, False, off=off)] = reg_name
|
|
elif arg[0][0] == dtypes.bool:
|
|
for i in range(arg[2]):
|
|
reg_name = "scc" if arg[0][1] else "vcc_lo" # `_lo` suffix since we're running wavefront_size=32
|
|
rtor[Register(f"%{arg[1]}{i}", *arg[0])] = reg_name
|
|
else:
|
|
raise NotImplementedError("DEFINE_REGISTER not implemented for arg: ", arg)
|
|
elif uop == UOps.SPECIAL:
|
|
if arg.startswith('buf'):
|
|
i = int(arg[3:])
|
|
ins.append(f's_load_b64 {reg_out(out)}, s[0:1], {i*8}')
|
|
pend_regs.add(out)
|
|
for r in out.subregs(): pend_regs.add(r)
|
|
elif arg.startswith('gid'):
|
|
ins.append(f'v_mov_b32 {reg_out(out)}, s{2+int(arg[3])}')
|
|
# the docs lied, this is actually y
|
|
if int(arg[3]) == 2: ins.append("v_bfe_u32 v2, v0, 20, 10") # untested
|
|
if int(arg[3]) == 1: ins.append("v_bfe_u32 v1, v0, 10, 10")
|
|
elif int(arg[3]) == 0: ins.append("v_and_b32_e32 v0, 0x3ff, v0")
|
|
# get local size
|
|
offset = len(args)*8
|
|
args.append({".offset": offset, ".value_kind": f"hidden_group_size_{'xyz'[int(arg[3])]}", ".size": 8})
|
|
ins.append(f's_load_b32 s{2+int(arg[3])}, s[0:1], {offset}')
|
|
ins.append('s_waitcnt vmcnt(0) lgkmcnt(0)')
|
|
pend_regs.clear()
|
|
ins.append(f'v_mul_i32_i24 {reg_out(out)}, {reg_out(out)}, s{2+int(arg[3])}')
|
|
ins.append(f'v_add_nc_u32 {reg_out(out)}, v{int(arg[3])}, {reg_out(out)}')
|
|
elif uop == UOps.CONST:
|
|
if arg == float('inf'): arg = "0x7f800000"
|
|
elif arg == float('-inf'): arg = "0xff800000"
|
|
if out.dtype == dtypes.float.vec(4):
|
|
for off in range(4):
|
|
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(Register(out.nm, dtypes.float, False, off=off))}, {arg}")
|
|
else:
|
|
ins.append(f"{'s_' if out.scalar else 'v_'}mov_b32 {reg_out(out)}, {arg}")
|
|
elif uop == UOps.ALU:
|
|
if arg in [BinaryOps.CMPLT]:
|
|
ins.append(f"{'s' if out.scalar else 'v'}_{alu[arg]}_{dtype_to_rdnatype[out.dtype]} {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
|
else:
|
|
alu_arg = alu[arg]
|
|
if arg == TernaryOps.MULACC and out == vin[2]:
|
|
alu_arg = "fmac"
|
|
vin = vin[0:2]
|
|
if out.dtype == dtypes.float.vec(4):
|
|
for rr in zip(*[x.subregs() if x.dtype == dtypes.float.vec(4) else [x,x,x,x] for x in [out]+vin]):
|
|
ins.append(f"{'s_' if rr[0].scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[rr[0].dtype]} {reg_out(rr[0])}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in rr[1:])}")
|
|
else:
|
|
ins.append(f"{'s_' if out.scalar else 'v_'}{alu_arg}_{dtype_to_rdnatype[out.dtype] if arg != UnaryOps.NOOP else 'b32'}{'_i24' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 and not out.scalar else ''} {reg_out(out)}, {', '.join(reg_in(x) if x.__class__ is Register else str(x) for x in vin)}")
|
|
elif uop == UOps.LOAD:
|
|
if out.scalar:
|
|
# swap arg order
|
|
ins.append(f's_load_b32 {reg_out(out)}, {reg_in(vin[0])}, {reg_in(vin[1])} offset:{arg[0]}')
|
|
else:
|
|
ins.append(f'global_load_{"b128" if out.dtype == dtypes.float.vec(4) else "b32"} {reg_out(out)}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
|
pend_regs.add(out)
|
|
for r in out.subregs(): pend_regs.add(r)
|
|
elif uop == UOps.STORE:
|
|
ins.append(f'global_store_{"b128" if vin[1].dtype == dtypes.float.vec(4) else "b32"} {reg_in(vin[2])}, {reg_in(vin[1])}, {reg_in(vin[0])} offset:{arg[0]}')
|
|
elif uop == UOps.LABEL:
|
|
ins.append(f"{arg}:")
|
|
elif uop == UOps.COND_BRANCH:
|
|
ins.append(f"s_cbranch_scc{'1' if arg[1] else '0'} {arg[0]}")
|
|
elif uop == UOps.CAST:
|
|
if vin[0].dtype == dtypes.bool:
|
|
if out.dtype == dtypes.float32:
|
|
ins.append(f"v_cndmask_b32 {reg_out(out)}, 0.0, 1.0, {reg_in(vin[0])}")
|
|
else:
|
|
raise NotImplementedError(f"cast {vin[0].dtype} -> {out.dtype}")
|
|
else:
|
|
raise NotImplementedError(uop)
|
|
|
|
ins += ['s_sendmsg sendmsg(MSG_DEALLOC_VGPRS)', 's_endpgm', 's_code_end']
|
|
|
|
# dual alu group
|
|
seen = set()
|
|
new_ins = []
|
|
for i,tins in enumerate(ins):
|
|
if tins in seen: continue
|
|
if tins.startswith("v_fmac_f32"):
|
|
for gins in reversed(ins[i+1:]):
|
|
if gins in seen: continue
|
|
if gins.startswith("v_fmac_f32"):
|
|
r0 = [int(x[1:].strip(',')) for x in tins.split(" ")[1:]]
|
|
r1 = [int(x[1:].strip(',')) for x in gins.split(" ")[1:]]
|
|
if r0[0]%2 == r1[0]%2: continue
|
|
if r0[1]%2 == r1[1]%2: continue
|
|
if r0[2]%2 == r1[2]%2: continue
|
|
new_ins.append(tins.replace("v_", "v_dual_")+" :: " + gins.replace("v_", "v_dual_"))
|
|
seen.add(tins)
|
|
seen.add(gins)
|
|
break
|
|
if tins not in seen:
|
|
new_ins.append(tins)
|
|
ins = new_ins
|
|
|
|
return 'code', self.assemble(args, ins, v_cnt, s_cnt)
|
|
|
|
def assemble(self, args, ins, v_cnt, s_cnt):
|
|
kernel_desc = {'.amdhsa_group_segment_fixed_size': 0, '.amdhsa_private_segment_fixed_size': 0, '.amdhsa_kernarg_size': 0,
|
|
'.amdhsa_next_free_vgpr': v_cnt, # this matters!
|
|
'.amdhsa_reserve_vcc': 0, '.amdhsa_reserve_xnack_mask': 0,
|
|
'.amdhsa_next_free_sgpr': s_cnt,
|
|
'.amdhsa_float_round_mode_32': 0, '.amdhsa_float_round_mode_16_64': 0, '.amdhsa_float_denorm_mode_32': 3, '.amdhsa_float_denorm_mode_16_64': 3, '.amdhsa_dx10_clamp': 1, '.amdhsa_ieee_mode': 1,
|
|
'.amdhsa_fp16_overflow': 0, '.amdhsa_workgroup_processor_mode': 1, '.amdhsa_memory_ordered': 1, '.amdhsa_forward_progress': 0, '.amdhsa_enable_private_segment': 0,
|
|
'.amdhsa_system_sgpr_workgroup_id_x': 1, '.amdhsa_system_sgpr_workgroup_id_y': 1, '.amdhsa_system_sgpr_workgroup_id_z': 1,
|
|
'.amdhsa_system_sgpr_workgroup_info': 0, '.amdhsa_system_vgpr_workitem_id': 2, # is amdhsa_system_vgpr_workitem_id real?
|
|
'.amdhsa_exception_fp_ieee_invalid_op': 0, '.amdhsa_exception_fp_denorm_src': 0, '.amdhsa_exception_fp_ieee_div_zero': 0, '.amdhsa_exception_fp_ieee_overflow': 0, '.amdhsa_exception_fp_ieee_underflow': 0,
|
|
'.amdhsa_exception_fp_ieee_inexact': 0, '.amdhsa_exception_int_div_zero': 0, '.amdhsa_user_sgpr_dispatch_ptr': 0, '.amdhsa_user_sgpr_queue_ptr': 0, '.amdhsa_user_sgpr_kernarg_segment_ptr': 1,
|
|
'.amdhsa_user_sgpr_dispatch_id': 0, '.amdhsa_user_sgpr_private_segment_size': 0, '.amdhsa_wavefront_size32': 1, '.amdhsa_uses_dynamic_stack': 0}
|
|
|
|
metadata = {'amdhsa.kernels': [{'.args': args,
|
|
'.group_segment_fixed_size': 0, '.kernarg_segment_align': 8, '.kernarg_segment_size': args[-1][".offset"] + args[-1][".size"],
|
|
'.language': 'OpenCL C', '.language_version': [1, 2], '.max_flat_workgroup_size': 256,
|
|
'.name': 'code', '.private_segment_fixed_size': 0, '.sgpr_count': s_cnt, '.sgpr_spill_count': 0,
|
|
'.symbol': 'code.kd', '.uses_dynamic_stack': False, '.vgpr_count': v_cnt, '.vgpr_spill_count': 0,
|
|
'.wavefront_size': 32}],
|
|
'amdhsa.target': 'amdgcn-amd-amdhsa--gfx1100', 'amdhsa.version': [1, 2]}
|
|
|
|
code = boilerplate_start + "\n" + '\n'.join("%s %d" % x for x in kernel_desc.items()) + "\n" + code_start + '\n'.join(ins) + "\n.amdgpu_metadata\n" + yaml.dump(metadata) + ".end_amdgpu_metadata"
|
|
obj = early_exec(([ROCM_LLVM_PATH / "llvm-mc", '--arch=amdgcn', '--mcpu=gfx1100', '--triple=amdgcn-amd-amdhsa', '--filetype=obj', '-'], code.encode("utf-8")))
|
|
asm = early_exec(([ROCM_LLVM_PATH / "ld.lld", "/dev/stdin", "-o", "/dev/stdout", "--pie"], obj))
|
|
return asm
|