mirror of https://github.com/commaai/tinygrad.git
use comgr to compile (#3248)
* use comgr to compile * fast * bfloat16 * move comgr to it's own file * cleaner style * comgr in new place * comgr free + dtype cleanup
This commit is contained in:
parent
c4d870db0d
commit
473935125a
File diff suppressed because it is too large
Load Diff
2
setup.py
2
setup.py
|
@ -15,7 +15,7 @@ setup(name='tinygrad',
|
|||
long_description=long_description,
|
||||
long_description_content_type='text/markdown',
|
||||
packages = ['tinygrad', 'tinygrad.runtime.autogen', 'tinygrad.codegen', 'tinygrad.nn', 'tinygrad.renderer',
|
||||
'tinygrad.runtime', 'tinygrad.runtime.graph', 'tinygrad.shape', 'tinygrad.features'],
|
||||
'tinygrad.runtime', 'tinygrad.runtime.compiler', 'tinygrad.runtime.graph', 'tinygrad.shape', 'tinygrad.features'],
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License"
|
||||
|
|
|
@ -14,6 +14,11 @@ class TestDeviceSpeed(unittest.TestCase):
|
|||
with Timing("compiler "):
|
||||
self.dev.compiler(self.empty)
|
||||
|
||||
def test_empty_compile_twice(self):
|
||||
self.dev.compiler(self.empty)
|
||||
with Timing("compiler "):
|
||||
self.dev.compiler(self.empty)
|
||||
|
||||
def test_launch_speed(self):
|
||||
prg_bin = self.dev.compiler(self.empty)
|
||||
prg = self.dev.runtime("test", prg_bin)
|
||||
|
|
|
@ -223,30 +223,66 @@ class CUDALanguage(CStyleLanguage):
|
|||
code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+x)}", "l": lambda x: f"threadIdx.{chr(120+x)}",
|
||||
"i": lambda x: f"(blockIdx.{chr(120+x)}*blockDim.{chr(120+x)}+threadIdx.{chr(120+x)})"}
|
||||
code_for_op = {**CStyleLanguage().code_for_op, **code_for_op_half}
|
||||
half_prekernel ="#include <cuda_fp16.h>\n"+"#include <cuda_bf16.h>\n"+"""
|
||||
half_prekernel = "#include <cuda_fp16.h>\n"+"#include <cuda_bf16.h>\n"+"""
|
||||
struct half4 { half x, y, z, w; };
|
||||
__device__ half4 make_half4(half x, half y, half z, half w) { half4 ret; ret.x = x; ret.y = y; ret.z = z; ret.w = w; return ret; }
|
||||
"""
|
||||
type_map = {dtypes.bfloat16: "nv_bfloat16"}
|
||||
CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage())
|
||||
|
||||
class HIPLanguage(CUDALanguage):
|
||||
code_for_op_hip = {
|
||||
BinaryOps.MAX: lambda a,b,dtype: f"__ocml_fmax_f32({a},{b})" if dtype != dtypes.half else f"__ocml_fmax_f16({a},{b})",
|
||||
UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f32({x})" if dtype != dtypes.half else f"__ocml_sqrt_f16({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f32({x})" if dtype != dtypes.half else f"__ocml_sin_f16({x})",
|
||||
UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f32({x})" if dtype != dtypes.half else f"__ocml_log2_f16({x})",
|
||||
UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f32({x})" if dtype != dtypes.half else f"__ocml_exp2_f16({x})",
|
||||
}
|
||||
|
||||
def _make_hip_dtype(base_type, name, cnt):
|
||||
nms = "xyzwabcdefghijkl"[:cnt]
|
||||
return f"typedef {base_type} {name}{cnt} __attribute__((ext_vector_type({cnt})));\n" + \
|
||||
f"static inline __attribute__((device)) {name}{cnt} make_{name}{cnt}(" + ', '.join([f"{base_type} {x}" for x in nms]) + \
|
||||
") { return {" + ', '.join(nms) + "}; }"
|
||||
|
||||
class HIPLanguage(CStyleLanguage):
|
||||
kernel_prefix = "#include <hip/hip_common.h>\n#define INFINITY (__builtin_inff())\n#define NAN (__builtin_nanf(\"\"))" + """
|
||||
typedef float float8 __attribute__((ext_vector_type(8)));
|
||||
__device__ float8 make_float8(float x, float y, float z, float w, float a, float b, float c, float d) { return {x, y, z, w, a, b, c, d}; }
|
||||
extern "C" __global__
|
||||
"""
|
||||
#define launch_bounds_impl0(requiredMaxThreadsPerBlock) \
|
||||
__attribute__((amdgpu_flat_work_group_size(1, requiredMaxThreadsPerBlock)))
|
||||
#define launch_bounds_impl1(requiredMaxThreadsPerBlock, minBlocksPerMultiprocessor) \
|
||||
__attribute__((amdgpu_flat_work_group_size(1, requiredMaxThreadsPerBlock), amdgpu_waves_per_eu(minBlocksPerMultiprocessor)))
|
||||
#define select_impl_(_1, _2, impl_, ...) impl_
|
||||
#define __launch_bounds__(...) select_impl_(__VA_ARGS__, launch_bounds_impl1, launch_bounds_impl0)(__VA_ARGS__)
|
||||
typedef long unsigned int size_t;
|
||||
#define half _Float16
|
||||
struct hip_bfloat16 { unsigned short data; };
|
||||
|
||||
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_id(unsigned int);
|
||||
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_group_id(unsigned int);
|
||||
extern "C" __attribute__((device)) __attribute__((const)) size_t __ockl_get_local_size(unsigned int);
|
||||
|
||||
extern "C" {
|
||||
__attribute__((device)) __attribute__((const)) float __ocml_fmax_f32(float, float);
|
||||
__attribute__((device)) __attribute__((pure)) float __ocml_exp2_f32(float);
|
||||
__attribute__((device)) __attribute__((pure)) float __ocml_log2_f32(float);
|
||||
__attribute__((device)) float __ocml_sin_f32(float);
|
||||
__attribute__((device)) __attribute__((const)) float __ocml_sqrt_f32(float);
|
||||
__attribute__((device)) __attribute__((const)) _Float16 __ocml_fmax_f16(_Float16, _Float16);
|
||||
__attribute__((device)) __attribute__((pure)) _Float16 __ocml_exp2_f16(_Float16);
|
||||
__attribute__((device)) __attribute__((pure)) _Float16 __ocml_log2_f16(_Float16);
|
||||
__attribute__((device)) _Float16 __ocml_sin_f16(_Float16);
|
||||
__attribute__((device)) __attribute__((const)) _Float16 __ocml_sqrt_f16(_Float16);
|
||||
}\n""" + '\n'.join([_make_hip_dtype(*x) for x in [("signed int", "int", 2),
|
||||
("_Float16", "half", 2), ("_Float16", "half", 4), ("_Float16", "half", 8), ("_Float16", "half", 16),
|
||||
("float", "float", 2), ("float", "float", 4), ("float", "float", 8)]]) + \
|
||||
'extern "C" __attribute__((global))'
|
||||
code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
|
||||
"i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
|
||||
code_for_op = {**CStyleLanguage().code_for_op, **code_for_op_hip}
|
||||
smem_prefix = "__attribute__((shared))"
|
||||
barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \
|
||||
'__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");'
|
||||
float4 = "make_float4"
|
||||
launch_bounds = True
|
||||
uses_ptr_arithmetic = True
|
||||
half_prekernel = "#include <hip/hip_fp16.h>\n" + """
|
||||
typedef union { struct { half x, y, z, w; } __attribute__((aligned(8))); half data[4]; } half4;
|
||||
__device__ half4 make_half4(half x, half y, half z, half w) { return {x, y, z, w}; }
|
||||
typedef union { struct { half x, y, z, w, a, b, c, d; } __attribute__((aligned(16))); half data[8]; } half8;
|
||||
__device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { return {x, y, z, w, a, b, c, d}; }
|
||||
typedef _Float16 half16 __attribute__((ext_vector_type(16)));
|
||||
__device__ half16 make_half16(half x, half y, half z, half w, half a, half b, half c, half d,
|
||||
half e, half f, half g, half h, half i, half j, half k, half l) {
|
||||
return {x, y, z, w, a, b, c, d, e, f, g, h, i, j, k, l}; }
|
||||
"""
|
||||
type_map = {dtypes.bfloat16: "hip_bfloat16"}
|
||||
HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage())
|
||||
HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage())
|
|
@ -0,0 +1,47 @@
|
|||
import ctypes
|
||||
import tinygrad.runtime.autogen.comgr as comgr
|
||||
|
||||
def check(status):
|
||||
if status != 0:
|
||||
comgr.amd_comgr_status_string(status, ctypes.byref(status_str := ctypes.POINTER(ctypes.c_char)()))
|
||||
raise RuntimeError(f"comgr fail {status}, {ctypes.string_at(status_str).decode()}")
|
||||
|
||||
def _get_comgr_data(data_set, data_type):
|
||||
check(comgr.amd_comgr_action_data_get_data(data_set, data_type, 0, ctypes.byref(data_exec := comgr.amd_comgr_data_t())))
|
||||
check(comgr.amd_comgr_get_data(data_exec, ctypes.byref(sz := ctypes.c_uint64()), None))
|
||||
check(comgr.amd_comgr_get_data(data_exec, ctypes.byref(sz), (dat := ctypes.create_string_buffer(sz.value))))
|
||||
check(comgr.amd_comgr_release_data(data_exec))
|
||||
return bytes(dat)
|
||||
|
||||
# AMD_COMGR_SAVE_TEMPS=1 AMD_COMGR_REDIRECT_LOGS=stdout AMD_COMGR_EMIT_VERBOSE_LOGS=1
|
||||
def compile_hip(prg:str, arch="gfx1100") -> bytes:
|
||||
check(comgr.amd_comgr_create_action_info(ctypes.byref(action_info := comgr.amd_comgr_action_info_t())))
|
||||
check(comgr.amd_comgr_action_info_set_language(action_info, comgr.AMD_COMGR_LANGUAGE_HIP))
|
||||
check(comgr.amd_comgr_action_info_set_isa_name(action_info, b"amdgcn-amd-amdhsa--" + arch.encode()))
|
||||
check(comgr.amd_comgr_action_info_set_logging(action_info, True))
|
||||
|
||||
check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_src := comgr.amd_comgr_data_set_t())))
|
||||
check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_bc := comgr.amd_comgr_data_set_t())))
|
||||
check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_reloc := comgr.amd_comgr_data_set_t())))
|
||||
check(comgr.amd_comgr_create_data_set(ctypes.byref(data_set_exec := comgr.amd_comgr_data_set_t())))
|
||||
|
||||
check(comgr.amd_comgr_create_data(comgr.AMD_COMGR_DATA_KIND_SOURCE, ctypes.byref(data_src := comgr.amd_comgr_data_t())))
|
||||
check(comgr.amd_comgr_set_data(data_src, len(rprg := prg.encode()), rprg))
|
||||
check(comgr.amd_comgr_set_data_name(data_src, b"<null>"))
|
||||
|
||||
check(comgr.amd_comgr_data_set_add(data_set_src, data_src))
|
||||
# -include hiprtc_runtime.h was removed
|
||||
check(comgr.amd_comgr_action_info_set_options(action_info, b"-O3 -mcumode --hip-version=6.0.32830 -DHIP_VERSION_MAJOR=6 -DHIP_VERSION_MINOR=0 -DHIP_VERSION_PATCH=32830 -D__HIPCC_RTC__ -std=c++14 -nogpuinc -Wno-gnu-line-marker -Wno-missing-prototypes --offload-arch=gfx1100 -I/opt/rocm/include -Xclang -disable-llvm-passes")) # noqa: E501
|
||||
status = comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_COMPILE_SOURCE_WITH_DEVICE_LIBS_TO_BC, action_info, data_set_src, data_set_bc)
|
||||
if status != 0:
|
||||
print(_get_comgr_data(data_set_bc, comgr.AMD_COMGR_DATA_KIND_LOG).decode())
|
||||
raise RuntimeError("compile failed")
|
||||
check(comgr.amd_comgr_action_info_set_options(action_info, b"-O3 -mllvm -amdgpu-internalize-symbols"))
|
||||
check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_CODEGEN_BC_TO_RELOCATABLE, action_info, data_set_bc, data_set_reloc))
|
||||
check(comgr.amd_comgr_action_info_set_options(action_info, b""))
|
||||
check(comgr.amd_comgr_do_action(comgr.AMD_COMGR_ACTION_LINK_RELOCATABLE_TO_EXECUTABLE, action_info, data_set_reloc, data_set_exec))
|
||||
ret = _get_comgr_data(data_set_exec, comgr.AMD_COMGR_DATA_KIND_EXECUTABLE)
|
||||
check(comgr.amd_comgr_release_data(data_src))
|
||||
for x in [data_set_src, data_set_bc, data_set_reloc, data_set_exec]: check(comgr.amd_comgr_destroy_data_set(x))
|
||||
check(comgr.amd_comgr_destroy_action_info(action_info))
|
||||
return ret
|
|
@ -3,10 +3,11 @@ import ctypes, functools, subprocess, io
|
|||
from typing import Tuple, TypeVar, List, Any, cast, Set
|
||||
import tinygrad.runtime.autogen.hip as hip
|
||||
from tinygrad.helpers import DEBUG, getenv, init_c_var
|
||||
from tinygrad.helpers import from_mv, round_up, to_mv, colored, init_c_struct_t, to_char_p_p, get_bytes
|
||||
from tinygrad.helpers import from_mv, round_up, to_mv, colored, init_c_struct_t
|
||||
from tinygrad.device import Compiled, LRUAllocator, MallocAllocator, BufferOptions, JITRunner, Device, Buffer, update_stats
|
||||
from tinygrad.renderer.cstyle import HIPRenderer
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
from tinygrad.runtime.compiler.hip_comgr import compile_hip
|
||||
|
||||
# The default HIP stream is used for everything.
|
||||
MOCKHIP = getenv("MOCKHIP") # for CI. don't run kernels, only check if they compile
|
||||
|
@ -21,13 +22,6 @@ def hip_set_device(d:int):
|
|||
def check(status):
|
||||
if status != 0: raise RuntimeError(f"HIP Error {status}, {ctypes.string_at(hip.hipGetErrorString(status)).decode()}")
|
||||
|
||||
def compile_hip(prg:str, arch="gfx1100") -> bytes:
|
||||
check(hip.hiprtcCreateProgram(ctypes.byref(prog := hip.hiprtcProgram()), prg.encode(), "<null>".encode(), 0, None, None))
|
||||
compile_options = [f'--offload-arch={arch}', '-I/opt/rocm/include']
|
||||
status = hip.hiprtcCompileProgram(prog, len(compile_options), to_char_p_p([o.encode() for o in compile_options]))
|
||||
if status != 0: raise RuntimeError(f"compile failed: {get_bytes(prog, hip.hiprtcGetProgramLogSize, hip.hiprtcGetProgramLog, check).decode()}")
|
||||
return get_bytes(prog, hip.hiprtcGetCodeSize, hip.hiprtcGetCode, check)
|
||||
|
||||
class HIPProgram:
|
||||
def __init__(self, device:int, name:str, lib:bytes):
|
||||
self.device, self.name, self.lib = device, name, lib
|
||||
|
|
Loading…
Reference in New Issue