Move HIP render logic to its dedicated place (#2394)

* update HIP language

* vectorized render_cast with special treatment for hip only

* test coverage for all cases

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
qazal 2023-11-23 16:03:29 -05:00 committed by GitHub
parent 6d672785db
commit b927942d58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 58 additions and 31 deletions

11
test/test_renderer.py Normal file
View File

@ -0,0 +1,11 @@
import unittest
from tinygrad.helpers import dtypes
from tinygrad.renderer.hip import HIPLanguage
class TestRenderer(unittest.TestCase):
def test_render_cast(self):
self.assertEqual(HIPLanguage().render_cast(["data0"], dtypes.half), "(half)(data0)")
self.assertEqual(HIPLanguage().render_cast(["data0", "data1", "data2", "data3"], dtypes.float.vec(4)), "make_float4(data0,data1,data2,data3)")
self.assertEqual(HIPLanguage().render_cast(["data0", "data1", "data2", "data3", "data4", "data5", "data6", "data7"], dtypes.float.vec(8)), "{data0,data1,data2,data3,data4,data5,data6,data7}")
self.assertEqual(HIPLanguage().render_cast(["data0", "data1", "data2", "data3"], dtypes.half.vec(4)), "{(half)data0,(half)data1,(half)data2,(half)data3}")

View File

@ -45,10 +45,8 @@ class CStyleLanguage(NamedTuple):
def render_cast(self, x:List[str], var_dtype:DType) -> str:
if len(x) == 1: return f"({var_dtype.name})({x[0]})"
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
assert self.float4 is not None, "cast is not supported on this platform"
if var_dtype == dtypes.half.vec(16): return f"{{{','.join(f'(half){x}' for x in x)}}}"
if var_dtype == dtypes.float.vec(8): return f"{{{','.join(x)}}}"
return f"{self.float4.replace('float4', var_dtype.name)}({','.join(f'(half){x}' if var_dtype.scalar() == dtypes.half else x for x in x)})"
assert self.float4 is not None, "vectorized cast is not supported on this platform"
return f"{self.float4.replace('float4', var_dtype.name)}({','.join(x)})"
# returns a str expression of the const with the given type
def render_const(self, x:Union[float,int], var_dtype) -> str:

42
tinygrad/renderer/hip.py Normal file
View File

@ -0,0 +1,42 @@
import functools
from tinygrad.helpers import dtypes
from tinygrad.renderer.cstyle import CStyleLanguage, uops_to_cstyle
class HIPLanguage(CStyleLanguage):
kernel_prefix = "#include <hip/hip_common.h>\n#define INFINITY (__builtin_inff())\n#define NAN (__builtin_nanf(\"\"))" + """
__device__ float4 max(float4 x, float4 y) { return float4(max(x.x, y.x), max(x.y, y.y), max(x.z, y.z), max(x.w, y.w)); }
__device__ float4 pow(float x, float4 y) { return float4(pow(x, y.x), pow(x, y.y), pow(x, y.z), pow(x, y.w)); }
__device__ float4 pow(float4 x, float4 y) { return float4(pow(x.x, y.x), pow(x.y, y.y), pow(x.z, y.z), pow(x.w, y.w)); }
__device__ float4 log2(float4 x) { return float4(log2(x.x), log2(x.y), log2(x.z), log2(x.w)); }
__device__ float4 exp2(float4 x) { return float4(exp2(x.x), exp2(x.y), exp2(x.z), exp2(x.w)); }
__device__ float4 sin(float4 x) { return float4(sin(x.x), sin(x.y), sin(x.z), sin(x.w)); }
typedef float float8 __attribute__((ext_vector_type(8)));
typedef _Float16 half16 __attribute__((ext_vector_type(16)));
extern "C" __global__
"""
launch_bounds = True
smem_prefix = "__shared__ "
smem_prefix_for_cast=False
barrier = "__syncthreads();"
float4 = "make_float4"
uses_vload=True
uses_ptr_arithmetic=True
arg_int_prefix = "const int"
half_prekernel = "#include <hip/hip_fp16.h>\nusing half4 = HIP_vector_type<half, 4>;" + """
__device__ float vload_half(size_t offset, const half *p) { return (float)*(p + offset); }
__device__ float2 vload_half2(size_t offset, const half *p) { return make_float2((float)*(p + offset*2), (float)*(p + offset*2 + 1)); }
__device__ float4 vload_half4(size_t offset, const half *p) { return make_float4((float)*(p + offset*4), (float)*(p + offset*4 + 1), (float)*(p + offset*4 + 2), (float)*(p + offset*4 + 3)); }
__device__ void vstore_half(float data, size_t offset, half *p) { *(p + offset) = (half)data; }
__device__ void vstore_half2(float2 data, size_t offset, half *p) { *(p + offset*2) = (half)data.x; *(p + offset*2 + 1) = (half)data.y; }
__device__ void vstore_half4(float4 data, size_t offset, half *p) { *(p + offset*4) = (half)data.x; *(p + offset*4 + 1) = (half)data.y; *(p + offset*4 + 2) = (half)data.z; *(p + offset*4 + 3) = (half)data.w; }
"""
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)]
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)]
xid = [f'(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})' for i in range(3)]
def render_cast(self, x, var_dtype):
if var_dtype.sz > 1 and var_dtype.scalar() == dtypes.half: return f"{{{','.join(f'(half){x}' for x in x)}}}"
if var_dtype.sz == 8: return f"{{{','.join(x)}}}"
return super().render_cast(x, var_dtype)
HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage())

