Metal with CDLL instead of py-objc (#6545)

* Add CDLL interface for metal

* remove two unused functions

* Cover most of the API methods

* switch to cdll

* directly call objc message in ops_metal

* keep only obj interface

* Use direct message sending for graph

* may have found a solution to the memoryview on ctypes pointer

* buf indexing bug fixed

* fix c_int

* fix c int to bytes

* fix gpu time bug

* line savings for cdll metal core

* wip

* c int bug

* fix buf casting

* dedup for c_void_p

* dedup for c_void_p

* linter fix

* remove unused stuff

* my py fix

* more mypy error fix

* line savings

* line savings

* rename send_message to msg; add __hash__ and __eq__ for dedup

* wip

* refactor

* refactor

* remove named import from ctypes

* forgot to change variable name

* file reorg, put support.py to ops_metal

* refactor

* hash error

* remove to_ns_array

* test oom exception, fix exception change

* typevar for msg

* add back dedup

* test for compile error

* move constant to graph

* move header constant around

* get label for icb buffer

* check icb label using "in"

* wip fixing mypy reported error

* fixed mypy error

* code formatting

* all_resources dedup match previous

* code formatting

* code formatting; buffer set to objc_id

* revert changes on buf for the manual release, seems like _free is not always called

* skip unless on metal, for test_metal

* fix premature mem release causing seg fault

* test_metal check for device before importing

* Buffer should only be released under _free explicitly

* mypy fixes

* change object ownership

* test compile success

* lint fixes

* remove load_library

* wrap sel_register in cache

* simplify to_struct

* swap lines

* fix type error in to_struct

* bump line to 9800

* remove pyobjc from setup.py

* command buffer should be objc_instance and get released

* stringWithUTF8String: returns objc_instance

* Use constant for MTLPipelineOptionNone

* better explanation for [MTLBuffer contents:] return

* Use dyld_find in case the path differs

* trailing whitespace

* handle exception for methods that take error:

* load /System/Library instead of /Library

* Init c_void_p with None instead of zero for error objects

---------

Co-authored-by: Mesozoic Egg <mesozoic.egg@proton.me>
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
mesozoic-egg 2024-09-25 17:43:01 +08:00 committed by GitHub
parent cd534dee11
commit 992cde05d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 206 additions and 81 deletions

View File

@ -21,9 +21,7 @@ setup(name='tinygrad',
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License"
],
install_requires=["numpy",
"pyobjc-framework-Metal; platform_system=='Darwin'",
"pyobjc-framework-libdispatch; platform_system=='Darwin'"],
install_requires=["numpy"],
python_requires='>=3.8',
extras_require={
'llvm': ["llvmlite"],

54
test/test_metal.py Normal file
View File

@ -0,0 +1,54 @@
import unittest
from tinygrad.device import CompileError, Device
if Device.DEFAULT=="METAL":
from tinygrad.runtime.ops_metal import MetalDevice, MetalCompiler, MetalProgram
@unittest.skipIf(Device.DEFAULT!="METAL", "Metal support required")
class TestMetal(unittest.TestCase):
def test_alloc_oom(self):
device = MetalDevice("metal")
with self.assertRaises(MemoryError):
device.allocator.alloc(10000000000000000000)
def test_compile_error(self):
device = MetalDevice("metal")
compiler = MetalCompiler(device)
with self.assertRaises(CompileError):
compiler.compile("this is not valid metal")
def test_compile_success(self):
device = MetalDevice("metal")
compiler = MetalCompiler(device)
ret = compiler.compile("""
#include <metal_stdlib>
using namespace metal;
kernel void E_4n1(device int* data0, const device int* data1, const device int* data2,
uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
int val0 = *(data1+0);
int val1 = *(data1+1);
int val2 = *(data1+2);
int val3 = *(data1+3);
int val4 = *(data2+0);
int val5 = *(data2+1);
int val6 = *(data2+2);
int val7 = *(data2+3);
*(data0+0) = (val0+val4);
*(data0+1) = (val1+val5);
*(data0+2) = (val2+val6);
*(data0+3) = (val3+val7);
}
""")
assert ret is not None
def test_failed_newLibraryWithData(self):
device = MetalDevice("metal")
compiler = MetalCompiler(device)
compiled = compiler.compile("""
#include <metal_stdlib>
kernel void r_5(device int* data0, const device int* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]){
data0[0] = 0;
}
""")
with self.assertRaises(RuntimeError):
compiled = compiled[:40] # corrupt the compiled program
MetalProgram(device, "r_5", compiled)

View File

@ -1,12 +1,20 @@
from typing import List, Any, Dict, cast, Optional
import Metal
import ctypes
from tinygrad.dtype import dtypes
from tinygrad.helpers import dedup, getenv
from tinygrad.device import Buffer
from tinygrad.engine.realize import ExecItem, CompiledRunner
from tinygrad.engine.jit import GraphRunner, GraphException
from tinygrad.shape.symbolic import Variable
from tinygrad.runtime.ops_metal import wait_check
from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\
MTLResourceOptions, elapsed_time, objc_id
class MTLIndirectCommandType:
MTLIndirectCommandTypeConcurrentDispatch = (1 << 5)
class MTLResourceUsage:
MTLResourceUsageRead = 0b01
MTLResourceUsageWrite = 0b10
class MetalGraph(GraphRunner):
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
@ -14,58 +22,64 @@ class MetalGraph(GraphRunner):
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
# create metal batch exec
icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new()
icb_descriptor.setCommandTypes_(Metal.MTLIndirectCommandType(Metal.MTLIndirectCommandTypeConcurrentDispatch))
icb_descriptor.setInheritBuffers_(False)
icb_descriptor.setInheritPipelineState_(False)
icb_descriptor.setMaxKernelBufferBindCount_(31)
self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache),
Metal.MTLResourceOptions(0))
if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
self.needs_icb_fix = int(type(self.icb).__name__ != "AGXG15XFamilyIndirectCommandBuffer") # not required on M3
icb_descriptor = msg(libobjc.objc_getClass(b"MTLIndirectCommandBufferDescriptor"), "new", restype=objc_instance)
msg(icb_descriptor, "setCommandTypes:", MTLIndirectCommandType.MTLIndirectCommandTypeConcurrentDispatch)
msg(icb_descriptor, "setInheritBuffers:", False)
msg(icb_descriptor, "setInheritPipelineState:", False)
msg(icb_descriptor, "setMaxKernelBufferBindCount:", 31)
self.icb = msg(self.device.device, "newIndirectCommandBufferWithDescriptor:maxCommandCount:options:",
icb_descriptor, len(self.jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache, restype=objc_instance)
if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?")
icb_label = bytes(msg(msg(self.icb, "description", restype=objc_instance), "UTF8String", restype=ctypes.c_char_p)).decode()
self.needs_icb_fix = int("AGXG15XFamilyIndirectCommandBuffer" not in icb_label) # not required on M3
if len(self.vars): self.int_buf = self.device.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
all_resources = [self.int_buf.buf] if len(self.vars) else []
all_pipelines = []
for j,ji in enumerate(self.jit_cache):
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
icb_command = self.icb.indirectComputeCommandAtIndex_(j)
icb_command = msg(self.icb, "indirectComputeCommandAtIndex:", j, restype=objc_instance)
all_pipelines.append(prg.clprg.pipeline_state)
icb_command.setComputePipelineState_(prg.clprg.pipeline_state)
msg(icb_command, "setComputePipelineState:", prg.clprg.pipeline_state)
for i,b in enumerate(ji.bufs):
if b is not None and b not in input_rawbuffers:
icb_command.setKernelBuffer_offset_atIndex_(b._buf.buf, b._buf.offset, i)
msg(icb_command, "setKernelBuffer:offset:atIndex:", b._buf.buf, b._buf.offset, i)
all_resources.append(b._buf.buf)
for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex_(self.int_buf.buf, self.vars.index(v)*4, len(ji.bufs)+i)
for i,v in enumerate(prg.p.vars): msg(icb_command, "setKernelBuffer:offset:atIndex:", self.int_buf.buf, self.vars.index(v)*4, len(ji.bufs)+i)
global_size, local_size = prg.p.launch_dims(var_vals)
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
icb_command.setBarrier()
msg(icb_command, "concurrentDispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_size), to_struct(*local_size))
msg(icb_command, "setBarrier")
self.all_resources = dedup(all_resources)
self.all_pipelines = dedup(all_pipelines)
self.command_buffer: Any = None
if len(self.vars): self.int_buf_view = self.int_buf.buf.contents().as_buffer(self.int_buf.buf.length()).cast('i')
self.range = Metal.MTLIndirectCommandBufferExecutionRangeMake(0, len(self.jit_cache))
if len(self.vars): self.int_buf_view = self.device.allocator.as_buffer(self.int_buf).cast('i')
self.range = to_struct(0, len(self.jit_cache))
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer)
all_resources = dedup(self.all_resources + [x._buf.buf for x in input_rawbuffers])
for (j,i),input_idx in self.input_replace.items():
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf.buf,
computeCommand = msg(self.icb, "indirectComputeCommandAtIndex:", j, restype=objc_id)
msg(computeCommand, "setKernelBuffer:offset:atIndex:", input_rawbuffers[input_idx]._buf.buf,
input_rawbuffers[input_idx]._buf.offset, i)
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
prg = cast(CompiledRunner, self.jit_cache[j].prg)
global_size, local_size = global_dims or prg.p.global_size, local_dims or prg.p.local_size
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size),
Metal.MTLSize(*local_size))
computeCommand = msg(self.icb, "indirectComputeCommandAtIndex:", j)
msg(computeCommand, "concurrentDispatchThreadgroups:threadsPerThreadgroup:",
to_struct(*cast(tuple, global_size)), to_struct(*cast(tuple, local_size)))
for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var]
command_buffer = self.device.mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
encoder.useResources_count_usage_(all_resources, len(all_resources), Metal.MTLResourceUsageRead | Metal.MTLResourceUsageWrite)
command_buffer = msg(self.device.mtl_queue, "commandBuffer", restype=objc_instance)
encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance)
msg(encoder, "useResources:count:usage:", (objc_id * len(all_resources))(*all_resources), len(all_resources),
MTLResourceUsage.MTLResourceUsageRead | MTLResourceUsage.MTLResourceUsageWrite)
# NOTE: the pipelines likely need to be added to the used resources to fix the crash on M1/M2, but I haven't figured out how
# this is a O(n) hack to get them used. what should work is:
@ -74,16 +88,16 @@ class MetalGraph(GraphRunner):
# to repro the crash (which can also crash other running GPU apps), run with FIX_METAL_ICB=0
if getenv("FIX_METAL_ICB", self.needs_icb_fix):
for ps in self.all_pipelines:
encoder.setComputePipelineState_(ps)
encoder.dispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(0,0,0), Metal.MTLSize(0,0,0))
msg(encoder, "setComputePipelineState:", ps)
msg(encoder, "dispatchThreadgroups:threadsPerThreadgroup:", to_struct(0,0,0), to_struct(0,0,0))
encoder.executeCommandsInBuffer_withRange_(self.icb, self.range)
encoder.endEncoding()
command_buffer.commit()
msg(encoder, "executeCommandsInBuffer:withRange:", self.icb, self.range)
msg(encoder, "endEncoding")
msg(command_buffer, "commit")
self.command_buffer = command_buffer
if wait:
wait_check(command_buffer)
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
return elapsed_time(command_buffer)
self.device.mtl_buffers_in_flight.append(command_buffer)
return None

