fix 'Import Error: cannot import name compile_cuda from tinygrad.runtime.ops_cuda' error in extra/gemm/cuda_matmul.py (#3531)

This commit is contained in:
Caleb Bunch 2024-02-28 20:15:32 -05:00 committed by GitHub
parent 275971e616
commit 0b1fc5888a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 2 deletions

View File

@ -1,7 +1,7 @@
import os
import numpy as np
os.environ["CUDA"] = "1"
from tinygrad.runtime.ops_cuda import CUDAAllocator, CUDADevice, CUDAProgram, compile_cuda
from tinygrad.runtime.ops_cuda import CUDAAllocator, CUDADevice, CUDAProgram, CUDACompiler
from tinygrad.helpers import flat_mv
FLOAT16 = True
@ -29,7 +29,9 @@ cudaalloc.copyin(b, bytearray(nb))
FLOPS = N*N*N*2
BW = N*N*3*4
prog = CUDAProgram(device, "wmma_example", compile_cuda(f"""
print(device.arch)
compiler = CUDACompiler(device.arch)
prog = CUDAProgram(device, "wmma_example", compiler.compile(f"""
#include <mma.h>
using namespace nvcuda;