diff --git a/extra/gemm/cuda_matmul.py b/extra/gemm/cuda_matmul.py index cfc096d2..48402d74 100644 --- a/extra/gemm/cuda_matmul.py +++ b/extra/gemm/cuda_matmul.py @@ -1,7 +1,7 @@ import os import numpy as np os.environ["CUDA"] = "1" -from tinygrad.runtime.ops_cuda import CUDAAllocator, CUDADevice, CUDAProgram, compile_cuda +from tinygrad.runtime.ops_cuda import CUDAAllocator, CUDADevice, CUDAProgram, CUDACompiler from tinygrad.helpers import flat_mv FLOAT16 = True @@ -29,7 +29,9 @@ cudaalloc.copyin(b, bytearray(nb)) FLOPS = N*N*N*2 BW = N*N*3*4 -prog = CUDAProgram(device, "wmma_example", compile_cuda(f""" +print(device.arch) +compiler = CUDACompiler(device.arch) +prog = CUDAProgram(device, "wmma_example", compiler.compile(f""" #include using namespace nvcuda;