mirror of https://github.com/commaai/tinygrad.git
Christopherm99 ptx (#3139)
* get basic ptx impl working * test ops passing * mypy * dont hardcode target * more walrus * ptx in ci * bool cast and f16 load/store * weird numpy bug and f16 cast tolerance * cast half to bool * fix 1 byte load/store * disable half for ptx * fix args and enable xid * fix non-ptr args * allow bitcast * mypy * cleanups * midcast use allclose * add xor * Revert "disable half for ptx" This reverts commit 73391c05fde5f7811293f60d994417d97ab20613. * enable float16 * mypy * no more crashing in ci * fix ci * minor cleanups * use new fn for ptx compiler * no diskcache in ptx compile * use rn instead of rz * save some lines * new DEFINE_GLOBAL syntax * line length * new llvm * cmpeq * minor fix * cast in mulacc * update test_recursive_add to check line count * mypy * remove llvmir.py * fix bool const * wip * cleanups * working * llvm in separate pr * cleanups * more cleanups * fix ci * use in_features directly in nn.Linear.__init__ bound check (#3050) * use in_features directly in nn.Linear.__init__ bound check get rid of the unnecessary check of isinstance int * that is always int * long lines * Device._buffers -> Device._devices (#3052) backend devices used to be called buffers * make Embedding device aware for multigpu (#3051) * make Embedding device aware for multigpu * split line instead of igore because that's cheating * add test incomplete * add test complete * remove comment * fix white space * remove nn.Embedding * remove unused reciprocal (#3053) * remove unused reciprocal * comment * unit tests for Device.canonicalize (#3055) * add multigpu test for RMSNorm (#3056) * need all gather * add two multigpu test scenarios for RMSNorm * No extra vars call (#3054) * remove unused reciprocal * comment * remove unneeded call to vars * free speedup * explicit lazybuffer caching (#3058) * hotfix: remove useless slow assert from ShapeTracker * Speed tweaks (#3059) * base doesn't have to be a function * no double fetch * pop, don't check * make the gc happy * avoid hasattr * cache canonicalize * remove assert, faster base * don't redefine that every time * fix gpt2 attention with start_pos = 0 (#3061) * fix gpt2 attention with start_pos size 1 test cases taken from ll_transformer branch * fix interpreted * Tensor.cat with 0 shape tensors (#3062) * Tensor.cat with 0 shape tensors supported both 0 in cat axis (for a subset of input), or 0 in non-cat axis (all needs to be 0) * no shp * test scaled dot product attention (#3063) * add test * add initial test for scaled dot product attention * test pass for scaled dot product attention * cached size (#3060) * cached size * simplify simplify * 0 doesn't have base * fix test * cleaner cache * hmm, metal is flaky on this...might be real(ish) but useless as test * short circuit reshape/expand properly * better reshape bypass * hotfix: use is for enum compare * hotfix: use is for enum compare, a few more * speedtweaks3: apply shouldn't use the tensor constructor (#3065) * speedtweaks3: apply shouldn't use the tensor constructor * replace 0 size with CONST, not 0 in shape * update gh actions (#3033) * update checkout actions * update upload artifact * update setup python --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> * unbind view or shapetracker also returns var_val (#3067) * unbind view or shapetracker also returns var_val 4% faster for llama compile time * one line less * unbound_views * hotfix: examples/transformer.py * jit autorealizes output (#3069) * early gate the graph (#3070) * simpler idxs_to_idx (#3071) * filter_strides -> canonicalize_strides (#3072) * fix onehot and jit in examples/transformer (#3073) trained to 0.999 in < 6 seconds on M1 Max consistently * better test demonstration (#3077) * a better test demonstration * fix white space * Tensor.expand resolves the new_shape before shortcut return (#3078) similar to how reshape is done. also updated shrink shortcut criteria to read similar to pad * minor cleanups of lazy.py (#3080) * wmma: clean up device specific tensor core code (#3081) * mem_estimate is always int, not symbolic (#3083) * mem_estimate is always int, not symbolic op_estimate can be symbolic, but mem_estimate is always int, thus we don't need to sym_infer it. fixed some long lines too. update_stats is a very big function * operator does not need underscores * cat works (#3086) * hotfix disable flaky mac runner wino cifar (#3087) * remove the third merging state in view._merge_dims (#3085) no logic depends on state == 0 or state == 2 * minor cleanup of View.reshape (#3088) * minor cleanup of View.reshape removed some redundant logic * new_strides * revert that * use BEAM=2 instead of BEAM=4 in cuda ci gpt2 (#3089) BEAM=2 is faster and less search time. investigating why BEAM2+BEAM4 is slower than BEAM2 alone * use device from LinearizerOptions in kernel search (#3090) * use device from LinearizerOptions in kernel search removed all Device.DEFAULT in search.py * pass device string for parallel pickle * device for interpreted backends in LinearizerOptions * update jit type annotation post lazy rewrite (#3091) * add mutigpu support for llama attention (#3064) * add llama attention test for multigpu * test fails * kv cache trying to shrink on sharded axis * mask None works for scale dot product * kv cache seems to be working but scale dot product breaks * scaled dot product works, but the last linear layer failed * running into the reshape case where it could be wrong for multigpu * making sure it was the reshape * adding contiguous doesn't solve * need to shard more properly * remove reshape test * minor adjustment to scale dot product attention test * weights are sharded wrong * continue fix new weight sharding * clean up * fix attention when start_pos is 0 * remove print * add TODOs for the best mutigpu interface * bugfix do not reset shapetracker of 0 size lazybuffer (#3096) it might be coming from an expand, and resetting results incorrect stride. caught by interpreted backend * One hot in tensor.py (#3093) * onehot in Tensor.py * one_hot tests * works for all shapes, not just 1 * pylint * not a static method * moved around, num_classes mandatory * pylint * pylint * space & moving * formatting * moved tests * fix broadcasted logic if there's 0 in shapes (#3097) * fix broadcasted logic if there's 0 in shapes should always expand into 0, not the other way around. fixed matmul with 0 in input shapes. for forwards for now though, backward is more involved and would need to change 0 size shortcuts * fix tests * replace with tensor op (#3099) * fix gpt2 with empty prompt (#3100) logits would be empty so need to replace that with ones before sampling, also cannot reshape with -1 when there's 0 in other axes * Revert "fix gpt2 with empty prompt" (#3101) * fix gpt2 with empty prompt take 2 (#3102) logits would be empty so need to replace that with ones before sampling, also cannot reshape with -1 when there's 0 in other axes * wmma: enable METAL half tensor cores and clean up cstyle (#3095) * wmma: enable METAL half tensor cores and clean up cstyle * revert simple_matmul rand changes and break line in tensor * added metal fp16->fp32 tensor core * add half @ half to mac benchmark (#3103) * flag to profile mixtral - 1.7 tok/s now (#3104) * update NumNode.__hash__ to be hash(self.b) (#3105) with this, `a:=NumNode(x) == b` implies `hash(a) == hash(b)` * catch runtime error in search._time_program (#3106) return inf if search encountered runtime errors. * no exceptions in __del__ when module creation is failed in hip/cuda (#3107) * failed test case due to cast resets shapetracker (#3109) cast implicitly resets shapetracker and makes it contiguous (for disk tensor), which fails for Interpreted backend if inputs contain non-contiguous st. * cleanup ops_disk type annotation and redundant str cast (#3110) * minor cleanup of test_disk_tensor (#3112) * add Tensor.var (#3114) also updated MeanVarianceNormalization and made test_ops test tensors of var and std smaller * move sample inside jit for beautiful_mnist (#3115) also removed .realize() for jit functions since jit does it automatically now. a little more beautiful * minor cleanups of onnx_ops (#3116) * fix conversation: llama generates token not prob now (#3120) * add device options for tests in multigpu (#3121) * make DType a dataclass (#3111) * remove np from DType * convert to dataclass * remove dunder hash, eq, ne overrides from ImageDType * is dataclass required for PtrDType? * fix GPU tests * reduce lines * revert changes to np * minor cleanup * hotfix: ptrdtype compare was broken * move fromcpu out of lazy.py (#3122) * move fromcpu out of lazy.py * fix abstractions2 * remove numpy from device (#3123) * remove numpy from device * fix tests * np item * cleanups * simplify with as_buffer * no toCPU * tinygradic * cast to scalar * remove numpy from ops_torch (#3124) updated mnist test to cast label to int8 and avoid hacking cast issue of torch uint8 * Fix backward fn for `<` and `==` (#3037) * fix no grad fn for < and == * remove 2 line breaks * Remove deprecated autograd variable --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> * separate try except blocks in onnx2torch in model benchmark (#3126) exceptions can be raised from either model conversion or individual backend failed. openpilot on torch mps works, but does not work with torch cpu. seperate the expcetion block so that the benchmark can inlcude torch mps for openpilot. * update env_vars.md (#3127) mostly removed deprecated ones. not clear how to maintain this especially for extra/examples * update test_ptr_ne (#3130) * remove np from metal graph (#3129) * dtype fmt (#3132) * dtype fmt * three ways to access * fix off-by-one error in st_equal (#3131) * fix off by one error * whitespace * no numpy (#3134) * fast resnet eval (#3135) * fast resnet eval * fix HIP multidevice graph * neater expression for devices * lines * add decorator test * remove LLVMOPT * move ptx * Update ops_cuda.py --------- Co-authored-by: Christopher Milan <chrismilan@ucla.edu> Co-authored-by: chenyu <chenyu@fastmail.com> Co-authored-by: Yixiang Gao <yixiangg310573@gmail.com> Co-authored-by: jxdv <virgoj@protonmail.com> Co-authored-by: Francis Lam <flam@alum.mit.edu> Co-authored-by: SnakeOnex <sheeproman@gmail.com> Co-authored-by: nimlgen <138685161+nimlgen@users.noreply.github.com> Co-authored-by: Jyotirmaya Mahanta <jyotirmaya.mahanta@gmail.com> Co-authored-by: Guy Leroy <g.m.leroy@outlook.com> Co-authored-by: Paul Gustafson <paul.gustafson@theambrusgroup.com>
This commit is contained in:
parent
1ee11411f1
commit
ca0beeef38
|
@ -0,0 +1,240 @@
|
||||||
|
from typing import Callable, DefaultDict, Dict, List, Union, NamedTuple
|
||||||
|
import functools, struct
|
||||||
|
from collections import defaultdict
|
||||||
|
from tinygrad.codegen.linearizer import UOps, UOp
|
||||||
|
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
|
||||||
|
from tinygrad.dtype import dtypes, DType, PtrDType, INVERSE_DTYPES_DICT
|
||||||
|
|
||||||
|
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
||||||
|
def double_to_hex(x): return "%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
|
||||||
|
def trunc_float(x, fmt): return struct.unpack(fmt, struct.pack(fmt, x))[0]
|
||||||
|
|
||||||
|
def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
|
||||||
|
|
||||||
|
class AssemblyLanguage(NamedTuple):
|
||||||
|
kernel_prefix: str = ""
|
||||||
|
barrier: str = ""
|
||||||
|
load_global: bool = False
|
||||||
|
label_prefix: str = ""
|
||||||
|
gid: List[str] = []
|
||||||
|
gdim: List[str] = []
|
||||||
|
lid: List[str] = []
|
||||||
|
const_requires_mov: List[DType] = [] # list of dtypes for which creating a const requires a move
|
||||||
|
asm_for_op: Dict[Op, Callable[...,str]] = {}
|
||||||
|
types: Dict[DType, str] = INVERSE_DTYPES_DICT
|
||||||
|
|
||||||
|
def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]: raise NotImplementedError()
|
||||||
|
def render_local(self, dest, name, size, dtype) -> List[str]: raise NotImplementedError()
|
||||||
|
|
||||||
|
def render_loop(self, idx, start, label, acc=None) -> List[str]: raise NotImplementedError()
|
||||||
|
def render_bra(self, b1, pred=None, b2=None) -> List[str]: raise NotImplementedError()
|
||||||
|
def render_gep(self, loc, base, offset, dtype, gate=None) -> List[str]: raise NotImplementedError()
|
||||||
|
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="") -> List[str]: raise NotImplementedError()
|
||||||
|
def render_store(self, loc, val, dtype, gate=None, ss="") -> List[str]: raise NotImplementedError()
|
||||||
|
def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]: raise NotImplementedError()
|
||||||
|
|
||||||
|
def render_kernel(self, kernel, function_name, bufs, regs) -> str: raise NotImplementedError()
|
||||||
|
|
||||||
|
def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:List[UOp]) -> str:
|
||||||
|
local_size: List[int] = []
|
||||||
|
kernel:List[str] = []
|
||||||
|
bufs = []
|
||||||
|
|
||||||
|
def kk(*s: str): kernel.append("\n".join(s))
|
||||||
|
|
||||||
|
c: DefaultDict[str, int] = defaultdict(int)
|
||||||
|
r: Dict[UOp, str] = {}
|
||||||
|
def ssa(u, prefix="t", dtype=None) -> str:
|
||||||
|
nonlocal c, r
|
||||||
|
prefix += f"_{dtype if dtype else lang.types[u.dtype]}_"
|
||||||
|
c[prefix] += 1
|
||||||
|
if u: r[u] = f"%{prefix}{c[prefix]-1}"
|
||||||
|
return f"%{prefix}{c[prefix]-1}"
|
||||||
|
|
||||||
|
c_label: DefaultDict[str, int] = defaultdict(int)
|
||||||
|
r_label: Dict[UOp, str] = {}
|
||||||
|
def ssa_label(u, prefix):
|
||||||
|
nonlocal c_label, r_label
|
||||||
|
c_label[prefix] += 1
|
||||||
|
r_label[u] = f"{lang.label_prefix}{prefix}_{c_label[prefix]-1}"
|
||||||
|
return r_label[u]
|
||||||
|
|
||||||
|
def const(x:Union[float,int,bool], dtype, mov=False):
|
||||||
|
if mov or dtype in lang.const_requires_mov:
|
||||||
|
kk(*lang.render_const(x, dtype, mov=(out:=ssa(None, 'const', lang.types[dtype]))))
|
||||||
|
return out
|
||||||
|
return lang.render_const(x, dtype)
|
||||||
|
|
||||||
|
def cast(a:str, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
|
||||||
|
if atype == dtype:
|
||||||
|
if u: r[u] = a
|
||||||
|
return a
|
||||||
|
kk(*lang.render_cast((ret:=ssa(u, 'cast', lang.types[dtype])), a, dtype, atype, bitcast))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
for u in uops:
|
||||||
|
uop,dtype,vin,args = u.uop,u.dtype,u.vin,u.arg
|
||||||
|
if uop == UOps.IF:
|
||||||
|
assert vin[0].dtype is not None
|
||||||
|
kk(*lang.render_bra(lb:=ssa_label(u, 'if'), cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
|
||||||
|
elif uop == UOps.BARRIER and lang.barrier: kk(lang.barrier)
|
||||||
|
elif uop == UOps.END:
|
||||||
|
if vin[0].uop == UOps.LOOP:
|
||||||
|
kk(lang.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, lang.types[dtypes.int]),
|
||||||
|
lang.asm_for_op[BinaryOps.CMPLT](pred:=ssa(None, "pred", "pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, lang.types[dtypes.int]))
|
||||||
|
kk(*lang.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
|
||||||
|
else: kk(f"{r_label[vin[0]]}:")
|
||||||
|
elif uop == UOps.STORE:
|
||||||
|
assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None
|
||||||
|
kk(*lang.render_gep(loc:=ssa(None,'loc','u64'), r[vin[0]], r[vin[1]], vin[0].dtype))
|
||||||
|
if len(vin) > 3:
|
||||||
|
assert vin[3].dtype is not None
|
||||||
|
pred = cast(r[vin[3]], dtypes.bool, vin[3].dtype, pred=True)
|
||||||
|
kk(*lang.render_store(loc, r[vin[2]], vin[0].dtype, gate=pred if len(vin)>3 else None, ss='.shared' if vin[0].uop == UOps.DEFINE_LOCAL else ''))
|
||||||
|
else:
|
||||||
|
assert dtype is not None, f"None dtype for uop {uop}"
|
||||||
|
if uop == UOps.LOOP: kk(*lang.render_loop(ssa(u, 'ridx'), r[vin[0]], ssa_label(u, 'loop')))
|
||||||
|
elif uop == UOps.ALU:
|
||||||
|
assert vin[0].dtype is not None
|
||||||
|
if args == BinaryOps.CMPLT or args == BinaryOps.CMPEQ:
|
||||||
|
regs = [cast(r[x], dtypes.int16, dtypes.bool) if x.dtype == dtypes.bool else r[x] for x in vin]
|
||||||
|
dt = dtypes.int16 if vin[0].dtype == dtypes.bool else vin[0].dtype
|
||||||
|
kk(lang.asm_for_op[args](pred:=ssa(u,'lt','pred'), *regs, dt, lang.types[dt]))
|
||||||
|
elif args == TernaryOps.MULACC:
|
||||||
|
assert vin[1].dtype is not None
|
||||||
|
kk(lang.asm_for_op[args](ssa(u, 'alu'), *[r[x] for x in vin], dtype, lang.types[vin[1].dtype]))
|
||||||
|
else: kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], dtype, lang.types[dtype]))
|
||||||
|
elif uop == UOps.DEFINE_ACC: kk(f"mov.b{lang.types[dtype][1:]} {ssa(u, 'acc')}, {const(args, dtype)};")
|
||||||
|
elif uop == UOps.SPECIAL:
|
||||||
|
if args[1][0] == "i": kk(f"mov.u32 %{args[1]}, {lang.gid[args[0]]};", f"mov.u32 {(gdim:=ssa(None,'tmp','u32'))}, {lang.gdim[args[0]]};",
|
||||||
|
f"mov.u32 {(lid:=ssa(None,'tmp','u32'))}, {lang.lid[args[0]]};",
|
||||||
|
f"mad.lo.u32 {(tmp:=ssa(None, 'tmp', 'u32'))}, %{args[1]}, {gdim}, {lid};")
|
||||||
|
else: kk(f"mov.u32 {(tmp:=ssa(None, 'tmp', 'u32'))}, {(lang.gid if args[1][0] == 'g' else lang.lid)[args[0]]};")
|
||||||
|
kk(*lang.render_cast(f"%{args[1]}", tmp, dtypes.uint, dtypes.int))
|
||||||
|
if args[1][0] == "l": local_size.append(args[2])
|
||||||
|
r[u] = "%" + args[1]
|
||||||
|
kernel = [f".reg .u32 %{args[1]};"] + kernel
|
||||||
|
elif uop == UOps.CONST: r[u] = const(args, dtype, mov=True)
|
||||||
|
elif uop == UOps.LOAD:
|
||||||
|
assert vin[1].dtype is not None
|
||||||
|
val = ssa(u, 'val')
|
||||||
|
if len(vin) > 3:
|
||||||
|
assert vin[2].dtype is not None
|
||||||
|
pred = cast(r[vin[2]], dtypes.bool, vin[2].dtype, pred=True)
|
||||||
|
off = cast(r[vin[1]], dtypes.uint, vin[1].dtype)
|
||||||
|
kk(*lang.render_gep(loc:=ssa(None,'loc',lang.types[dtypes.ulong]), r[vin[0]], off if len(vin)>3 else cast(r[vin[1]],
|
||||||
|
dtypes.uint, vin[1].dtype), dtype),
|
||||||
|
*lang.render_load(loc, val, dtype, gate=pred if len(vin) > 3 else None,
|
||||||
|
alt=r[vin[3]] if len(vin) > 3 else None, ss='.shared' if vin[0].uop == UOps.DEFINE_LOCAL else ''))
|
||||||
|
elif uop == UOps.PHI:
|
||||||
|
kk(f"mov.b{lang.types[dtype][1:]} {r[vin[0]]}, {r[vin[1]]};")
|
||||||
|
r[u] = r[vin[0]]
|
||||||
|
elif uop == UOps.CAST:
|
||||||
|
assert vin[0].dtype is not None
|
||||||
|
cast(r[vin[0]], dtype, vin[0].dtype, bitcast=isinstance(args, tuple) and args[1], u=u)
|
||||||
|
elif uop == UOps.DEFINE_LOCAL: kk(*lang.render_local(ssa(u, 'local', lang.types[dtypes.ulong]), args[0], args[1], dtype))
|
||||||
|
elif uop == UOps.DEFINE_GLOBAL:
|
||||||
|
bufs.append((args, dtype))
|
||||||
|
r[u] = f"%{args}"
|
||||||
|
if lang.load_global:
|
||||||
|
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
|
||||||
|
kk(*lang.render_load(args, ssa(u, 'dat', dtype=lang.types[dt]), dt, ss=".param"))
|
||||||
|
else: raise NotImplementedError(f"no code for {uop}")
|
||||||
|
|
||||||
|
return lang.render_kernel(kernel, function_name, bufs, c.items())
|
||||||
|
|
||||||
|
class PTXLanguage(AssemblyLanguage):
|
||||||
|
kernel_prefix = """.version 7.8
|
||||||
|
.target TARGET
|
||||||
|
.address_size 64
|
||||||
|
.visible .entry"""
|
||||||
|
barrier = "bar.sync\t0;"
|
||||||
|
has_pred = True
|
||||||
|
load_global = True
|
||||||
|
label_prefix = "$"
|
||||||
|
gid = [f'%ctaid.{chr(120+i)}' for i in range(3)]
|
||||||
|
gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)]
|
||||||
|
lid = [f'%tid.{chr(120+i)}' for i in range(3)]
|
||||||
|
asm_for_op = {
|
||||||
|
UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if dt == dtypes.bool else f"neg.{name} {d}, {a};",
|
||||||
|
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
|
||||||
|
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};",
|
||||||
|
UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
|
||||||
|
BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
|
||||||
|
BinaryOps.SUB: lambda d,a,b,dt,name: f"sub.{name} {d}, {a}, {b};",
|
||||||
|
BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
|
||||||
|
BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.b{name[1:]} {d}, {a}, {b};",
|
||||||
|
BinaryOps.DIV: lambda d,a,b,dt,name: f"div{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a}, {b};",
|
||||||
|
BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
|
||||||
|
BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};",
|
||||||
|
BinaryOps.CMPEQ: lambda d,a,b,dt,name: f"setp.eq.{name} {d}, {a}, {b};",
|
||||||
|
TernaryOps.MULACC: lambda d,a,b,c,dt,name: (('fma.rn' if dtypes.is_float(dt) else 'mad.lo' if a.split('_')[1]==c.split('_')[1] else 'mad.wide') +
|
||||||
|
f".{name} {d}, {a}, {b}, {c};"),
|
||||||
|
TernaryOps.WHERE: lambda d,a,b,c,dt,name: f"selp.{name} {d}, {b}, {c}, {a};"
|
||||||
|
}
|
||||||
|
supports_half = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT,
|
||||||
|
TernaryOps.MULACC, TernaryOps.WHERE]
|
||||||
|
types = {
|
||||||
|
dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
|
||||||
|
dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
|
||||||
|
dtypes.float16: "f16", dtypes.float32: "f32", dtypes.float64: "f64",
|
||||||
|
dtypes.bool: "pred"
|
||||||
|
}
|
||||||
|
|
||||||
|
const_requires_mov = [dtypes.half, dtypes.bool]
|
||||||
|
|
||||||
|
def render_const(self, x:Union[float,int,bool], dtype, mov=None) -> Union[List[str], str]:
|
||||||
|
if dtypes.is_float(dtype): val = f"0f{float_to_hex(x)}" if dtype != dtypes.float64 else f"0d{double_to_hex(x)}"
|
||||||
|
else: val = str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "")
|
||||||
|
if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"]
|
||||||
|
if dtype == dtypes.half: return [f".reg .f32 {mov}_tmp;", f"mov.f32 {mov}_tmp, {val};", f"cvt.rn.f16.f32 {mov}, {mov}_tmp;"]
|
||||||
|
return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val
|
||||||
|
|
||||||
|
def render_local(self, dest, name, size, dtype) -> List[str]:
|
||||||
|
return [f".shared .align 4 .b8 {name}[{size*dtype.itemsize}];", f"mov.u64 {dest}, {name}[0];"]
|
||||||
|
|
||||||
|
def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"]
|
||||||
|
|
||||||
|
def render_bra(self, b1, pred=None, b2=None) -> List[str]: return [f"@{pred} bra {b1};", f"@!{pred} bra {b2};"] if pred else [f"bra {b1};"]
|
||||||
|
|
||||||
|
def render_gep(self, loc, base, offset, dtype, gate=None) -> List[str]:
|
||||||
|
# this cast is only required because of ocelot
|
||||||
|
if "s32" in offset:
|
||||||
|
return [f".reg .u32 {offset}_cast;", f"cvt.u32.s32 {offset}_cast, {offset};", f"mad.wide.u32 {loc}, {offset}_cast, {dtype.itemsize}, {base};"]
|
||||||
|
else: return [f"mad.wide.u32 {loc}, {offset}, {dtype.itemsize}, {base};"]
|
||||||
|
|
||||||
|
def mem_type(self, dtype): return 's8' if dtype.itemsize == 1 else 'b16' if dtype == dtypes.float16 else self.types[dtype]
|
||||||
|
|
||||||
|
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="") -> List[str]:
|
||||||
|
ret = []
|
||||||
|
if (byte:=dtype.itemsize == 1): ret.append(f".reg .s8 {dest}_tmp;")
|
||||||
|
if (isbool:= dtype == dtypes.bool): ret.append(f".reg .s16 {dest}_bool;")
|
||||||
|
if gate: ret.extend([f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}];",
|
||||||
|
f"@!{gate} mov.b{'8' if byte else self.types[dtype][1:]} {dest + ('_tmp' if byte else '')}, {alt};"])
|
||||||
|
else: ret.append(f"ld{ss}.{'s8' if byte else 'b16' if dtype==dtypes.float16 else self.types[dtype]} {dest + ('_tmp' if byte else '')}, [{loc}];")
|
||||||
|
if byte: ret.append(f"cvt.{'s16' if isbool else self.types[dtype]}.s8 {dest + ('_bool' if isbool else '')}, {dest}_tmp;")
|
||||||
|
if isbool: ret.append(f"setp.ne.s16 {dest}, {dest}_bool, {self.render_const(0, dtypes.int16)};")
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def render_store(self, loc, val, dtype, gate=None, ss="") -> List[str]:
|
||||||
|
if dtype == dtypes.bool: return [f".reg .s16 {val}_cast;", *self.render_cast(f"{val}_cast", val, dtypes.int16, dtype),
|
||||||
|
(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}], {val}_cast;"]
|
||||||
|
return [(f"@{gate} " if gate else "") + f"st{ss}.{self.mem_type(dtype)} [{loc}], {val};"]
|
||||||
|
|
||||||
|
def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]:
|
||||||
|
if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"]
|
||||||
|
if atype == dtypes.bool: return [f"selp.{self.types[dtype]} {d}, {self.render_const(1, dtype)}, {self.render_const(0, dtype)}, {a};"]
|
||||||
|
if dtype == dtypes.bool: return [f"setp.ne.{self.types[atype]} {d}, {a}, {self.render_const(0, atype)};"]
|
||||||
|
rnd = ('.rzi' if dtypes.is_int(dtype) and dtypes.is_float(atype) else
|
||||||
|
'.rn' if dtypes.is_float(dtype) and (dtype.itemsize < atype.itemsize or dtypes.is_int(atype) or atype == dtypes.bool) else '')
|
||||||
|
return [f"cvt{rnd}.{self.types[dtype]}.{self.types[atype]} {d}, {a};"]
|
||||||
|
|
||||||
|
def render_kernel(self, kernel, function_name, bufs, regs) -> str:
|
||||||
|
kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"]
|
||||||
|
def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1)
|
||||||
|
return (f"{self.kernel_prefix} {function_name}(\n\t" +
|
||||||
|
',\n\t'.join([f".param .{'u64' if dtype.__class__ == PtrDType else self.types[dtype]} {name}" for name,dtype in bufs]) + "\n)\n{\n" +
|
||||||
|
'\n'.join([fmt(line) for op in kernel for line in op.splitlines()]) +
|
||||||
|
"\n}")
|
||||||
|
|
||||||
|
PTXRenderer = functools.partial(uops_to_asm, PTXLanguage())
|
|
@ -40,7 +40,7 @@ def _assert_eq(tensor:Tensor, target_dtype:DType, target):
|
||||||
if DEBUG >= 2: print(tensor.numpy())
|
if DEBUG >= 2: print(tensor.numpy())
|
||||||
try:
|
try:
|
||||||
assert tensor.dtype == target_dtype
|
assert tensor.dtype == target_dtype
|
||||||
np.testing.assert_allclose(tensor.numpy(), target)
|
np.testing.assert_allclose(tensor.numpy(), target, rtol=1e-3 if target_dtype == dtypes.float16 else 1e-7)
|
||||||
except AssertionError as e:
|
except AssertionError as e:
|
||||||
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
|
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
|
||||||
|
|
||||||
|
|
|
@ -63,7 +63,7 @@ def universal_test_unary(a, dtype, op):
|
||||||
tensor_value = out.numpy()
|
tensor_value = out.numpy()
|
||||||
numpy_value = op[1](np.array([a]).astype(dtype.np))
|
numpy_value = op[1](np.array([a]).astype(dtype.np))
|
||||||
if dtype in dtypes_float:
|
if dtype in dtypes_float:
|
||||||
atol = 2 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-3
|
atol = 2 if (Device.DEFAULT == "METAL" or getenv("PTX")) and op[0] == Tensor.sin else 1e-3
|
||||||
rtol = 2 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-4 if dtype == dtypes.float32 else 1e-2
|
rtol = 2 if Device.DEFAULT == "METAL" and op[0] == Tensor.sin else 1e-4 if dtype == dtypes.float32 else 1e-2
|
||||||
# exp and log and sin are approximations (in METAL, the default fast-math versions are less precise)
|
# exp and log and sin are approximations (in METAL, the default fast-math versions are less precise)
|
||||||
np.testing.assert_allclose(tensor_value, numpy_value, atol=atol, rtol=rtol)
|
np.testing.assert_allclose(tensor_value, numpy_value, atol=atol, rtol=rtol)
|
||||||
|
@ -84,7 +84,7 @@ def universal_test_midcast(a, b, c, op1, op2, d1:DType, d2:DType):
|
||||||
an, bn, cn = np.array([a]).astype(d1.np), np.array([b]).astype(d1.np), np.array([c]).astype(d2.np)
|
an, bn, cn = np.array([a]).astype(d1.np), np.array([b]).astype(d1.np), np.array([c]).astype(d2.np)
|
||||||
tensor_value = op2[0](op1[0](at, bt).cast(d2), ct).numpy()
|
tensor_value = op2[0](op1[0](at, bt).cast(d2), ct).numpy()
|
||||||
numpy_value = op2[1](op1[1](an, bn).astype(d2.np), cn)
|
numpy_value = op2[1](op1[1](an, bn).astype(d2.np), cn)
|
||||||
np.testing.assert_almost_equal(tensor_value, numpy_value)
|
np.testing.assert_allclose(tensor_value, numpy_value, rtol=1e-6 if getenv("PTX") else 1e-7)
|
||||||
|
|
||||||
class TestDTypeALU(unittest.TestCase):
|
class TestDTypeALU(unittest.TestCase):
|
||||||
@unittest.skipIf(OSX and Device.DEFAULT in {"GPU", "METAL"}, "no float64 on OSX GPU")
|
@unittest.skipIf(OSX and Device.DEFAULT in {"GPU", "METAL"}, "no float64 on OSX GPU")
|
||||||
|
|
|
@ -30,7 +30,7 @@ class TestFusionOp(unittest.TestCase):
|
||||||
sched = create_schedule([a.lazydata], None)
|
sched = create_schedule([a.lazydata], None)
|
||||||
ji = lower_schedule_item(sched[-1])
|
ji = lower_schedule_item(sched[-1])
|
||||||
self.assertLess(time.perf_counter()-st, 1.0)
|
self.assertLess(time.perf_counter()-st, 1.0)
|
||||||
assert isinstance(ji, InterpretedASTRunner) or len(ji.prg) < 5000
|
assert isinstance(ji, InterpretedASTRunner) or len(ji.prg.splitlines()) < 250
|
||||||
|
|
||||||
def test_recursive_add_cmp(self):
|
def test_recursive_add_cmp(self):
|
||||||
st = time.perf_counter()
|
st = time.perf_counter()
|
||||||
|
|
|
@ -34,7 +34,7 @@ def compile_cuda(prg) -> bytes: return compile_cuda_style(prg, [f'--gpu-architec
|
||||||
class CUDAProgram:
|
class CUDAProgram:
|
||||||
def __init__(self, device:CUDADevice, name:str, lib:bytes):
|
def __init__(self, device:CUDADevice, name:str, lib:bytes):
|
||||||
self.device, self.name, self.lib = device, name, lib
|
self.device, self.name, self.lib = device, name, lib
|
||||||
if DEBUG >= 5: print(pretty_ptx(lib.decode('utf-8')))
|
if DEBUG >= 5: print("\n".join([f"{i+1:>3} {line}" for i, line in enumerate(pretty_ptx(lib.decode('utf-8')).split("\n"))]))
|
||||||
if DEBUG >= 6:
|
if DEBUG >= 6:
|
||||||
try:
|
try:
|
||||||
fn = (Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
|
fn = (Path(tempfile.gettempdir()) / f"tinycuda_{hashlib.md5(lib).hexdigest()}").as_posix()
|
||||||
|
|
Loading…
Reference in New Issue