mirror of https://github.com/commaai/tinygrad.git
update metal matmul and matvec for compile api (#2238)
This commit is contained in:
parent
3042450b4d
commit
553688f12a
|
@ -3,7 +3,7 @@ os.environ["METAL"] = "1"
|
|||
import time
|
||||
import numpy as np
|
||||
from tinygrad.helpers import dtypes, getenv
|
||||
from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram
|
||||
from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram, compile_metal
|
||||
|
||||
N = getenv("N", 2048)
|
||||
LID = 2
|
||||
|
@ -18,7 +18,7 @@ c = RawMetalBuffer.fromCPU(nc)
|
|||
FLOPS = N*N*N*2
|
||||
BW = N*N*3*4
|
||||
|
||||
prog = MetalProgram("test", f"""
|
||||
prog = MetalProgram("test", compile_metal(f"""
|
||||
#include <metal_stdlib>
|
||||
#include <metal_simdgroup_matrix> // Available from Metal version 2.3 released with OS X 11.0+
|
||||
using namespace metal;
|
||||
|
@ -80,13 +80,13 @@ kernel void test(device float *a, device const float *data1, device const float
|
|||
simdgroup_store(acc[1][3], a+{8+24*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[2][3], a+{16+24*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[3][3], a+{24+24*N}, {N}, ulong2(0, 0));
|
||||
}}""")
|
||||
}}"""))
|
||||
def timeit(fxn):
|
||||
st = time.perf_counter()
|
||||
et = fxn()
|
||||
# NOTE: et doesn't contain the launch overhead
|
||||
return time.perf_counter() - st
|
||||
tm = min([timeit(lambda: prog([N//(8*4), N//(8*4*LID), 1], [32, LID, 1], a, b, c, wait=True)) for _ in range(20)])
|
||||
tm = min([timeit(lambda: prog(a, b, c, global_size=[N//(8*4), N//(8*4*LID), 1], local_size=[32, LID, 1], wait=True)) for _ in range(20)])
|
||||
na = a.toCPU().reshape(N,N)
|
||||
comp = nb@nc
|
||||
if N <= 32:
|
||||
|
|
|
@ -14,7 +14,7 @@ os.environ["METAL"] = "1"
|
|||
import time
|
||||
import numpy as np
|
||||
from tinygrad.helpers import dtypes, getenv
|
||||
from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram
|
||||
from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram, compile_metal
|
||||
|
||||
N = 16384
|
||||
M = 4096
|
||||
|
@ -40,7 +40,7 @@ WORKSIZE_ROW = 16
|
|||
WORKSIZE_COL = 1
|
||||
LOCAL_SIZE = [32, WORKSIZE_COL, WORKSIZE_ROW]
|
||||
GLOBAL_SIZE = [M//(LOCAL_SIZE[0]*LOCAL_SIZE[1]*4), 1, 1]
|
||||
prog_string = f"""
|
||||
prog = compile_metal(f"""
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
kernel void test(device float* data0, const device float* data1, const device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {{
|
||||
|
@ -86,15 +86,15 @@ kernel void test(device float* data0, const device float* data1, const device fl
|
|||
*( (device float4 *) (data0 + (gidx0*{M//GLOBAL_SIZE[0]}) + ( ( (lidx1*{LOCAL_SIZE[1]})+lidx2 ) * 4 ) ) ) = out;
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
prog = MetalProgram("test", prog_string)
|
||||
""")
|
||||
prog = MetalProgram("test", prog)
|
||||
# print(prog_string)
|
||||
na = np.zeros(M, dtype=np.float32)
|
||||
b = RawMetalBuffer.fromCPU(nb)
|
||||
c = RawMetalBuffer.fromCPU(nc)
|
||||
def metalrun():
|
||||
a = RawMetalBuffer.fromCPU(na)
|
||||
prog(GLOBAL_SIZE, LOCAL_SIZE, a, b, c, wait=True)
|
||||
prog(a, b, c, global_size=GLOBAL_SIZE, local_size=LOCAL_SIZE, wait=True)
|
||||
return a
|
||||
def timeit(fxn):
|
||||
st = time.perf_counter()
|
||||
|
|
Loading…
Reference in New Issue