From dea8bb09386d9565c636bdbb024af4556e28c7b1 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sat, 7 Oct 2023 07:29:59 -0700 Subject: [PATCH] triton isn't tested, and allows this refactor (#2007) * triton isn't tested * cuda buffer --- {tinygrad/renderer => extra/triton}/triton.py | 0 tinygrad/ops.py | 13 +++++++------ tinygrad/runtime/ops_cuda.py | 11 +++-------- 3 files changed, 10 insertions(+), 14 deletions(-) rename {tinygrad/renderer => extra/triton}/triton.py (100%) diff --git a/tinygrad/renderer/triton.py b/extra/triton/triton.py similarity index 100% rename from tinygrad/renderer/triton.py rename to extra/triton/triton.py diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 5ac127af..5fd829a4 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -172,6 +172,12 @@ class ASTRunner: if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args or not runtime_args['binary']): print(prg) self.name, self.prg, self.global_size, self.local_size, self.op_estimate, self.mem_estimate, self.display_name, self.runtime_args = name, prg, global_size, local_size, op_estimate, mem_estimate, display_name, runtime_args if runtime_args is not None else {} + @staticmethod + def from_linearizer(k, src:str): + return ASTRunner(k.function_name, src, k.global_size, k.local_size, + op_estimate=k.info.flops, mem_estimate=k.mem_estimate, + display_name=k.display_name, runtime_args={"binary": False}) + def build(self, runtime, batch_exec=BasicBatchExecutor): self.clprg, self.batch_exec = runtime(self.name, self.prg, **self.runtime_args), batch_exec return self @@ -206,12 +212,7 @@ class Compiled: def to_program(self, k): k.linearize() - src = self.renderer(k.function_name, k.uops) - if len(src) == 3: - return ASTRunner(k.function_name, src[0], k.global_size, src[1],display_name=k.display_name, runtime_args=src[2]).build(self.runtime) - return ASTRunner(k.function_name, src, k.global_size, k.local_size, - op_estimate=k.info.flops, mem_estimate=k.mem_estimate, - display_name=k.display_name, runtime_args={"binary": False}).build(self.runtime, self.batch_exec) + return ASTRunner.from_linearizer(k, self.renderer(k.function_name, k.uops)).build(self.runtime, self.batch_exec) def exec_ast(self, ast:LazyOp, output, inputs, var_vals, **kwargs): # check if we can reuse the output buffer diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index fb8a1a1e..71d0d29e 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import Optional import numpy as np from pycuda.compiler import compile as cuda_compile # type: ignore -from tinygrad.helpers import DEBUG, getenv, colored, fromimport +from tinygrad.helpers import DEBUG, getenv, colored from tinygrad.ops import Compiled from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator from tinygrad.codegen.kernel import LinearizerOptions @@ -96,10 +96,5 @@ renderer = functools.partial(uops_to_cstyle, CStyleLanguage( __device__ __forceinline__ explicit half4(const float4& a): x(make_half2(__float2half(a.x), __float2half(a.y))), y(make_half2(__float2half(a.z),__float2half(a.w))) {} __device__ __forceinline__ explicit operator float4() const {return make_float4(__half2float(x.x), __half2float(x.y), __half2float(y.x), __half2float(y.y)); } }; - """)) if not getenv("PTX") else fromimport("tinygrad.renderer.assembly_ptx", "uops_to_ptx_asm") -if getenv("TRITON") == 1: - from tinygrad.renderer.triton import uops_to_triton - renderer = uops_to_triton - CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024], has_shared=False), renderer, CUDAProgram, cuda.Context.synchronize) -else: - CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=False if getenv("PTX") else True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), renderer, CUDAProgram, cuda.Context.synchronize) + """)) +CUDABuffer = Compiled(RawCUDABuffer, LinearizerOptions(supports_float4=True, supports_float4_alu=False, global_max = [65535, 65535, 2147483647], local_max = [64, 1024, 1024]), renderer, CUDAProgram, cuda.Context.synchronize)