adding cuda TC headers (#2165)

* split cuda to renderer and add headers for tc

* fix TritonRenderer

* remove unused import
This commit is contained in:
Yixiang Gao 2023-10-27 19:25:59 -05:00 committed by GitHub
parent 7f4f925385
commit 902f00b095
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 31 additions and 19 deletions

25
tinygrad/renderer/cuda.py Normal file
View File

@ -0,0 +1,25 @@
import functools
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
class CUDALanguage(CStyleLanguage):
kernel_prefix = "__global__ "
smem_prefix = "__shared__ "
smem_prefix_for_cast = False
arg_int_prefix = "const int"
barrier = "__syncthreads();"
float4 = "make_float4"
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)]
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)]
xid = [f'(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})' for i in range(3)]
half_prekernel = """
#include <cuda_fp16.h>
#include <mma.h>
using namespace nvcuda;
struct __align__(8) half4 {
half2 x, y;
__device__ __forceinline__ explicit half4(const float4& a): x(make_half2(__float2half(a.x), __float2half(a.y))), y(make_half2(__float2half(a.z),__float2half(a.w))) {}
__device__ __forceinline__ explicit operator float4() const {return make_float4(__half2float(x.x), __half2float(x.y), __half2float(y.x), __half2float(y.y)); }
};
""" # if not getenv("PTX") else fromimport("tinygrad.renderer.assembly_ptx", "uops_to_ptx_asm") # assembly_ptx currently isn't supported
CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage())

View File

@ -1,13 +1,13 @@
import subprocess, time, re, hashlib, tempfile, functools
import subprocess, time, re, hashlib, tempfile
from pathlib import Path
from typing import Optional, List, Any, Tuple
import numpy as np
from pycuda.compiler import compile as cuda_compile # type: ignore
from tinygrad.helpers import DEBUG, getenv, colored, fromimport
from tinygrad.helpers import DEBUG, getenv, colored
from tinygrad.ops import Compiled, GraphBatchExecutor, ASTRunner
from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
from tinygrad.renderer.cuda import CUDARenderer
def pretty_ptx(s):
# all expressions match `<valid_before><expr><valid_after>` and replace it with `<valid_before>color(<expr>)<valid_after>`
@ -116,22 +116,9 @@ class CUDAProgram:
end.synchronize()
return start.time_till(end)*1e-3
renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
kernel_prefix = "__global__ ", smem_prefix = "__shared__ ", smem_prefix_for_cast=False, arg_int_prefix = "const int", barrier = "__syncthreads();", float4 = "make_float4",
gid = [f'blockIdx.{chr(120+i)}' for i in range(3)],
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)],
xid = [f'(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})' for i in range(3)],
half_prekernel = """
#include <cuda_fp16.h>
struct __align__(8) half4 {
half2 x, y;
__device__ __forceinline__ explicit half4(const float4& a): x(make_half2(__float2half(a.x), __float2half(a.y))), y(make_half2(__float2half(a.z),__float2half(a.w))) {}
__device__ __forceinline__ explicit operator float4() const {return make_float4(__half2float(x.x), __half2float(x.y), __half2float(y.x), __half2float(y.y)); }
};
""")) if not getenv("PTX") else fromimport("tinygrad.renderer.assembly_ptx", "uops_to_ptx_asm")
if getenv("TRITON") == 1:
from tinygrad.renderer.triton import uops_to_triton
renderer = uops_to_triton
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024], has_shared=False), renderer, CUDAProgram, cuda.Context.synchronize)
TritonRenderer = uops_to_triton
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024], has_shared=False), TritonRenderer, CUDAProgram, cuda.Context.synchronize)
else:
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False if getenv("PTX") else True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), renderer, CUDAProgram, cuda.Context.synchronize, CUDAGraph)
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False if getenv("PTX") else True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), CUDARenderer, CUDAProgram, cuda.Context.synchronize, CUDAGraph)