mirror of https://github.com/commaai/tinygrad.git
bring ptx back (#3623)
* bring ptx back * ptx back * fix define var * fix a few bugs * bugfixes * fixes * fix llvm bug * fix test bug
This commit is contained in:
parent
c270d54c32
commit
81baf3eed3
|
@ -346,7 +346,7 @@ jobs:
|
|||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
backend: [llvm, clang, gpu, cuda, hip] #, triton] #, ptx]
|
||||
backend: [llvm, clang, gpu, cuda, hip, ptx] #, triton]
|
||||
|
||||
name: Tests on (${{ matrix.backend }})
|
||||
runs-on: ubuntu-latest
|
||||
|
|
|
@ -784,8 +784,9 @@ class TestLinearizerUOptimize(unittest.TestCase):
|
|||
assert store_val.dtype == dtypes.float.vec(4) and store_val.uop != UOps.CAST
|
||||
|
||||
def test_grouped_store_locals_and_globals(self):
|
||||
if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local or not Device[Device.DEFAULT].compiler.linearizer_opts.has_shared:
|
||||
self.skipTest("Only Compiled uses linearizer with locals and shared")
|
||||
if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local or not Device[Device.DEFAULT].compiler.linearizer_opts.has_shared or \
|
||||
not Device[Device.DEFAULT].compiler.linearizer_opts.supports_float4:
|
||||
self.skipTest("Only Compiled uses linearizer with locals, shared, and float4")
|
||||
|
||||
x, y = Tensor.rand(128, 128), Tensor.rand(128, 128)
|
||||
out = x@y
|
||||
|
@ -808,8 +809,9 @@ class TestLinearizerUOptimize(unittest.TestCase):
|
|||
assert len([u for u in k.uops if u.uop is UOps.IF and u.vin[-1] == barrier]) == 1
|
||||
|
||||
def test_grouped_store_local_only(self):
|
||||
if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local or not Device[Device.DEFAULT].compiler.linearizer_opts.has_shared:
|
||||
self.skipTest("Only Compiled uses linearizer with locals and shared")
|
||||
if not Device[Device.DEFAULT].compiler.linearizer_opts.has_local or not Device[Device.DEFAULT].compiler.linearizer_opts.has_shared or \
|
||||
not Device[Device.DEFAULT].compiler.linearizer_opts.supports_float4:
|
||||
self.skipTest("Only Compiled uses linearizer with locals, shared, and float4")
|
||||
|
||||
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
|
||||
r = (x@y).relu()
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# ruff: noqa: E501
|
||||
import unittest
|
||||
from tinygrad import dtypes, Device
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.features.search import Opt, OptOps
|
||||
from tinygrad.features.search import time_linearizer, bufs_from_lin
|
||||
|
@ -63,7 +64,8 @@ class TestLinearizerOverflow(unittest.TestCase):
|
|||
opts = [Opt(op=OptOps.UPCAST, axis=3, amt=4), Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=8), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=2, amt=4)]
|
||||
_test_overflow(ast, opts)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT not in {"GPU", "HIP", "HSA", "CUDA", "METAL"}, "only backends with locals")
|
||||
#@unittest.skipIf(Device.DEFAULT not in {"GPU", "HIP", "HSA", "CUDA", "METAL"}, "only backends with locals")
|
||||
@unittest.skipIf(CI, "slow")
|
||||
class TestLinearizerOverflowAlt(unittest.TestCase):
|
||||
def test_overflow_1(self):
|
||||
BS = 2
|
||||
|
|
|
@ -2,12 +2,10 @@ import unittest
|
|||
|
||||
from test.helpers import assert_jit_cache_len
|
||||
from tinygrad.features.jit import TinyJit
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.tensor import Tensor
|
||||
import numpy as np
|
||||
|
||||
@unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported")
|
||||
class TestSymbolicJit(unittest.TestCase):
|
||||
def test_plus1(self):
|
||||
def f(a): return (a+1).realize()
|
||||
|
|
|
@ -5,7 +5,6 @@ from tinygrad.tensor import Tensor
|
|||
from examples.gpt2 import Attention
|
||||
import numpy as np
|
||||
|
||||
@unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported")
|
||||
class TestSymbolicOps(unittest.TestCase):
|
||||
def test_plus1(self):
|
||||
def f(a): return (a+1).realize()
|
||||
|
|
|
@ -2,12 +2,11 @@ from typing import Optional, Tuple, Any, List
|
|||
import unittest, math
|
||||
import numpy as np
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.device import CompiledASTRunner, Compiled
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.codegen.uops import exec_alu
|
||||
from tinygrad.codegen.uops import exec_alu, UOpGraph
|
||||
from test.test_dtype import is_dtype_supported
|
||||
|
||||
def _uops_to_prg(uops):
|
||||
|
@ -29,7 +28,7 @@ def _test_single_value(vals, op, dts):
|
|||
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
||||
buf = Buffer(Device.DEFAULT, 1, output_dtype)
|
||||
buf2 = [Buffer(Device.DEFAULT, 1, dtype).copyin(np.array([a], dtype=dtype.np).data) for a,dtype in zip(vals, dts)]
|
||||
prg = _uops_to_prg(uops)
|
||||
prg = _uops_to_prg(UOpGraph(uops))
|
||||
prg.exec([buf]+buf2)
|
||||
ret = np.empty(1, output_dtype.np)
|
||||
buf.copyout(ret.data)
|
||||
|
@ -43,7 +42,7 @@ def _test_single_value_const(vals, op, dts):
|
|||
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
|
||||
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
||||
buf = Buffer(Device.DEFAULT, 1, output_dtype)
|
||||
prg = _uops_to_prg(uops)
|
||||
prg = _uops_to_prg(UOpGraph(uops))
|
||||
prg.exec([buf])
|
||||
ret = np.empty(1, output_dtype.np)
|
||||
buf.copyout(ret.data)
|
||||
|
@ -88,26 +87,51 @@ class TestFloatUOps(TestUOps):
|
|||
# MOD isn't tested on floats
|
||||
|
||||
def test_where(self):
|
||||
self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (PtrDType(dtypes.bool), PtrDType(dtypes.float), PtrDType(dtypes.float)))
|
||||
self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float, dtypes.float))
|
||||
|
||||
# TODO: fix this on all the backends
|
||||
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or getenv('ARM64', False), "only test for compiled backends, broken on some")
|
||||
class TestNonFloatUOps(TestUOps):
|
||||
def test_neg_int32(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a, (PtrDType(dtypes.int32), ))
|
||||
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), (PtrDType(dtypes.int32), PtrDType(dtypes.int32)))
|
||||
def test_sub_int32(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: int(a)-int(b), (PtrDType(dtypes.int32), PtrDType(dtypes.int32)))
|
||||
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), (PtrDType(dtypes.int32), PtrDType(dtypes.int32)))
|
||||
def test_neg_int32(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a, (dtypes.int32, ))
|
||||
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_sub_int32(self): self._test_bop_fxn(BinaryOps.SUB, lambda a,b: int(a)-int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_mul_int32(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: int(a)*int(b), (dtypes.int32, dtypes.int32))
|
||||
def test_div_int32(self):
|
||||
self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), (PtrDType(dtypes.int32), PtrDType(dtypes.int32)), no_b_zero=True)
|
||||
self._test_bop_fxn(BinaryOps.DIV, lambda a,b: int(a/b), (dtypes.int32, dtypes.int32), no_b_zero=True)
|
||||
def test_mod_int32(self):
|
||||
self._test_bop_fxn(BinaryOps.MOD,
|
||||
lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], (PtrDType(dtypes.int32), PtrDType(dtypes.int32)), no_b_zero=True)
|
||||
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), (PtrDType(dtypes.int32), PtrDType(dtypes.int32)))
|
||||
lambda a,b: abs(int(a))%abs(int(b))*(1,-1)[a<0], (dtypes.int32, dtypes.int32), no_b_zero=True)
|
||||
def test_cmplt_int32(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: float(a<b), (dtypes.int32, dtypes.int32))
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.bool), "dtype not supported")
|
||||
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), (PtrDType(dtypes.bool), PtrDType(dtypes.bool)))
|
||||
def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) and bool(b), (dtypes.bool, dtypes.bool))
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.float16), "dtype not supported")
|
||||
def test_where_float16(self):
|
||||
self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (PtrDType(dtypes.bool), PtrDType(dtypes.float16), PtrDType(dtypes.float16)))
|
||||
self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (dtypes.bool, dtypes.float16, dtypes.float16))
|
||||
|
||||
class TestBoolUOps(TestUOps):
|
||||
def _test_uop_bool_fxn(self, op, fxn):
|
||||
for f in [_test_single_value, _test_single_value_const]:
|
||||
for a in [False, True]:
|
||||
self._equal(f([a], op, (dtypes.bool, )*1), fxn(a))
|
||||
|
||||
def _test_bop_bool_fxn(self, op, fxn):
|
||||
for f in [_test_single_value, _test_single_value_const]:
|
||||
for a in [False, True]:
|
||||
for b in [False, True]:
|
||||
self._equal(f([a,b], op, (dtypes.bool, )*2), fxn(a,b))
|
||||
|
||||
def _test_top_bool_fxn(self, op, fxn):
|
||||
for f in [_test_single_value, _test_single_value_const]:
|
||||
for a in [False, True]:
|
||||
for b in [False, True]:
|
||||
for c in [False, True]:
|
||||
self._equal(f([a,b,c], op, (dtypes.bool, )*3), fxn(a,b,c))
|
||||
|
||||
def test_not_bool(self): self._test_uop_bool_fxn(UnaryOps.NEG, lambda a: not a)
|
||||
def test_add_bool(self): self._test_bop_bool_fxn(BinaryOps.ADD, lambda a,b: a or b)
|
||||
def test_mul_bool(self): self._test_bop_bool_fxn(BinaryOps.MUL, lambda a,b: a and b)
|
||||
def test_xor_bool(self): self._test_bop_bool_fxn(BinaryOps.XOR, lambda a,b: a != b)
|
||||
def test_cmpeq_bool(self): self._test_bop_bool_fxn(BinaryOps.CMPEQ, lambda a,b: a == b)
|
||||
def test_cmplt_bool(self): self._test_bop_bool_fxn(BinaryOps.CMPLT, lambda a,b: a < b)
|
||||
def test_where_bool(self): self._test_top_bool_fxn(TernaryOps.WHERE, lambda a,b,c: b if a else c)
|
||||
|
||||
class TestExecALU(TestUOps):
|
||||
def test_sqrt(self):
|
||||
|
|
|
@ -56,9 +56,9 @@ def uop_alu_resolve(u:UOp) -> sint:
|
|||
def phi_resolve_acc(u:UOp) -> UOp: return u if u.uop is UOps.DEFINE_ACC else phi_resolve_acc(u.vin[0])
|
||||
|
||||
class UOpGraph:
|
||||
def __init__(self):
|
||||
def __init__(self, start_uops:Optional[List[UOp]]=None):
|
||||
# list of uops
|
||||
self.uops: List[UOp] = []
|
||||
self.uops: List[UOp] = [] if start_uops is None else start_uops
|
||||
|
||||
# global uop cache
|
||||
self.saved_exprs: Dict[Tuple, UOp] = dict()
|
||||
|
@ -88,7 +88,8 @@ class UOpGraph:
|
|||
if arg is BinaryOps.ADD and vin[1].uop is UOps.ALU and vin[1].arg is UnaryOps.NEG:
|
||||
return self.add(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable, insert_before)
|
||||
# constant folding
|
||||
if arg is UnaryOps.NEG and vin[0].uop is UOps.CONST: return self.add(UOps.CONST, dtype, arg=-vin[0].arg, insert_before=insert_before)
|
||||
if arg is UnaryOps.NEG and vin[0].uop is UOps.CONST:
|
||||
return self.add(UOps.CONST, dtype, arg=-vin[0].arg if dtype != dtypes.bool else not vin[0].arg, insert_before=insert_before)
|
||||
if arg is TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop
|
||||
if arg is TernaryOps.WHERE and vin[0].uop is UOps.CONST: return vin[1] if vin[0].arg else vin[2]
|
||||
if arg is BinaryOps.MUL and vin[0].uop is UOps.CONST and vin[1].uop is UOps.CONST and dtype is not None and dtypes.is_float(dtype):
|
||||
|
|
|
@ -238,7 +238,7 @@ class Compiled:
|
|||
ops, mem = k.uops.flops_mem()
|
||||
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
|
||||
# NOTE: we use min here to ignore the indexing FLOPS
|
||||
ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops.uops), self, k.global_size, k.local_size,
|
||||
ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops), self, k.global_size, k.local_size,
|
||||
k.uops.vars(), min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
|
||||
return ret
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ from collections import defaultdict
|
|||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
|
||||
from tinygrad.dtype import dtypes, DType, PtrDType, INVERSE_DTYPES_DICT
|
||||
from tinygrad.codegen.uops import UOpGraph
|
||||
|
||||
def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
|
||||
def double_to_hex(x): return "%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
|
||||
|
@ -35,11 +36,38 @@ class AssemblyLanguage(NamedTuple):
|
|||
|
||||
def render_kernel(self, kernel, function_name, bufs, regs) -> str: raise NotImplementedError()
|
||||
|
||||
def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:List[UOp]) -> str:
|
||||
def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
local_size: List[int] = []
|
||||
kernel:List[str] = []
|
||||
bufs = []
|
||||
|
||||
# here we do a pretransform on UOps to fix some shortcomings of PTX
|
||||
# all uops must be a register
|
||||
# TODO: uops class should make these rewrites easier
|
||||
replace: Dict[UOp, UOp] = {}
|
||||
for u in uops:
|
||||
for o,n in replace.items():
|
||||
if o in u.vin and u is not n:
|
||||
u.vin = tuple(n if x == o else x for x in u.vin)
|
||||
if u.uop is UOps.LOAD and u.dtype is dtypes.bool:
|
||||
# rewrite load bool
|
||||
if len(u.vin) == 4:
|
||||
new = uops.add(UOps.CAST, dtypes.uint8, (u.vin[3],), insert_before=uops.uops.index(u))
|
||||
u.vin = u.vin[0:3] + (new,)
|
||||
u.dtype = dtypes.uint8
|
||||
new = uops.add(UOps.CAST, dtypes.bool, (u,), insert_before=uops.uops.index(u)+1)
|
||||
replace[u] = new
|
||||
if u.uop is UOps.ALU and u.arg in {BinaryOps.CMPEQ, BinaryOps.CMPLT} and u.vin[0].dtype is dtypes.bool:
|
||||
if u.arg == BinaryOps.CMPEQ:
|
||||
u.arg = BinaryOps.XOR
|
||||
new = uops.add(UOps.ALU, dtypes.bool, (u,), arg=UnaryOps.NEG, insert_before=uops.uops.index(u)+1)
|
||||
replace[u] = new
|
||||
if u.arg == BinaryOps.CMPLT:
|
||||
new = uops.add(UOps.ALU, dtypes.bool, (u.vin[0],), arg=UnaryOps.NEG, insert_before=uops.uops.index(u))
|
||||
u.vin = (new, u.vin[1])
|
||||
u.arg = BinaryOps.MUL
|
||||
#uops.print()
|
||||
|
||||
def kk(*s: str): kernel.append("\n".join(s))
|
||||
|
||||
c: DefaultDict[str, int] = defaultdict(int)
|
||||
|
@ -78,12 +106,12 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:List[UOp]) -> str
|
|||
assert vin[0].dtype is not None
|
||||
kk(*lang.render_bra(lb:=ssa_label(u, 'if'), cast(r[vin[0]], dtypes.bool, vin[0].dtype, u=u, pred=True), f"{lb}_true"), f"{lb}_true:")
|
||||
elif uop == UOps.BARRIER and lang.barrier: kk(lang.barrier)
|
||||
elif uop == UOps.END:
|
||||
if vin[0].uop == UOps.LOOP:
|
||||
kk(lang.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, lang.types[dtypes.int]),
|
||||
lang.asm_for_op[BinaryOps.CMPLT](pred:=ssa(None, "pred", "pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, lang.types[dtypes.int]))
|
||||
kk(*lang.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
|
||||
else: kk(f"{r_label[vin[0]]}:")
|
||||
elif uop == UOps.ENDLOOP:
|
||||
kk(lang.asm_for_op[BinaryOps.ADD](r[vin[0]], r[vin[0]], "1", dtypes.int, lang.types[dtypes.int]),
|
||||
lang.asm_for_op[BinaryOps.CMPLT](pred:=ssa(None, "pred", "pred"), r[vin[0]], r[vin[0].vin[1]], dtypes.int, lang.types[dtypes.int]))
|
||||
kk(*lang.render_bra(r_label[vin[0]], pred, f"{r_label[vin[0]]}_exit"), f"{r_label[vin[0]]}_exit:")
|
||||
elif uop == UOps.ENDIF:
|
||||
kk(f"{r_label[vin[0]]}:")
|
||||
elif uop == UOps.STORE:
|
||||
assert vin[0].dtype is not None and vin[1].dtype is not None and vin[2].dtype is not None
|
||||
kk(*lang.render_gep(loc:=ssa(None,'loc','u64'), r[vin[0]], r[vin[1]], vin[0].dtype))
|
||||
|
@ -97,13 +125,10 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:List[UOp]) -> str
|
|||
elif uop == UOps.ALU:
|
||||
assert vin[0].dtype is not None
|
||||
if args == BinaryOps.CMPLT or args == BinaryOps.CMPEQ:
|
||||
regs = [cast(r[x], dtypes.int16, dtypes.bool) if x.dtype == dtypes.bool else r[x] for x in vin]
|
||||
dt = dtypes.int16 if vin[0].dtype == dtypes.bool else vin[0].dtype
|
||||
kk(lang.asm_for_op[args](pred:=ssa(u,'lt','pred'), *regs, dt, lang.types[dt]))
|
||||
elif args == TernaryOps.MULACC:
|
||||
assert vin[1].dtype is not None
|
||||
kk(lang.asm_for_op[args](ssa(u, 'alu'), *[r[x] for x in vin], dtype, lang.types[vin[1].dtype]))
|
||||
else: kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], dtype, lang.types[dtype]))
|
||||
# pass in the other dtype here
|
||||
kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], vin[0].dtype, lang.types[vin[0].dtype]))
|
||||
else:
|
||||
kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], dtype, lang.types[dtype]))
|
||||
elif uop == UOps.DEFINE_ACC: kk(f"mov.b{lang.types[dtype][1:]} {ssa(u, 'acc')}, {const(args, dtype)};")
|
||||
elif uop == UOps.SPECIAL:
|
||||
if args[1][0] == "i": kk(f"mov.u32 %{args[1]}, {lang.gid[args[0]]};", f"mov.u32 {(gdim:=ssa(None,'tmp','u32'))}, {lang.gdim[args[0]]};",
|
||||
|
@ -133,12 +158,17 @@ def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:List[UOp]) -> str
|
|||
assert vin[0].dtype is not None
|
||||
cast(r[vin[0]], dtype, vin[0].dtype, bitcast=isinstance(args, tuple) and args[1], u=u)
|
||||
elif uop == UOps.DEFINE_LOCAL: kk(*lang.render_local(ssa(u, 'local', lang.types[dtypes.ulong]), args[0], args[1], dtype))
|
||||
elif uop == UOps.DEFINE_GLOBAL:
|
||||
bufs.append((args, dtype))
|
||||
r[u] = f"%{args}"
|
||||
elif uop is UOps.DEFINE_VAR:
|
||||
bufs.append((args.expr, dtype))
|
||||
r[u] = f"%{args.expr}"
|
||||
if lang.load_global:
|
||||
kk(*lang.render_load(args.expr, ssa(u, 'dat', dtype=lang.types[dtype]), dtype, ss=".param"))
|
||||
elif uop is UOps.DEFINE_GLOBAL:
|
||||
bufs.append((args[1], dtype))
|
||||
r[u] = f"%{args[1]}"
|
||||
if lang.load_global:
|
||||
dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype
|
||||
kk(*lang.render_load(args, ssa(u, 'dat', dtype=lang.types[dt]), dt, ss=".param"))
|
||||
kk(*lang.render_load(args[1], ssa(u, 'dat', dtype=lang.types[dt]), dt, ss=".param"))
|
||||
else: raise NotImplementedError(f"no code for {uop}")
|
||||
|
||||
return lang.render_kernel(kernel, function_name, bufs, c.items())
|
||||
|
@ -156,24 +186,22 @@ class PTXLanguage(AssemblyLanguage):
|
|||
gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)]
|
||||
lid = [f'%tid.{chr(120+i)}' for i in range(3)]
|
||||
asm_for_op = {
|
||||
UnaryOps.NEG: lambda d,a,dt,name: f"neg.{name} {d}, {a};",
|
||||
UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"neg.{name} {d}, {a};",
|
||||
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
|
||||
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};",
|
||||
UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
|
||||
BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
|
||||
BinaryOps.SUB: lambda d,a,b,dt,name: f"sub.{name} {d}, {a}, {b};",
|
||||
BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
|
||||
BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.b{name[1:]} {d}, {a}, {b};",
|
||||
BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
|
||||
BinaryOps.DIV: lambda d,a,b,dt,name: f"div{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a}, {b};",
|
||||
BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
|
||||
BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};",
|
||||
BinaryOps.CMPEQ: lambda d,a,b,dt,name: f"setp.eq.{name} {d}, {a}, {b};",
|
||||
TernaryOps.MULACC: lambda d,a,b,c,dt,name: (('fma.rn' if dtypes.is_float(dt) else 'mad.lo' if a.split('_')[1]==c.split('_')[1] else 'mad.wide') +
|
||||
f".{name} {d}, {a}, {b}, {c};"),
|
||||
TernaryOps.WHERE: lambda d,a,b,c,dt,name: f"selp.{name} {d}, {b}, {c}, {a};"
|
||||
TernaryOps.WHERE: lambda d,a,b,c,dt,name:
|
||||
f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};"
|
||||
}
|
||||
supports_half = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT,
|
||||
TernaryOps.MULACC, TernaryOps.WHERE]
|
||||
supports_half = [UnaryOps.NEG, UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE]
|
||||
types = {
|
||||
dtypes.int8: "s16", dtypes.int16: "s16", dtypes.int32: "s32", dtypes.int64: "s64",
|
||||
dtypes.uint8: "u16", dtypes.uint16: "u16", dtypes.uint32: "u32", dtypes.uint64: "u64",
|
||||
|
@ -206,14 +234,11 @@ class PTXLanguage(AssemblyLanguage):
|
|||
def mem_type(self, dtype): return 's8' if dtype.itemsize == 1 else 'b16' if dtype == dtypes.float16 else self.types[dtype]
|
||||
|
||||
def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="") -> List[str]:
|
||||
assert dtype is not dtypes.bool
|
||||
ret = []
|
||||
if (byte:=dtype.itemsize == 1): ret.append(f".reg .s8 {dest}_tmp;")
|
||||
if (isbool:= dtype == dtypes.bool): ret.append(f".reg .s16 {dest}_bool;")
|
||||
if gate: ret.extend([f"@{gate} ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}];",
|
||||
f"@!{gate} mov.b{'8' if byte else self.types[dtype][1:]} {dest + ('_tmp' if byte else '')}, {alt};"])
|
||||
else: ret.append(f"ld{ss}.{'s8' if byte else 'b16' if dtype==dtypes.float16 else self.types[dtype]} {dest + ('_tmp' if byte else '')}, [{loc}];")
|
||||
if byte: ret.append(f"cvt.{'s16' if isbool else self.types[dtype]}.s8 {dest + ('_bool' if isbool else '')}, {dest}_tmp;")
|
||||
if isbool: ret.append(f"setp.ne.s16 {dest}, {dest}_bool, {self.render_const(0, dtypes.int16)};")
|
||||
f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"])
|
||||
else: ret.append(f"ld{ss}.{self.mem_type(dtype)} {dest}, [{loc}];")
|
||||
return ret
|
||||
|
||||
def render_store(self, loc, val, dtype, gate=None, ss="") -> List[str]:
|
|
@ -5,6 +5,7 @@ from tinygrad.codegen.linearizer import UOps, UOp
|
|||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.helpers import strip_parens, getenv
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
||||
from tinygrad.codegen.uops import UOpGraph
|
||||
|
||||
class CStyleLanguage(NamedTuple):
|
||||
kernel_prefix: str = ""
|
||||
|
@ -24,7 +25,7 @@ class CStyleLanguage(NamedTuple):
|
|||
uses_ptr_arithmetic: bool = False
|
||||
type_map: Dict[DType, str] = {}
|
||||
code_for_op: Dict = {
|
||||
UnaryOps.NEG: lambda x,dtype: f"(-{x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
|
||||
UnaryOps.NEG: lambda x,dtype: f"(!{x})" if dtype is dtypes.bool else f"(-{x})", UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})",
|
||||
BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})",
|
||||
BinaryOps.DIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.MAX: lambda a,b,dtype: f"max({a},{b})", BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})",
|
||||
|
@ -61,7 +62,7 @@ class CStyleLanguage(NamedTuple):
|
|||
out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
|
||||
return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val
|
||||
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str:
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:UOpGraph, prefix=None) -> str:
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
|
||||
buftypes = [(name,f"{'write_only' if mutable else 'read_only'} image2d_t" if dtype.name.startswith('image') else
|
||||
("" if mutable else "const ")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else
|
||||
|
@ -86,7 +87,7 @@ class CStyleLanguage(NamedTuple):
|
|||
def render_local(self, name:str, dtype:DType, size:int): return self.smem_align + self.smem_prefix + f"{dtype.name} {name}[{size}];"
|
||||
def render_dtype(self, var_dtype:DType) -> str: return self.type_map[var_dtype] if var_dtype in self.type_map else var_dtype.name
|
||||
|
||||
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> str:
|
||||
def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:UOpGraph) -> str:
|
||||
kernel = []
|
||||
bufs: List[Tuple[str, Tuple[DType, bool]]] = []
|
||||
#pend_close = None
|
||||
|
|
|
@ -3,13 +3,15 @@ from llvmlite import ir
|
|||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.dtype import DType, PtrDType, dtypes
|
||||
from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.codegen.uops import UOpGraph
|
||||
|
||||
MFLAGS = ('nsz', 'arcp', 'contract', 'afn', 'reassoc') # All from fast math, but nnan and ninf
|
||||
|
||||
def is_bool_or_unsigned(dtype: DType): return dtype == dtypes.bool or dtypes.is_unsigned(dtype)
|
||||
|
||||
code_for_op: Final[Dict[Op, Callable]] = {
|
||||
UnaryOps.NEG: lambda builder, x, var_dtype: builder.neg(x) if dtypes.is_int(var_dtype) else builder.fneg(x, flags=MFLAGS),
|
||||
UnaryOps.NEG: lambda builder, x, var_dtype: builder.neg(x) if dtypes.is_int(var_dtype) else \
|
||||
(builder.not_(x) if var_dtype is dtypes.bool else builder.fneg(x, flags=MFLAGS)),
|
||||
UnaryOps.EXP2: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.exp2', [x.type]), [x], fastmath=MFLAGS),
|
||||
UnaryOps.LOG2: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.log2', [x.type]), [x], fastmath=MFLAGS),
|
||||
UnaryOps.SIN: lambda builder, x, var_dtype: builder.call(builder._block.module.declare_intrinsic('llvm.sin', [x.type]), [x], fastmath=MFLAGS),
|
||||
|
@ -65,7 +67,7 @@ def const(args, dtype):
|
|||
# TODO: remove int from int(args) once const args conform with dtype
|
||||
return ir.Constant(dtype_to_llvm_dtype[dtype], int(args) if dtypes.is_int(dtype) else bool(args) if dtype == dtypes.bool else args)
|
||||
|
||||
def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> str:
|
||||
def uops_to_llvm_ir(function_name:str, uops:UOpGraph) -> str:
|
||||
# all llvm stuff goes into a module
|
||||
module = ir.Module(name=__file__)
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ from tinygrad.helpers import DEBUG, getenv, from_mv, to_char_p_p, init_c_var, co
|
|||
from tinygrad.device import Compiled, LRUAllocator, MallocAllocator, Compiler
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.renderer.cstyle import CUDARenderer
|
||||
from tinygrad.renderer.assembly import PTXRenderer
|
||||
|
||||
def pretty_ptx(s):
|
||||
# all expressions match `<valid_before><expr><valid_after>` and replace it with `<valid_before>color(<expr>)<valid_after>`
|
||||
|
@ -33,6 +34,15 @@ def _get_bytes(arg, get_str, get_sz, check) -> bytes:
|
|||
sz = init_c_var(ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x))))
|
||||
return ctypes.string_at(init_c_var(ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))), size=sz.value)
|
||||
|
||||
class PTXCompiler(Compiler):
|
||||
linearizer_opts = LinearizerOptions("CUDA", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024], supports_float4=False)
|
||||
def __init__(self, arch:str):
|
||||
self.arch = arch
|
||||
PTXCompiler.linearizer_opts = PTXCompiler.linearizer_opts._replace(has_tensor_cores=int(arch[3:]) >= 80)
|
||||
super().__init__(f"compile_ptx_{self.arch}")
|
||||
def render(self, name:str, uops) -> str: return PTXRenderer(name, uops).replace("TARGET", self.arch)
|
||||
def compile(self, src:str) -> bytes: return src.encode()
|
||||
|
||||
class CUDACompiler(Compiler):
|
||||
linearizer_opts = LinearizerOptions("CUDA", global_max=[65535, 65535, 2147483647], local_max=[64, 1024, 1024])
|
||||
def __init__(self, arch:str):
|
||||
|
@ -100,7 +110,8 @@ class CUDADevice(Compiled):
|
|||
self.arch = f"sm_{major.value}{minor.value}" if not CUDACPU else "sm_35"
|
||||
|
||||
from tinygrad.runtime.graph.cuda import CUDAGraph
|
||||
super().__init__(device, CUDAAllocator(self) if not CUDACPU else MallocAllocator, CUDACompiler(self.arch),
|
||||
super().__init__(device, CUDAAllocator(self) if not CUDACPU else MallocAllocator,
|
||||
PTXCompiler(self.arch) if getenv("PTX") else CUDACompiler(self.arch),
|
||||
functools.partial(CUDAProgram, self), graph=CUDAGraph if not CUDACPU else None)
|
||||
def synchronize(self):
|
||||
if not CUDACPU:
|
||||
|
|
|
@ -6,7 +6,7 @@ import pickle, base64, itertools, time, struct
|
|||
from tinygrad.dtype import DType, dtypes, ImageDType
|
||||
from tinygrad.helpers import all_same, getenv, flatten
|
||||
from tinygrad.device import Compiled, Allocator, Compiler
|
||||
from tinygrad.codegen.uops import UOp, UOps, exec_alu
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps, exec_alu
|
||||
from tinygrad.ops import BinaryOps, TernaryOps
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
|
||||
|
@ -188,8 +188,8 @@ class PythonCompiler(Compiler):
|
|||
linearizer_opts = LinearizerOptions("METAL", has_tensor_cores=True) if getenv("EMULATE_METAL") else \
|
||||
(LinearizerOptions("HIP", has_tensor_cores=True) if getenv("EMULATE_HIP") else \
|
||||
(LinearizerOptions("CUDA", has_tensor_cores=True) if getenv("EMULATE_CUDA") else LinearizerOptions("PYTHON")))
|
||||
def render(self, name:str, uops:List[UOp]) -> str:
|
||||
lops = [(u.uop, u.dtype, [uops.index(v) for v in u.vin], u.arg) for u in uops]
|
||||
def render(self, name:str, uops:UOpGraph) -> str:
|
||||
lops = [(u.uop, u.dtype, [uops.uops.index(v) for v in u.vin], u.arg) for u in uops]
|
||||
return base64.b64encode(pickle.dumps(lops)).decode()
|
||||
def compile(self, src:str) -> bytes: return base64.b64decode(src)
|
||||
|
||||
|
|
Loading…
Reference in New Issue