diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 7b364dd3..a4c6a712 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -1,6 +1,7 @@ # pip3 install pyobjc-framework-Metal pyobjc-framework-Cocoa pyobjc-framework-libdispatch import os, subprocess, pathlib import Metal, Cocoa, libdispatch # type: ignore +from typing import List, Any from tinygrad.codegen.cstyle import CStyleCodegen, CStyleLanguage from tinygrad.helpers import prod, getenv, DEBUG, DType from tinygrad.ops import Compiled @@ -10,17 +11,13 @@ METAL_XCODE = getenv("METAL_XCODE") class _METAL: def __init__(self): + self.mtl_buffers_in_flight: List[Any] = [] self.device = Metal.MTLCreateSystemDefaultDevice() - self.dispatch_group = libdispatch.dispatch_group_create() self.mtl_queue = self.device.newCommandQueue() - def command_buffer(self): - command_buffer = self.mtl_queue.commandBuffer() - libdispatch.dispatch_group_enter(self.dispatch_group) - def leave(_): libdispatch.dispatch_group_leave(self.dispatch_group) - command_buffer.addCompletedHandler_(leave) - return command_buffer + # TODO: is there a better way to do this? def synchronize(self): - libdispatch.dispatch_group_wait(self.dispatch_group, libdispatch.DISPATCH_TIME_FOREVER) + for cbuf in self.mtl_buffers_in_flight: cbuf.waitUntilCompleted() + self.mtl_buffers_in_flight.clear() METAL = _METAL() class RawMetalBuffer(RawBufferMapped): @@ -63,7 +60,7 @@ class MetalProgram: def __call__(self, global_size, local_size, *bufs, wait=False): assert prod(local_size) <= self.pipeline_state.maxTotalThreadsPerThreadgroup(), f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}" - command_buffer = METAL.command_buffer() + command_buffer = METAL.mtl_queue.commandBuffer() encoder = command_buffer.computeCommandEncoder() encoder.setComputePipelineState_(self.pipeline_state) for i,a in enumerate(bufs): encoder.setBuffer_offset_atIndex_(a._buf, 0, i) @@ -73,6 +70,7 @@ class MetalProgram: if wait: command_buffer.waitUntilCompleted() return command_buffer.GPUEndTime() - command_buffer.GPUStartTime() + METAL.mtl_buffers_in_flight.append(command_buffer) class MetalCodegen(CStyleCodegen): lang = CStyleLanguage(