update metal matmul and matvec for compile api (#2238)

This commit is contained in:
Rory Clear 2023-11-08 16:08:35 +00:00 committed by GitHub
parent 3042450b4d
commit 553688f12a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 9 deletions

View File

@ -3,7 +3,7 @@ os.environ["METAL"] = "1"
import time
import numpy as np
from tinygrad.helpers import dtypes, getenv
from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram
from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram, compile_metal
N = getenv("N", 2048)
LID = 2
@ -18,7 +18,7 @@ c = RawMetalBuffer.fromCPU(nc)
FLOPS = N*N*N*2
BW = N*N*3*4
prog = MetalProgram("test", f"""
prog = MetalProgram("test", compile_metal(f"""
#include <metal_stdlib>
#include <metal_simdgroup_matrix> // Available from Metal version 2.3 released with OS X 11.0+
using namespace metal;
@ -80,13 +80,13 @@ kernel void test(device float *a, device const float *data1, device const float
simdgroup_store(acc[1][3], a+{8+24*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[2][3], a+{16+24*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[3][3], a+{24+24*N}, {N}, ulong2(0, 0));
}}""")
}}"""))
def timeit(fxn):
st = time.perf_counter()
et = fxn()
# NOTE: et doesn't contain the launch overhead
return time.perf_counter() - st
tm = min([timeit(lambda: prog([N//(8*4), N//(8*4*LID), 1], [32, LID, 1], a, b, c, wait=True)) for _ in range(20)])
tm = min([timeit(lambda: prog(a, b, c, global_size=[N//(8*4), N//(8*4*LID), 1], local_size=[32, LID, 1], wait=True)) for _ in range(20)])
na = a.toCPU().reshape(N,N)
comp = nb@nc
if N <= 32:

View File

@ -14,7 +14,7 @@ os.environ["METAL"] = "1"
import time
import numpy as np
from tinygrad.helpers import dtypes, getenv
from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram
from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram, compile_metal
N = 16384
M = 4096
@ -40,7 +40,7 @@ WORKSIZE_ROW = 16
WORKSIZE_COL = 1
LOCAL_SIZE = [32, WORKSIZE_COL, WORKSIZE_ROW]
GLOBAL_SIZE = [M//(LOCAL_SIZE[0]*LOCAL_SIZE[1]*4), 1, 1]
prog_string = f"""
prog = compile_metal(f"""
#include <metal_stdlib>
using namespace metal;
kernel void test(device float* data0, const device float* data1, const device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {{
@ -86,15 +86,15 @@ kernel void test(device float* data0, const device float* data1, const device fl
*( (device float4 *) (data0 + (gidx0*{M//GLOBAL_SIZE[0]}) + ( ( (lidx1*{LOCAL_SIZE[1]})+lidx2 ) * 4 ) ) ) = out;
}}
}}
"""
prog = MetalProgram("test", prog_string)
""")
prog = MetalProgram("test", prog)
# print(prog_string)
na = np.zeros(M, dtype=np.float32)
b = RawMetalBuffer.fromCPU(nb)
c = RawMetalBuffer.fromCPU(nc)
def metalrun():
a = RawMetalBuffer.fromCPU(na)
prog(GLOBAL_SIZE, LOCAL_SIZE, a, b, c, wait=True)
prog(a, b, c, global_size=GLOBAL_SIZE, local_size=LOCAL_SIZE, wait=True)
return a
def timeit(fxn):
st = time.perf_counter()