mirror of https://github.com/commaai/tinygrad.git
fix 'Import Error: cannot import name compile_cuda from tinygrad.runtime.ops_cuda' error in extra/gemm/cuda_matmul.py (#3531)
This commit is contained in:
parent
275971e616
commit
0b1fc5888a
|
@ -1,7 +1,7 @@
|
|||
import os
|
||||
import numpy as np
|
||||
os.environ["CUDA"] = "1"
|
||||
from tinygrad.runtime.ops_cuda import CUDAAllocator, CUDADevice, CUDAProgram, compile_cuda
|
||||
from tinygrad.runtime.ops_cuda import CUDAAllocator, CUDADevice, CUDAProgram, CUDACompiler
|
||||
from tinygrad.helpers import flat_mv
|
||||
|
||||
FLOAT16 = True
|
||||
|
@ -29,7 +29,9 @@ cudaalloc.copyin(b, bytearray(nb))
|
|||
FLOPS = N*N*N*2
|
||||
BW = N*N*3*4
|
||||
|
||||
prog = CUDAProgram(device, "wmma_example", compile_cuda(f"""
|
||||
print(device.arch)
|
||||
compiler = CUDACompiler(device.arch)
|
||||
prog = CUDAProgram(device, "wmma_example", compiler.compile(f"""
|
||||
#include <mma.h>
|
||||
using namespace nvcuda;
|
||||
|
||||
|
|
Loading…
Reference in New Issue