mirror of https://github.com/commaai/tinygrad.git
fix 'Import Error: cannot import name compile_cuda from tinygrad.runtime.ops_cuda' error in extra/gemm/cuda_matmul.py (#3531)
This commit is contained in:
parent
275971e616
commit
0b1fc5888a
|
@ -1,7 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
os.environ["CUDA"] = "1"
|
os.environ["CUDA"] = "1"
|
||||||
from tinygrad.runtime.ops_cuda import CUDAAllocator, CUDADevice, CUDAProgram, compile_cuda
|
from tinygrad.runtime.ops_cuda import CUDAAllocator, CUDADevice, CUDAProgram, CUDACompiler
|
||||||
from tinygrad.helpers import flat_mv
|
from tinygrad.helpers import flat_mv
|
||||||
|
|
||||||
FLOAT16 = True
|
FLOAT16 = True
|
||||||
|
@ -29,7 +29,9 @@ cudaalloc.copyin(b, bytearray(nb))
|
||||||
FLOPS = N*N*N*2
|
FLOPS = N*N*N*2
|
||||||
BW = N*N*3*4
|
BW = N*N*3*4
|
||||||
|
|
||||||
prog = CUDAProgram(device, "wmma_example", compile_cuda(f"""
|
print(device.arch)
|
||||||
|
compiler = CUDACompiler(device.arch)
|
||||||
|
prog = CUDAProgram(device, "wmma_example", compiler.compile(f"""
|
||||||
#include <mma.h>
|
#include <mma.h>
|
||||||
using namespace nvcuda;
|
using namespace nvcuda;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue