mirror of https://github.com/commaai/tinygrad.git
177 lines
9.2 KiB
Python
177 lines
9.2 KiB
Python
import struct
|
|
from platform import system
|
|
from typing import Tuple, Dict, List, Optional
|
|
from tinygrad import dtypes
|
|
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps
|
|
from tinygrad.codegen.kernel import UOps, UOp
|
|
from tinygrad.helpers import CI
|
|
from tinygrad.codegen.assembly import uops_to_asmstyle, AssemblyLanguage
|
|
|
|
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
|
def compute_offsets(total):
|
|
quotient, remainder = divmod(total, 4096)
|
|
return [4096]*quotient + [remainder] if remainder else [4096]*quotient
|
|
|
|
#NOTE: Darwin needs names to start with a "_"
|
|
def get_name(name): return ('_' if system() == 'Darwin' else '') + name
|
|
|
|
class ARM64Language(AssemblyLanguage): pass
|
|
|
|
def specialize_to_arm64(fn_nm, asm):
|
|
var_size = 16
|
|
prev_uop:Optional[UOps] = None
|
|
ins = []
|
|
x_regs = ['x' + str(i) for i in reversed(range(12))]
|
|
s_regs = ['s' + str(i) for i in reversed(range(3,32)) if i <= 7 or i >= 16]
|
|
type_to_reg = {dtypes.double: "d", dtypes.half: 'h', dtypes.float32: 's', dtypes.bool: 'w', dtypes.int8:'w', dtypes.int32: 'w', dtypes.int64: 'x', dtypes.uint8:'w', dtypes.uint32: 'w', dtypes.uint64: 'x'}
|
|
alu = {BinaryOps.ADD: "add", BinaryOps.SUB: "sub", BinaryOps.MUL: "mul", BinaryOps.DIV: "div", BinaryOps.MAX: "max",
|
|
BinaryOps.MOD: "", BinaryOps.CMPLT: "subs",
|
|
UnaryOps.NOOP: "mov", UnaryOps.NEG: "neg",
|
|
UnaryOps.SIN:'bl ' + get_name('sinf'), UnaryOps.LOG2: 'bl ' + get_name("log2f"), UnaryOps.EXP2: 'bl ' + get_name("exp2f"), UnaryOps.SQRT: 'bl ' + get_name("sqrtf"),
|
|
TernaryOps.MULACC: "madd", TernaryOps.WHERE: "fcsel"}
|
|
|
|
def mov_imm(value, reg):
|
|
# Manually move value into reg if value can't fit
|
|
if value.__class__ is not float and abs(value) > abs(65535):
|
|
ins.append(f"movz w15, #{value & 0xffff}")
|
|
ins.append(f"movk w15, #{(value >> 16) & 0xffff}, lsl #16")
|
|
ins.append(f"sxtw {reg}, w15")
|
|
elif reg[0] == 's':
|
|
ins.append(f"movz x15, 0x{float_to_hex(value)[4:]}")
|
|
ins.append(f"movk x15, 0x{float_to_hex(value)[:4]}, lsl #16")
|
|
ins.append("str x15, [sp, 16]")
|
|
ins.append(f"ldr {reg}, [sp, 16]")
|
|
else:
|
|
ins.append(f"mov {reg}, #{value}")
|
|
|
|
# Get variables intervals
|
|
live_range:Dict[str, List[int]] = {}
|
|
for i, (uop, out, vin, arg) in enumerate(asm):
|
|
for var in ([v for v in [out] + vin if v is not None and v.__class__ is not int]):
|
|
live_range[var.nm] = [i,i] if var.nm not in live_range else [live_range[var.nm][0], i]
|
|
|
|
mem_vars:Dict[str, int] = {}
|
|
rtor:Dict[str, str] = {}
|
|
def allocate_regs(mvars):
|
|
nonlocal var_size
|
|
for v in [v for v in mvars if v is not None and v.__class__ is not int and v.nm not in rtor]:
|
|
available_regs = s_regs if dtypes.is_float(v[1]) else x_regs
|
|
#NOTE: Very simple spill, everything that don't fit in regs goes to mem
|
|
if not available_regs:
|
|
# ARM needs the stack 16-byte aligned
|
|
var_size += 16
|
|
available_regs.append('s0' if dtypes.is_float(out[1]) else 'x12')
|
|
mem_vars[v.nm] = var_size
|
|
rtor[v.nm] = available_regs.pop()
|
|
|
|
temp_floats = ['s0', 's1', 's2']
|
|
temp_ints = ['x12', 'x13', 'x16']
|
|
for i, (uop, out, vin, arg) in enumerate(asm):
|
|
# Clear regs out of interval
|
|
for var, reg in list(rtor.items()):
|
|
available_regs = s_regs if reg[0] == 's' else x_regs
|
|
if var[1] not in 'B' and var not in mem_vars and i > live_range[var][1]:
|
|
available_regs.append(rtor.pop(var))
|
|
# Assign a registers to the variables using live ranges.
|
|
allocate_regs([out] + vin)
|
|
# Assign temp regs to vin and load them before direct use
|
|
for i, v in enumerate([v for v in vin if v.__class__ is not int and v.nm in mem_vars]):
|
|
rtor[v.nm] = temp_floats[i] if dtypes.is_float(v[1]) else temp_ints[i]
|
|
# ARM64 addressing constraints https://devblogs.microsoft.com/oldnewthing/20220728-00/?p=106912
|
|
ins.append(f"mov x15, {mem_vars[v.nm]}")
|
|
ins.append(f"ldr {rtor[v.nm]}, [sp, x15]")
|
|
|
|
if uop == UOps.SPECIAL:
|
|
if arg.startswith('data'):
|
|
# data 8 to n into the stack
|
|
if int(arg[4:]) >= 8:
|
|
ins.append(f"ldr x15, [x17, #{(int(arg[4:]) - 8) * 8}]")
|
|
ins.append(f"mov {rtor[out.nm]}, x15")
|
|
else:
|
|
ins.append(f"mov {rtor[out.nm]}, #0")
|
|
ins.append(f"loop_{arg}:")
|
|
elif uop == UOps.CAST:
|
|
if arg == BinaryOps.CMPLT:
|
|
if rtor[out.nm][0] == 's':
|
|
mov_imm(0.0, 's0')
|
|
mov_imm(1.0, 's1')
|
|
ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt")
|
|
if rtor[out.nm][0] == 'x':
|
|
mov_imm(0, 'x14')
|
|
mov_imm(1, 'x15')
|
|
ins.append(f"csel {rtor[out.nm]}, x15, x14, lt")
|
|
else:
|
|
ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}")
|
|
elif uop == UOps.ALU:
|
|
if len(vin)==2 and vin[1].__class__ is int: mov_imm(vin[1], 'x15')
|
|
if arg == BinaryOps.MUL and out.dtype == dtypes.bool:
|
|
ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
|
elif arg == TernaryOps.WHERE:
|
|
ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0" if rtor[vin[0].nm][0] == 's' else f"cmp {rtor[vin[0].nm]}, #0")
|
|
ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne")
|
|
elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]:
|
|
#NOTE: Not a real instruction, use to emulate a ext call in unicorn
|
|
if CI: ins.append(f"{alu[arg]} {rtor[out.nm]} {rtor[vin[0].nm]}")
|
|
else:
|
|
save_regs = [k for k in rtor.keys() if k != out.nm and k not in mem_vars]
|
|
ins.append(f"sub sp, sp, #{(len(save_regs))*16}")
|
|
# Save the registers before they are cleared by func call
|
|
for i,k in enumerate(save_regs,1):
|
|
ins.append(f"str {rtor[k]}, [sp, #{16*i}]")
|
|
ins.append("stp x29, x30, [sp, #0]!")
|
|
ins.append("mov x29, sp")
|
|
ins.append(f"fmov s0, {rtor[vin[0].nm]}")
|
|
ins.append(alu[arg])
|
|
ins.append(f"fmov {rtor[out.nm]}, s0")
|
|
ins.append("mov sp, x29")
|
|
ins.append("ldp x29, x30, [sp], #0")
|
|
for i,k in enumerate(save_regs,1):
|
|
ins.append(f"ldr {rtor[k]}, [sp, #{16*i}]")
|
|
ins.append(f"add sp, sp, #{len(save_regs)*16}")
|
|
elif arg == BinaryOps.CMPLT:
|
|
ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}")
|
|
elif arg == BinaryOps.MOD:
|
|
rhs = 'x15' if vin[1].__class__ is int else rtor[vin[1].nm]
|
|
ins.append(f"udiv x14, {rtor[vin[0].nm]}, {rhs}")
|
|
ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}")
|
|
else:
|
|
ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}")
|
|
elif uop == UOps.LOAD:
|
|
if arg.__class__ in (int, float):
|
|
mov_imm(arg, rtor[out.nm])
|
|
else:
|
|
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
|
reg_in = type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[out.nm]
|
|
mov_imm(arg[0], "x15")
|
|
ins.append(f"add x15, {rtor[vin[0].nm]}, x15")
|
|
ins.append(f"ldr{'sb' if arg[2] is not None and arg[2] in (dtypes.int8, dtypes.uint8, dtypes.bool) else ''} {reg_in}, [x15]")
|
|
if arg[2] is not None: ins.append(f"{'fcvt' if arg[2] in [dtypes.half, dtypes.double] else 'scvtf'} {rtor[out.nm]}, {reg_in}")
|
|
elif uop == UOps.STORE:
|
|
#NOTE: if need casting load var in s/h0 or x/w12 temp regs
|
|
reg_out = (type_to_reg[arg[2]] + ('0' if dtypes.is_float(arg[2]) else '12') if arg[2] is not None else rtor[vin[1].nm])
|
|
if arg[2] is not None: ins.append(f"fcvt{'zs' if arg[2] not in [dtypes.half, dtypes.double] else '' } {reg_out}, {rtor[vin[1].nm]}")
|
|
ins.append(f"mov x15, #{arg[0]}")
|
|
ins.append(f"str {reg_out}, [{rtor[vin[0].nm]}, x15, lsl #0]")
|
|
elif uop == UOps.COND_BRANCH:
|
|
#TODO: this is a hack it shouldn't always be a cmp before a cond branch?
|
|
if prev_uop == UOps.LOAD:
|
|
ins.append(f"cmp {rtor[vin[0].nm]}, #0")
|
|
ins.append(f"b.{'lt' if arg[1] else 'ge'} {arg[0][1:]}")
|
|
elif uop == UOps.LABEL:
|
|
ins.append(f"{arg[1:]}:")
|
|
elif uop == UOps.ENDLOOP:
|
|
mov_imm(arg[0], "x15")
|
|
ins.append(f"add {rtor[vin[0].nm]}, {rtor[vin[0].nm]}, #1")
|
|
ins.append(f"cmp {rtor[vin[0].nm]}, x15")
|
|
ins.append(f"b.lt loop_{arg[1]}")
|
|
prev_uop = uop
|
|
# store regs into memory if needed
|
|
if out is not None and out.nm in mem_vars:
|
|
ins.append(f"mov x15, {mem_vars[out.nm]}")
|
|
ins.append(f"str {rtor[out.nm]}, [sp, x15]")
|
|
return "\n".join([f"//varsize {var_size}",".arch armv8-a",".text", f".global {get_name(fn_nm)}",".p2align 2", f"{get_name(fn_nm)}:", "mov x17, sp"] + [f"sub sp, sp, #{offset}" for offset in compute_offsets(var_size)]+ ins + [f"add sp, sp, #{offset}" for offset in compute_offsets(var_size)] +["ret", "\n"])
|
|
|
|
def uops_to_arm64_asm(fn_nm:str, uops:List[UOp]) -> Tuple[str, List[int], List[int], bool]:
|
|
lang = ARM64Language()
|
|
global_size, local_size = uops_to_asmstyle(lang, fn_nm, uops)
|
|
return specialize_to_arm64(fn_nm, lang.ins), global_size[::-1], local_size[::-1], True |