View File

@ -1,15 +1,61 @@
from __future__ import annotations
import os, subprocess, pathlib, ctypes, tempfile, functools
import Metal, libdispatch
from typing import List, Any, Tuple, Optional
from tinygrad.helpers import prod, getenv, DEBUG, unwrap2
from typing import List, Any, Tuple, Optional, cast, TypeVar
from tinygrad.helpers import prod, getenv, DEBUG
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
from tinygrad.renderer.cstyle import MetalRenderer
class objc_id(ctypes.c_void_p): # This prevents ctypes from converting response to plain int, and dict.fromkeys() can use it to dedup
def __hash__(self): return hash(self.value)
def __eq__(self, other): return self.value == other.value
class objc_instance(objc_id): # method with name "new", "alloc" should be freed after use
def __del__(self): msg(self, "release")
@functools.lru_cache(None)
def sel(name: str): return libobjc.sel_registerName(name.encode())
class MTLResourceOptions:
MTLResourceCPUCacheModeDefaultCache = 0
MTLResourceStorageModeShared = 0 << 4
class MTLPipelineOption:
MTLPipelineOptionNone = 0
libobjc = ctypes.CDLL("/usr/lib/libobjc.dylib")
libmetal = ctypes.CDLL("/System/Library/Frameworks/Metal.framework/Metal")
# Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc
ctypes.CDLL("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics")
libdispatch = ctypes.CDLL("/usr/lib/libSystem.dylib") # libdispatch is part of libSystem on mac
libobjc.objc_getClass.restype = objc_id
libobjc.sel_registerName.restype = objc_id
libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance
libdispatch.dispatch_data_create.restype = objc_instance
T = TypeVar("T")
# Ignore mypy error reporting incompatible default, because typevar default only works on python 3.12
def msg(ptr: objc_id, selector: str, /, *args: Any, restype: type[T] = objc_id) -> T: # type: ignore [assignment]
sender = libobjc["objc_msgSend"] # Using attribute access returns a new reference so setting restype is safe
sender.restype = restype
return sender(ptr, sel(selector), *args)
def to_ns_str(s: str): return msg(libobjc.objc_getClass(b"NSString"), "stringWithUTF8String:", s.encode(), restype=objc_instance)
def to_struct(*t: int, _type: type = ctypes.c_ulong):
class Struct(ctypes.Structure): pass
Struct._fields_ = [(f"field{i}", _type) for i in range(len(t))]
return Struct(*t)
def wait_check(cbuf: Any):
cbuf.waitUntilCompleted()
if (error := cbuf.error()) is not None:
raise RuntimeError(error)
msg(cbuf, "waitUntilCompleted")
if (error := cast(int, msg(cbuf, "error", restype=ctypes.c_ulong))) != 0: raise RuntimeError(error)
def elapsed_time(cbuf: objc_id):
return cast(float, msg(cbuf, "GPUEndTime", restype=ctypes.c_double)) - cast(float, msg(cbuf, "GPUStartTime", restype=ctypes.c_double))
def error_check(error: objc_instance, error_constructor: type[Exception] = RuntimeError):
if error.value is None: return None
raise error_constructor(bytes(msg(msg(error, "localizedDescription", restype=objc_instance), "UTF8String", restype=ctypes.c_char_p)).decode())
class MetalCompiler(Compiler):
def __init__(self, device:Optional[MetalDevice]):
@ -20,11 +66,14 @@ class MetalCompiler(Compiler):
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=src.encode('utf-8'))
return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
options = Metal.MTLCompileOptions.new()
options.setFastMathEnabled_(getenv("METAL_FAST_MATH"))
try: library = unwrap2(self.device.device.newLibraryWithSource_options_error_(src, options, None))
except AssertionError as e: raise CompileError(e) from e
return library.libraryDataContents().bytes().tobytes()
options = msg(libobjc.objc_getClass(b"MTLCompileOptions"), "new", restype=objc_instance)
msg(options, "setFastMathEnabled:", getenv("METAL_FAST_MATH"))
compileError = objc_instance()
library = msg(self.device.device, "newLibraryWithSource:options:error:", to_ns_str(src),
options, ctypes.byref(compileError), restype=objc_instance)
error_check(compileError, CompileError)
library_contents = msg(library, "libraryDataContents", restype=objc_instance)
return ctypes.string_at(msg(library_contents, "bytes"), cast(int, msg(library_contents, "length", restype=ctypes.c_ulong)))
class MetalProgram:
def __init__(self, device:MetalDevice, name:str, lib:bytes):
@ -38,27 +87,35 @@ class MetalProgram:
print("Error running disassembler: Make sure you have https://github.com/dougallj/applegpu cloned to tinygrad/extra/disassemblers/applegpu")
assert lib[:4] == b"MTLB", "Invalid Metal library. Could be due to using conda. Try system python or METAL_XCODE=1 DISABLE_COMPILER_CACHE=1."
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
self.library = unwrap2(self.device.device.newLibraryWithData_error_(data, None))
self.fxn = self.library.newFunctionWithName_(name)
descriptor = Metal.MTLComputePipelineDescriptor.new()
descriptor.setComputeFunction_(self.fxn)
descriptor.setSupportIndirectCommandBuffers_(True)
self.pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(
descriptor, Metal.MTLPipelineOption(0), None, None))
error_library_creation = objc_instance()
self.library = msg(self.device.device, "newLibraryWithData:error:", data, ctypes.byref(error_library_creation), restype=objc_instance)
error_check(error_library_creation)
self.fxn = msg(self.library, "newFunctionWithName:", to_ns_str(name), restype=objc_instance)
descriptor = msg(libobjc.objc_getClass(b"MTLComputePipelineDescriptor"), "new", restype=objc_instance)
msg(descriptor, "setComputeFunction:", self.fxn)
msg(descriptor, "setSupportIndirectCommandBuffers:", True)
error_pipeline_creation = objc_instance()
self.pipeline_state = msg(self.device.device, "newComputePipelineStateWithDescriptor:options:reflection:error:",
descriptor, MTLPipelineOption.MTLPipelineOptionNone, None, ctypes.byref(error_pipeline_creation), restype=objc_instance)
error_check(error_pipeline_creation)
def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
if prod(local_size) > self.pipeline_state.maxTotalThreadsPerThreadgroup(): raise RuntimeError(f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}") # noqa: E501
command_buffer = self.device.mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
encoder.setComputePipelineState_(self.pipeline_state)
for i,a in enumerate(bufs): encoder.setBuffer_offset_atIndex_(a.buf, a.offset, i)
for i,a in enumerate(vals,start=len(bufs)): encoder.setBytes_length_atIndex_(ctypes.c_int32(a), 4, i)
encoder.dispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
encoder.endEncoding()
command_buffer.commit()
max_total_threads = msg(self.pipeline_state, "maxTotalThreadsPerThreadgroup", restype=ctypes.c_ulong)
if prod(local_size) > cast(int, max_total_threads):
exec_width = msg(self.pipeline_state, "threadExecutionWidth", restype=ctypes.c_ulong)
memory_length = msg(self.pipeline_state, "staticThreadgroupMemoryLength", restype=ctypes.c_ulong)
raise RuntimeError(f"local size {local_size} bigger than {max_total_threads} with exec width {exec_width} memory length {memory_length}")
command_buffer = msg(self.device.mtl_queue, "commandBuffer", restype=objc_instance)
encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance)
msg(encoder, "setComputePipelineState:", self.pipeline_state)
for i,a in enumerate(bufs): msg(encoder, "setBuffer:offset:atIndex:", a.buf, a.offset, i)
for i,a in enumerate(vals,start=len(bufs)): msg(encoder, "setBytes:length:atIndex:", bytes(ctypes.c_int(a)), 4, i)
msg(encoder, "dispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_size), to_struct(*local_size))
msg(encoder, "endEncoding")
msg(command_buffer, "commit")
if wait:
wait_check(command_buffer)
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
return elapsed_time(command_buffer)
self.device.mtl_buffers_in_flight.append(command_buffer)
class MetalBuffer:
@ -69,46 +126,48 @@ class MetalAllocator(LRUAllocator):
self.device:MetalDevice = device
super().__init__()
def _alloc(self, size:int, options) -> MetalBuffer:
ret = self.device.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
if ret is None: raise MemoryError(f"Metal OOM while allocating {size=}")
# Buffer is explicitly released in _free() rather than garbage collected via reference count
ret = msg(self.device.device, "newBufferWithLength:options:", size, MTLResourceOptions.MTLResourceStorageModeShared, restype=objc_id)
if ret.value is None: raise MemoryError(f"Metal OOM while allocating {size=}")
return MetalBuffer(ret, size)
def _free(self, opaque:MetalBuffer, options): opaque.buf.release()
def _free(self, opaque:MetalBuffer, options): msg(opaque.buf, "release")
def transfer(self, dest:MetalBuffer, src:MetalBuffer, sz:int, src_dev:MetalDevice, dest_dev:MetalDevice):
dest_dev.synchronize()
src_command_buffer = src_dev.mtl_queue.commandBuffer()
encoder = src_command_buffer.blitCommandEncoder()
encoder.copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_(src.buf, src.offset, dest.buf, dest.offset, sz)
encoder.endEncoding()
src_command_buffer = msg(src_dev.mtl_queue, "commandBuffer", restype=objc_instance)
encoder = msg(src_command_buffer, "blitCommandEncoder", restype=objc_instance)
msg(encoder, "copyFromBuffer:sourceOffset:toBuffer:destinationOffset:size:", src.buf, src.offset, dest.buf, dest.offset, sz)
msg(encoder, "endEncoding")
if src_dev != dest_dev:
src_command_buffer.encodeSignalEvent_value_(src_dev.timeline_signal, src_dev.timeline_value)
dest_command_buffer = dest_dev.mtl_queue.commandBuffer()
dest_command_buffer.encodeWaitForEvent_value_(src_dev.timeline_signal, src_dev.timeline_value)
dest_command_buffer.commit()
msg(src_command_buffer, "encodeSignalEvent:value:", src_dev.timeline_signal, src_dev.timeline_value)
dest_command_buffer = msg(dest_dev.mtl_queue, "commandBuffer", restype=objc_instance)
msg(dest_command_buffer, "encodeWaitForEvent:value:", src_dev.timeline_signal, src_dev.timeline_value)
msg(dest_command_buffer, "commit")
dest_dev.mtl_buffers_in_flight.append(dest_command_buffer)
src_dev.timeline_value += 1
src_command_buffer.commit()
msg(src_command_buffer, "commit")
src_dev.mtl_buffers_in_flight.append(src_command_buffer)
def from_buffer(self, src:memoryview) -> Optional[Any]:
ret = self.device.device.newBufferWithBytesNoCopy_length_options_deallocator_(src, src.nbytes, Metal.MTLResourceStorageModeShared, None)
ptr = (ctypes.c_char * src.nbytes).from_buffer(src)
ret = msg(self.device.device, "newBufferWithBytesNoCopy:length:options:deallocator:", ptr, src.nbytes, 0, None, restype=objc_instance)
if ret: self.device.mv_in_metal.append(src)
return MetalBuffer(ret, src.nbytes)
def as_buffer(self, src:MetalBuffer) -> memoryview:
self.device.synchronize()
return src.buf.contents().as_buffer(src.offset+src.size)[src.offset:]
ptr = msg(src.buf, "contents", restype=objc_id) # Shared memory, do not release here
array = (ctypes.c_char * (src.offset + src.size)).from_address(ptr.value)
return memoryview(array).cast("B")[src.offset:]
def copyin(self, dest:MetalBuffer, src:memoryview): self.as_buffer(dest)[:] = src
def copyout(self, dest:memoryview, src:MetalBuffer): dest[:] = self.as_buffer(src)
def offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)
class MetalDevice(Compiled):
def __init__(self, device:str):
self.device = Metal.MTLCreateSystemDefaultDevice()
self.mtl_queue = self.device.newCommandQueueWithMaxCommandBufferCount_(1024)
self.device = libmetal.MTLCreateSystemDefaultDevice()
self.mtl_queue = msg(self.device, "newCommandQueueWithMaxCommandBufferCount:", 1024, restype=objc_instance)
if self.mtl_queue is None: raise RuntimeError("Cannot allocate a new command queue")
self.mtl_buffers_in_flight: List[Any] = []
self.mv_in_metal: List[memoryview] = []
self.timeline_signal = self.device.newSharedEvent()
self.timeline_signal = msg(self.device, "newSharedEvent", restype=objc_instance)
self.timeline_value = 0
from tinygrad.runtime.graph.metal import MetalGraph