triton isn't tested, and allows this refactor (#2007)

* triton isn't tested

* cuda buffer
This commit is contained in:
George Hotz 2023-10-07 07:29:59 -07:00 committed by GitHub
parent 23de1db727
commit dea8bb0938
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 14 deletions

View File

@ -172,6 +172,12 @@ class ASTRunner:
if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args or not runtime_args['binary']): print(prg)
self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {}
@staticmethod
def from_linearizer(k, src:str):
return ASTRunner(k.function_name, src, k.global_size, k.local_size,
op_estimate=k.info.flops, mem_estimate=k.mem_estimate,
display_name=k.display_name, runtime_args={"binary": False})
def build(self, runtime, batch_exec=BasicBatchExecutor):
self.clprg, self.batch_exec = runtime(self.name, self.prg, **self.runtime_args), batch_exec
return self
@ -206,12 +212,7 @@ class Compiled:
def to_program(self, k):
k.linearize()
src = self.renderer(k.function_name, k.uops)
if len(src) == 3:
return ASTRunner(k.function_name, src[0], k.global_size, src[1],display_name=k.display_name, runtime_args=src[2]).build(self.runtime)
return ASTRunner(k.function_name, src, k.global_size, k.local_size,
op_estimate=k.info.flops, mem_estimate=k.mem_estimate,
display_name=k.display_name, runtime_args={"binary": False}).build(self.runtime, self.batch_exec)
return ASTRunner.from_linearizer(k, self.renderer(k.function_name, k.uops)).build(self.runtime, self.batch_exec)
def exec_ast(self, ast:LazyOp, output, inputs, var_vals, **kwargs):
# check if we can reuse the output buffer

View File

@ -3,7 +3,7 @@ from pathlib import Path
from typing import Optional
import numpy as np
from pycuda.compiler import compile as cuda_compile # type: ignore
from tinygrad.helpers import DEBUG, getenv, colored, fromimport
from tinygrad.helpers import DEBUG, getenv, colored
from tinygrad.ops import Compiled
from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator
from tinygrad.codegen.kernel import LinearizerOptions
@ -96,10 +96,5 @@ renderer = functools.partial(uops_to_cstyle, CStyleLanguage(
__device__ __forceinline__ explicit half4(const float4& a): x(make_half2(__float2half(a.x), __float2half(a.y))), y(make_half2(__float2half(a.z),__float2half(a.w))) {}
__device__ __forceinline__ explicit operator float4() const {return make_float4(__half2float(x.x), __half2float(x.y), __half2float(y.x), __half2float(y.y)); }
};
""")) if not getenv("PTX") else fromimport("tinygrad.renderer.assembly_ptx", "uops_to_ptx_asm")
if getenv("TRITON") == 1:
from tinygrad.renderer.triton import uops_to_triton
renderer = uops_to_triton
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024], has_shared=False), renderer, CUDAProgram, cuda.Context.synchronize)
else:
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False if getenv("PTX") else True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), renderer, CUDAProgram, cuda.Context.synchronize)
"""))
CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), renderer, CUDAProgram, cuda.Context.synchronize)