View File

@ -1,12 +1,12 @@
import numpy as np
import ctypes, functools
import ctypes
import extra.hip_wrapper as hip
from typing import Tuple
from tinygrad.helpers import DEBUG, getenv, diskcache
from tinygrad.ops import Compiled
from tinygrad.renderer.hip import HIPRenderer
from tinygrad.runtime.lib import RawBuffer, RawBufferCopyInOut, LRUAllocator, RawBufferTransfer
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
# TODO: if you fork and exit the child process after creating anything with cl on AMD, it hangs on e.wait()
if DEBUG >= 6:
@ -79,28 +79,4 @@ class HIPProgram:
def __del__(self):
for module in self.modules: hip.hipModuleUnload(module)
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
kernel_prefix = "#include <hip/hip_common.h>\n#define INFINITY (__builtin_inff())\n#define NAN (__builtin_nanf(\"\"))" + """
__device__ float4 max(float4 x, float4 y) { return float4(max(x.x, y.x), max(x.y, y.y), max(x.z, y.z), max(x.w, y.w)); }
__device__ float4 pow(float x, float4 y) { return float4(pow(x, y.x), pow(x, y.y), pow(x, y.z), pow(x, y.w)); }
__device__ float4 pow(float4 x, float4 y) { return float4(pow(x.x, y.x), pow(x.y, y.y), pow(x.z, y.z), pow(x.w, y.w)); }
__device__ float4 log2(float4 x) { return float4(log2(x.x), log2(x.y), log2(x.z), log2(x.w)); }
__device__ float4 exp2(float4 x) { return float4(exp2(x.x), exp2(x.y), exp2(x.z), exp2(x.w)); }
__device__ float4 sin(float4 x) { return float4(sin(x.x), sin(x.y), sin(x.z), sin(x.w)); }
typedef float float8 __attribute__((ext_vector_type(8)));
typedef _Float16 half16 __attribute__((ext_vector_type(16)));
extern "C" __global__
""", launch_bounds=True,
smem_prefix = "__shared__ ", smem_prefix_for_cast=False, barrier = "__syncthreads();", float4 = "make_float4", uses_vload=True, uses_ptr_arithmetic=True, arg_int_prefix = "const int",
half_prekernel = "#include <hip/hip_fp16.h>\nusing half4 = HIP_vector_type<half, 4>;" + """
__device__ float vload_half(size_t offset, const half *p) { return (float)*(p + offset); }
__device__ float2 vload_half2(size_t offset, const half *p) { return make_float2((float)*(p + offset*2), (float)*(p + offset*2 + 1)); }
__device__ float4 vload_half4(size_t offset, const half *p) { return make_float4((float)*(p + offset*4), (float)*(p + offset*4 + 1), (float)*(p + offset*4 + 2), (float)*(p + offset*4 + 3)); }
__device__ void vstore_half(float data, size_t offset, half *p) { *(p + offset) = (half)data; }
__device__ void vstore_half2(float2 data, size_t offset, half *p) { *(p + offset*2) = (half)data.x; *(p + offset*2 + 1) = (half)data.y; }
__device__ void vstore_half4(float4 data, size_t offset, half *p) { *(p + offset*4) = (half)data.x; *(p + offset*4 + 1) = (half)data.y; *(p + offset*4 + 2) = (half)data.z; *(p + offset*4 + 3) = (half)data.w; }
""",
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)],
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)],
xid = [f'(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})' for i in range(3)]))
HIPBuffer = Compiled(RawHIPBuffer, LinearizerOptions(device="HIP"), renderer, compile_hip, HIPProgram, hip.hipDeviceSynchronize)
HIPBuffer = Compiled(RawHIPBuffer, LinearizerOptions(device="HIP"), HIPRenderer, compile_hip, HIPProgram, hip.hipDeviceSynchronize)