import os #os.environ["METAL"] = "1" import numpy as np import time, torch, torch.mps from tinygrad.helpers import GlobalCounters from tinygrad.tensor import Tensor from tinygrad.jit import TinyJit from tinygrad.ops import Device from tinygrad.helpers import colored, getenv, CI import os os.environ["METAL"] = "1" import time import numpy as np from tinygrad.helpers import dtypes, getenv from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram N = 16384 M = 4096 FLOPS = N*M*2 nb = np.random.default_rng().standard_normal(size=(N), dtype=np.float32) #.astype(np.int32).astype(np.float32) nc = np.random.default_rng().standard_normal(size=(N,M), dtype=np.float32) #.astype(np.int32).astype(np.float32) import torch, torch.mps b = torch.from_numpy(nb).to('mps') c = torch.from_numpy(nc).to('mps') def torch_prog(b, c): st = time.perf_counter() a = b@c torch.mps.synchronize() return time.perf_counter() - st tm = min([torch_prog(b, c) for _ in range(200)]) print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in torch") torch_a = (b@c).cpu() WORKSIZE_ROW = 16 WORKSIZE_COL = 1 LOCAL_SIZE = [32, WORKSIZE_COL, WORKSIZE_ROW] GLOBAL_SIZE = [M//(LOCAL_SIZE[0]*LOCAL_SIZE[1]*4), 1, 1] prog_string = f""" #include using namespace metal; kernel void test(device float* data0, const device float* data1, const device float* data2, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {{ int gidx0 = gid.x; /* {GLOBAL_SIZE[0]} */ int lidx1 = lid.x; /* {LOCAL_SIZE[0]} */ int lidx2 = lid.y; /* {LOCAL_SIZE[1]} */ int lidx3 = lid.z; /* {LOCAL_SIZE[2]} */ // 4 rows per thread threadgroup float4 acc0[{LOCAL_SIZE[0]*LOCAL_SIZE[1]*LOCAL_SIZE[2]}]; int acc0_index = ((lidx1*{LOCAL_SIZE[1]})+lidx2)+({LOCAL_SIZE[0]*LOCAL_SIZE[1]}*lidx3); acc0[acc0_index] = float4(0.0f,0.0f,0.0f,0.0f); threadgroup float4 val1[{LOCAL_SIZE[0]*LOCAL_SIZE[1]*LOCAL_SIZE[2]}]; // iterate over the columns for (int ridx2 = 0; ridx2 < {N//(4*LOCAL_SIZE[0]*LOCAL_SIZE[1]*(LOCAL_SIZE[2]))}; ++ridx2) {{ // load 4*threadgroup_size columns into shared memory int col_1 = (((lidx3*{N//(4*LOCAL_SIZE[0]*LOCAL_SIZE[1]*(LOCAL_SIZE[2]))})+ridx2)*{LOCAL_SIZE[0]*LOCAL_SIZE[1]})+(lidx1*{LOCAL_SIZE[1]})+lidx2; val1[(lidx3*{LOCAL_SIZE[1]*LOCAL_SIZE[0]})+((lidx1*{LOCAL_SIZE[1]})+lidx2)] = *((device float4*)(data1+(col_1*4))); threadgroup_barrier(mem_flags::mem_threadgroup); for (int ridx3 = 0; ridx3 < {LOCAL_SIZE[0]*LOCAL_SIZE[1]}; ++ridx3) {{ int col = ((((lidx3*{N//(4*LOCAL_SIZE[0]*LOCAL_SIZE[1]*(LOCAL_SIZE[2]))})+ridx2)*{LOCAL_SIZE[0]*LOCAL_SIZE[1]})+ridx3); float4 val1_0 = val1[(lidx3*{LOCAL_SIZE[1]*LOCAL_SIZE[0]})+ridx3]; float4 val2_0 = (float4)(*((device float4*)(data2+(gidx0*{M//GLOBAL_SIZE[0]})+(((lidx1*{LOCAL_SIZE[1]})+lidx2)*4)+(col*{M*4})+{M*0}))); float4 val2_1 = (float4)(*((device float4*)(data2+(gidx0*{M//GLOBAL_SIZE[0]})+(((lidx1*{LOCAL_SIZE[1]})+lidx2)*4)+(col*{M*4})+{M*1}))); float4 val2_2 = (float4)(*((device float4*)(data2+(gidx0*{M//GLOBAL_SIZE[0]})+(((lidx1*{LOCAL_SIZE[1]})+lidx2)*4)+(col*{M*4})+{M*2}))); float4 val2_3 = (float4)(*((device float4*)(data2+(gidx0*{M//GLOBAL_SIZE[0]})+(((lidx1*{LOCAL_SIZE[1]})+lidx2)*4)+(col*{M*4})+{M*3}))); acc0[acc0_index] = ((val1_0.x*val2_0)+acc0[acc0_index]); acc0[acc0_index] = ((val1_0.y*val2_1)+acc0[acc0_index]); acc0[acc0_index] = ((val1_0.z*val2_2)+acc0[acc0_index]); acc0[acc0_index] = ((val1_0.w*val2_3)+acc0[acc0_index]); }} threadgroup_barrier(mem_flags::mem_threadgroup); }} /* reduce */ if (lidx3 == 0) {{ float4 out = float4(0.0f,0.0f,0.0f,0.0f); for (int n = 0; n < {LOCAL_SIZE[2]}; n++) {{ out += acc0[((lidx1*{LOCAL_SIZE[1]})+lidx2)+({LOCAL_SIZE[0]*LOCAL_SIZE[1]}*n)]; }} *( (device float4 *) (data0 + (gidx0*{M//GLOBAL_SIZE[0]}) + ( ( (lidx1*{LOCAL_SIZE[1]})+lidx2 ) * 4 ) ) ) = out; }} }} """ prog = MetalProgram("test", prog_string) # print(prog_string) na = np.zeros(M, dtype=np.float32) b = RawMetalBuffer.fromCPU(nb) c = RawMetalBuffer.fromCPU(nc) def metalrun(): a = RawMetalBuffer.fromCPU(na) prog(GLOBAL_SIZE, LOCAL_SIZE, a, b, c, wait=True) return a def timeit(fxn): st = time.perf_counter() et = fxn() # NOTE: et doesn't contain the launch overhead return time.perf_counter() - st tm = min([timeit(metalrun) for _ in range(200)]) print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in metal") metal_a = metalrun().toCPU().reshape(M) np.testing.assert_allclose(metal_a, torch_a, atol=5e-3) from tinygrad.tensor import Tensor from tinygrad.jit import TinyJit from tinygrad.runtime.ops_metal import METAL b = Tensor(nb) c = Tensor(nc) # TODO: slowness without the JIT I suspect comes from a lack of a caching allocator @TinyJit def tiny_jit(b, c): return (b@c).realize() def tiny_prog(b, c): st = time.perf_counter() a = tiny_jit(b, c) METAL.synchronize() return time.perf_counter() - st tm = min([tiny_prog(b, c) for _ in range(200)]) print(f"{N:d}x{M:d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matvec in tinygrad") tiny_a = tiny_jit(b, c).numpy() np.testing.assert_allclose(tiny_a, torch_a, atol=5e-3)