mirror of https://github.com/commaai/tinygrad.git
Metal with CDLL instead of py-objc (#6545)
* Add CDLL interface for metal * remove two unused functions * Cover most of the API methods * switch to cdll * directly call objc message in ops_metal * keep only obj interface * Use direct message sending for graph * may have found a solution to the memoryview on ctypes pointer * buf indexing bug fixed * fix c_int * fix c int to bytes * fix gpu time bug * line savings for cdll metal core * wip * c int bug * fix buf casting * dedup for c_void_p * dedup for c_void_p * linter fix * remove unused stuff * my py fix * more mypy error fix * line savings * line savings * rename send_message to msg; add __hash__ and __eq__ for dedup * wip * refactor * refactor * remove named import from ctypes * forgot to change variable name * file reorg, put support.py to ops_metal * refactor * hash error * remove to_ns_array * test oom exception, fix exception change * typevar for msg * add back dedup * test for compile error * move constant to graph * move header constant around * get label for icb buffer * check icb label using "in" * wip fixing mypy reported error * fixed mypy error * code formatting * all_resources dedup match previous * code formatting * code formatting; buffer set to objc_id * revert changes on buf for the manual release, seems like _free is not always called * skip unless on metal, for test_metal * fix premature mem release causing seg fault * test_metal check for device before importing * Buffer should only be released under _free explicitly * mypy fixes * change object ownership * test compile success * lint fixes * remove load_library * wrap sel_register in cache * simplify to_struct * swap lines * fix type error in to_struct * bump line to 9800 * remove pyobjc from setup.py * command buffer should be objc_instance and get released * stringWithUTF8String: returns objc_instance * Use constant for MTLPipelineOptionNone * better explanation for [MTLBuffer contents:] return * Use dyld_find in case the path differs * trailing whitespace * handle exception for methods that take error: * load /System/Library instead of /Library * Init c_void_p with None instead of zero for error objects --------- Co-authored-by: Mesozoic Egg <mesozoic.egg@proton.me> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
parent
cd534dee11
commit
992cde05d7
4
setup.py
4
setup.py
|
@ -21,9 +21,7 @@ setup(name='tinygrad',
|
|||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License"
|
||||
],
|
||||
install_requires=["numpy",
|
||||
"pyobjc-framework-Metal; platform_system=='Darwin'",
|
||||
"pyobjc-framework-libdispatch; platform_system=='Darwin'"],
|
||||
install_requires=["numpy"],
|
||||
python_requires='>=3.8',
|
||||
extras_require={
|
||||
'llvm': ["llvmlite"],
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
import unittest
|
||||
from tinygrad.device import CompileError, Device
|
||||
if Device.DEFAULT=="METAL":
|
||||
from tinygrad.runtime.ops_metal import MetalDevice, MetalCompiler, MetalProgram
|
||||
@unittest.skipIf(Device.DEFAULT!="METAL", "Metal support required")
|
||||
class TestMetal(unittest.TestCase):
|
||||
def test_alloc_oom(self):
|
||||
device = MetalDevice("metal")
|
||||
with self.assertRaises(MemoryError):
|
||||
device.allocator.alloc(10000000000000000000)
|
||||
|
||||
def test_compile_error(self):
|
||||
device = MetalDevice("metal")
|
||||
compiler = MetalCompiler(device)
|
||||
with self.assertRaises(CompileError):
|
||||
compiler.compile("this is not valid metal")
|
||||
|
||||
def test_compile_success(self):
|
||||
device = MetalDevice("metal")
|
||||
compiler = MetalCompiler(device)
|
||||
ret = compiler.compile("""
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
kernel void E_4n1(device int* data0, const device int* data1, const device int* data2,
|
||||
uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]) {
|
||||
int val0 = *(data1+0);
|
||||
int val1 = *(data1+1);
|
||||
int val2 = *(data1+2);
|
||||
int val3 = *(data1+3);
|
||||
int val4 = *(data2+0);
|
||||
int val5 = *(data2+1);
|
||||
int val6 = *(data2+2);
|
||||
int val7 = *(data2+3);
|
||||
*(data0+0) = (val0+val4);
|
||||
*(data0+1) = (val1+val5);
|
||||
*(data0+2) = (val2+val6);
|
||||
*(data0+3) = (val3+val7);
|
||||
}
|
||||
""")
|
||||
assert ret is not None
|
||||
|
||||
def test_failed_newLibraryWithData(self):
|
||||
device = MetalDevice("metal")
|
||||
compiler = MetalCompiler(device)
|
||||
compiled = compiler.compile("""
|
||||
#include <metal_stdlib>
|
||||
kernel void r_5(device int* data0, const device int* data1, uint3 gid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]]){
|
||||
data0[0] = 0;
|
||||
}
|
||||
""")
|
||||
with self.assertRaises(RuntimeError):
|
||||
compiled = compiled[:40] # corrupt the compiled program
|
||||
MetalProgram(device, "r_5", compiled)
|
||||
|
|
@ -1,12 +1,20 @@
|
|||
from typing import List, Any, Dict, cast, Optional
|
||||
import Metal
|
||||
import ctypes
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import dedup, getenv
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
from tinygrad.engine.jit import GraphRunner, GraphException
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.runtime.ops_metal import wait_check
|
||||
from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\
|
||||
MTLResourceOptions, elapsed_time, objc_id
|
||||
|
||||
class MTLIndirectCommandType:
|
||||
MTLIndirectCommandTypeConcurrentDispatch = (1 << 5)
|
||||
|
||||
class MTLResourceUsage:
|
||||
MTLResourceUsageRead = 0b01
|
||||
MTLResourceUsageWrite = 0b10
|
||||
|
||||
class MetalGraph(GraphRunner):
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
|
@ -14,58 +22,64 @@ class MetalGraph(GraphRunner):
|
|||
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
||||
|
||||
# create metal batch exec
|
||||
icb_descriptor = Metal.MTLIndirectCommandBufferDescriptor.new()
|
||||
icb_descriptor.setCommandTypes_(Metal.MTLIndirectCommandType(Metal.MTLIndirectCommandTypeConcurrentDispatch))
|
||||
icb_descriptor.setInheritBuffers_(False)
|
||||
icb_descriptor.setInheritPipelineState_(False)
|
||||
icb_descriptor.setMaxKernelBufferBindCount_(31)
|
||||
self.icb = self.device.device.newIndirectCommandBufferWithDescriptor_maxCommandCount_options_(icb_descriptor, len(self.jit_cache),
|
||||
Metal.MTLResourceOptions(0))
|
||||
if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
|
||||
self.needs_icb_fix = int(type(self.icb).__name__ != "AGXG15XFamilyIndirectCommandBuffer") # not required on M3
|
||||
icb_descriptor = msg(libobjc.objc_getClass(b"MTLIndirectCommandBufferDescriptor"), "new", restype=objc_instance)
|
||||
msg(icb_descriptor, "setCommandTypes:", MTLIndirectCommandType.MTLIndirectCommandTypeConcurrentDispatch)
|
||||
msg(icb_descriptor, "setInheritBuffers:", False)
|
||||
msg(icb_descriptor, "setInheritPipelineState:", False)
|
||||
msg(icb_descriptor, "setMaxKernelBufferBindCount:", 31)
|
||||
|
||||
self.icb = msg(self.device.device, "newIndirectCommandBufferWithDescriptor:maxCommandCount:options:",
|
||||
icb_descriptor, len(self.jit_cache), MTLResourceOptions.MTLResourceCPUCacheModeDefaultCache, restype=objc_instance)
|
||||
if self.icb.value is None: raise GraphException("create indirect command buffer failed, does your system support this?")
|
||||
icb_label = bytes(msg(msg(self.icb, "description", restype=objc_instance), "UTF8String", restype=ctypes.c_char_p)).decode()
|
||||
self.needs_icb_fix = int("AGXG15XFamilyIndirectCommandBuffer" not in icb_label) # not required on M3
|
||||
|
||||
if len(self.vars): self.int_buf = self.device.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
|
||||
all_resources = [self.int_buf.buf] if len(self.vars) else []
|
||||
all_pipelines = []
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
|
||||
icb_command = self.icb.indirectComputeCommandAtIndex_(j)
|
||||
icb_command = msg(self.icb, "indirectComputeCommandAtIndex:", j, restype=objc_instance)
|
||||
all_pipelines.append(prg.clprg.pipeline_state)
|
||||
icb_command.setComputePipelineState_(prg.clprg.pipeline_state)
|
||||
msg(icb_command, "setComputePipelineState:", prg.clprg.pipeline_state)
|
||||
for i,b in enumerate(ji.bufs):
|
||||
if b is not None and b not in input_rawbuffers:
|
||||
icb_command.setKernelBuffer_offset_atIndex_(b._buf.buf, b._buf.offset, i)
|
||||
msg(icb_command, "setKernelBuffer:offset:atIndex:", b._buf.buf, b._buf.offset, i)
|
||||
all_resources.append(b._buf.buf)
|
||||
for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex_(self.int_buf.buf, self.vars.index(v)*4, len(ji.bufs)+i)
|
||||
for i,v in enumerate(prg.p.vars): msg(icb_command, "setKernelBuffer:offset:atIndex:", self.int_buf.buf, self.vars.index(v)*4, len(ji.bufs)+i)
|
||||
|
||||
global_size, local_size = prg.p.launch_dims(var_vals)
|
||||
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
||||
icb_command.setBarrier()
|
||||
msg(icb_command, "concurrentDispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_size), to_struct(*local_size))
|
||||
msg(icb_command, "setBarrier")
|
||||
|
||||
self.all_resources = dedup(all_resources)
|
||||
self.all_pipelines = dedup(all_pipelines)
|
||||
self.command_buffer: Any = None
|
||||
if len(self.vars): self.int_buf_view = self.int_buf.buf.contents().as_buffer(self.int_buf.buf.length()).cast('i')
|
||||
self.range = Metal.MTLIndirectCommandBufferExecutionRangeMake(0, len(self.jit_cache))
|
||||
if len(self.vars): self.int_buf_view = self.device.allocator.as_buffer(self.int_buf).cast('i')
|
||||
self.range = to_struct(0, len(self.jit_cache))
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
|
||||
if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
||||
all_resources = dedup(self.all_resources + [x._buf.buf for x in input_rawbuffers])
|
||||
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf.buf,
|
||||
computeCommand = msg(self.icb, "indirectComputeCommandAtIndex:", j, restype=objc_id)
|
||||
msg(computeCommand, "setKernelBuffer:offset:atIndex:", input_rawbuffers[input_idx]._buf.buf,
|
||||
input_rawbuffers[input_idx]._buf.offset, i)
|
||||
|
||||
for j, global_dims, local_dims in self.updated_launch_dims(var_vals):
|
||||
prg = cast(CompiledRunner, self.jit_cache[j].prg)
|
||||
global_size, local_size = global_dims or prg.p.global_size, local_dims or prg.p.local_size
|
||||
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size),
|
||||
Metal.MTLSize(*local_size))
|
||||
computeCommand = msg(self.icb, "indirectComputeCommandAtIndex:", j)
|
||||
msg(computeCommand, "concurrentDispatchThreadgroups:threadsPerThreadgroup:",
|
||||
to_struct(*cast(tuple, global_size)), to_struct(*cast(tuple, local_size)))
|
||||
for j, var in enumerate(self.vars): self.int_buf_view[j] = var_vals[var]
|
||||
|
||||
command_buffer = self.device.mtl_queue.commandBuffer()
|
||||
encoder = command_buffer.computeCommandEncoder()
|
||||
encoder.useResources_count_usage_(all_resources, len(all_resources), Metal.MTLResourceUsageRead | Metal.MTLResourceUsageWrite)
|
||||
command_buffer = msg(self.device.mtl_queue, "commandBuffer", restype=objc_instance)
|
||||
encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance)
|
||||
msg(encoder, "useResources:count:usage:", (objc_id * len(all_resources))(*all_resources), len(all_resources),
|
||||
MTLResourceUsage.MTLResourceUsageRead | MTLResourceUsage.MTLResourceUsageWrite)
|
||||
|
||||
# NOTE: the pipelines likely need to be added to the used resources to fix the crash on M1/M2, but I haven't figured out how
|
||||
# this is a O(n) hack to get them used. what should work is:
|
||||
|
@ -74,16 +88,16 @@ class MetalGraph(GraphRunner):
|
|||
# to repro the crash (which can also crash other running GPU apps), run with FIX_METAL_ICB=0
|
||||
if getenv("FIX_METAL_ICB", self.needs_icb_fix):
|
||||
for ps in self.all_pipelines:
|
||||
encoder.setComputePipelineState_(ps)
|
||||
encoder.dispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(0,0,0), Metal.MTLSize(0,0,0))
|
||||
msg(encoder, "setComputePipelineState:", ps)
|
||||
msg(encoder, "dispatchThreadgroups:threadsPerThreadgroup:", to_struct(0,0,0), to_struct(0,0,0))
|
||||
|
||||
encoder.executeCommandsInBuffer_withRange_(self.icb, self.range)
|
||||
encoder.endEncoding()
|
||||
command_buffer.commit()
|
||||
msg(encoder, "executeCommandsInBuffer:withRange:", self.icb, self.range)
|
||||
msg(encoder, "endEncoding")
|
||||
msg(command_buffer, "commit")
|
||||
self.command_buffer = command_buffer
|
||||
|
||||
if wait:
|
||||
wait_check(command_buffer)
|
||||
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
||||
return elapsed_time(command_buffer)
|
||||
self.device.mtl_buffers_in_flight.append(command_buffer)
|
||||
return None
|
||||
|
|
|
@ -1,15 +1,61 @@
|
|||
from __future__ import annotations
|
||||
import os, subprocess, pathlib, ctypes, tempfile, functools
|
||||
import Metal, libdispatch
|
||||
from typing import List, Any, Tuple, Optional
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, unwrap2
|
||||
from typing import List, Any, Tuple, Optional, cast, TypeVar
|
||||
from tinygrad.helpers import prod, getenv, DEBUG
|
||||
from tinygrad.device import Compiled, Compiler, CompileError, LRUAllocator
|
||||
from tinygrad.renderer.cstyle import MetalRenderer
|
||||
|
||||
class objc_id(ctypes.c_void_p): # This prevents ctypes from converting response to plain int, and dict.fromkeys() can use it to dedup
|
||||
def __hash__(self): return hash(self.value)
|
||||
def __eq__(self, other): return self.value == other.value
|
||||
|
||||
class objc_instance(objc_id): # method with name "new", "alloc" should be freed after use
|
||||
def __del__(self): msg(self, "release")
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def sel(name: str): return libobjc.sel_registerName(name.encode())
|
||||
|
||||
class MTLResourceOptions:
|
||||
MTLResourceCPUCacheModeDefaultCache = 0
|
||||
MTLResourceStorageModeShared = 0 << 4
|
||||
|
||||
class MTLPipelineOption:
|
||||
MTLPipelineOptionNone = 0
|
||||
|
||||
libobjc = ctypes.CDLL("/usr/lib/libobjc.dylib")
|
||||
libmetal = ctypes.CDLL("/System/Library/Frameworks/Metal.framework/Metal")
|
||||
# Must be loaded for default Metal Device: https://developer.apple.com/documentation/metal/1433401-mtlcreatesystemdefaultdevice?language=objc
|
||||
ctypes.CDLL("/System/Library/Frameworks/CoreGraphics.framework/CoreGraphics")
|
||||
libdispatch = ctypes.CDLL("/usr/lib/libSystem.dylib") # libdispatch is part of libSystem on mac
|
||||
libobjc.objc_getClass.restype = objc_id
|
||||
libobjc.sel_registerName.restype = objc_id
|
||||
libmetal.MTLCreateSystemDefaultDevice.restype = objc_instance
|
||||
libdispatch.dispatch_data_create.restype = objc_instance
|
||||
|
||||
T = TypeVar("T")
|
||||
# Ignore mypy error reporting incompatible default, because typevar default only works on python 3.12
|
||||
def msg(ptr: objc_id, selector: str, /, *args: Any, restype: type[T] = objc_id) -> T: # type: ignore [assignment]
|
||||
sender = libobjc["objc_msgSend"] # Using attribute access returns a new reference so setting restype is safe
|
||||
sender.restype = restype
|
||||
return sender(ptr, sel(selector), *args)
|
||||
|
||||
def to_ns_str(s: str): return msg(libobjc.objc_getClass(b"NSString"), "stringWithUTF8String:", s.encode(), restype=objc_instance)
|
||||
|
||||
def to_struct(*t: int, _type: type = ctypes.c_ulong):
|
||||
class Struct(ctypes.Structure): pass
|
||||
Struct._fields_ = [(f"field{i}", _type) for i in range(len(t))]
|
||||
return Struct(*t)
|
||||
|
||||
def wait_check(cbuf: Any):
|
||||
cbuf.waitUntilCompleted()
|
||||
if (error := cbuf.error()) is not None:
|
||||
raise RuntimeError(error)
|
||||
msg(cbuf, "waitUntilCompleted")
|
||||
if (error := cast(int, msg(cbuf, "error", restype=ctypes.c_ulong))) != 0: raise RuntimeError(error)
|
||||
|
||||
def elapsed_time(cbuf: objc_id):
|
||||
return cast(float, msg(cbuf, "GPUEndTime", restype=ctypes.c_double)) - cast(float, msg(cbuf, "GPUStartTime", restype=ctypes.c_double))
|
||||
|
||||
def error_check(error: objc_instance, error_constructor: type[Exception] = RuntimeError):
|
||||
if error.value is None: return None
|
||||
raise error_constructor(bytes(msg(msg(error, "localizedDescription", restype=objc_instance), "UTF8String", restype=ctypes.c_char_p)).decode())
|
||||
|
||||
class MetalCompiler(Compiler):
|
||||
def __init__(self, device:Optional[MetalDevice]):
|
||||
|
@ -20,11 +66,14 @@ class MetalCompiler(Compiler):
|
|||
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
|
||||
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=src.encode('utf-8'))
|
||||
return subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
|
||||
options = Metal.MTLCompileOptions.new()
|
||||
options.setFastMathEnabled_(getenv("METAL_FAST_MATH"))
|
||||
try: library = unwrap2(self.device.device.newLibraryWithSource_options_error_(src, options, None))
|
||||
except AssertionError as e: raise CompileError(e) from e
|
||||
return library.libraryDataContents().bytes().tobytes()
|
||||
options = msg(libobjc.objc_getClass(b"MTLCompileOptions"), "new", restype=objc_instance)
|
||||
msg(options, "setFastMathEnabled:", getenv("METAL_FAST_MATH"))
|
||||
compileError = objc_instance()
|
||||
library = msg(self.device.device, "newLibraryWithSource:options:error:", to_ns_str(src),
|
||||
options, ctypes.byref(compileError), restype=objc_instance)
|
||||
error_check(compileError, CompileError)
|
||||
library_contents = msg(library, "libraryDataContents", restype=objc_instance)
|
||||
return ctypes.string_at(msg(library_contents, "bytes"), cast(int, msg(library_contents, "length", restype=ctypes.c_ulong)))
|
||||
|
||||
class MetalProgram:
|
||||
def __init__(self, device:MetalDevice, name:str, lib:bytes):
|
||||
|
@ -38,27 +87,35 @@ class MetalProgram:
|
|||
print("Error running disassembler: Make sure you have https://github.com/dougallj/applegpu cloned to tinygrad/extra/disassemblers/applegpu")
|
||||
assert lib[:4] == b"MTLB", "Invalid Metal library. Could be due to using conda. Try system python or METAL_XCODE=1 DISABLE_COMPILER_CACHE=1."
|
||||
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
|
||||
self.library = unwrap2(self.device.device.newLibraryWithData_error_(data, None))
|
||||
self.fxn = self.library.newFunctionWithName_(name)
|
||||
descriptor = Metal.MTLComputePipelineDescriptor.new()
|
||||
descriptor.setComputeFunction_(self.fxn)
|
||||
descriptor.setSupportIndirectCommandBuffers_(True)
|
||||
self.pipeline_state = unwrap2(self.device.device.newComputePipelineStateWithDescriptor_options_reflection_error_(
|
||||
descriptor, Metal.MTLPipelineOption(0), None, None))
|
||||
error_library_creation = objc_instance()
|
||||
self.library = msg(self.device.device, "newLibraryWithData:error:", data, ctypes.byref(error_library_creation), restype=objc_instance)
|
||||
error_check(error_library_creation)
|
||||
self.fxn = msg(self.library, "newFunctionWithName:", to_ns_str(name), restype=objc_instance)
|
||||
descriptor = msg(libobjc.objc_getClass(b"MTLComputePipelineDescriptor"), "new", restype=objc_instance)
|
||||
msg(descriptor, "setComputeFunction:", self.fxn)
|
||||
msg(descriptor, "setSupportIndirectCommandBuffers:", True)
|
||||
error_pipeline_creation = objc_instance()
|
||||
self.pipeline_state = msg(self.device.device, "newComputePipelineStateWithDescriptor:options:reflection:error:",
|
||||
descriptor, MTLPipelineOption.MTLPipelineOptionNone, None, ctypes.byref(error_pipeline_creation), restype=objc_instance)
|
||||
error_check(error_pipeline_creation)
|
||||
|
||||
def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tuple[int,int,int]=(1,1,1), vals:Tuple[int, ...]=(), wait=False):
|
||||
if prod(local_size) > self.pipeline_state.maxTotalThreadsPerThreadgroup(): raise RuntimeError(f"local size {local_size} bigger than {self.pipeline_state.maxTotalThreadsPerThreadgroup()} with exec width {self.pipeline_state.threadExecutionWidth()} memory length {self.pipeline_state.staticThreadgroupMemoryLength()}") # noqa: E501
|
||||
command_buffer = self.device.mtl_queue.commandBuffer()
|
||||
encoder = command_buffer.computeCommandEncoder()
|
||||
encoder.setComputePipelineState_(self.pipeline_state)
|
||||
for i,a in enumerate(bufs): encoder.setBuffer_offset_atIndex_(a.buf, a.offset, i)
|
||||
for i,a in enumerate(vals,start=len(bufs)): encoder.setBytes_length_atIndex_(ctypes.c_int32(a), 4, i)
|
||||
encoder.dispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
||||
encoder.endEncoding()
|
||||
command_buffer.commit()
|
||||
max_total_threads = msg(self.pipeline_state, "maxTotalThreadsPerThreadgroup", restype=ctypes.c_ulong)
|
||||
if prod(local_size) > cast(int, max_total_threads):
|
||||
exec_width = msg(self.pipeline_state, "threadExecutionWidth", restype=ctypes.c_ulong)
|
||||
memory_length = msg(self.pipeline_state, "staticThreadgroupMemoryLength", restype=ctypes.c_ulong)
|
||||
raise RuntimeError(f"local size {local_size} bigger than {max_total_threads} with exec width {exec_width} memory length {memory_length}")
|
||||
command_buffer = msg(self.device.mtl_queue, "commandBuffer", restype=objc_instance)
|
||||
encoder = msg(command_buffer, "computeCommandEncoder", restype=objc_instance)
|
||||
msg(encoder, "setComputePipelineState:", self.pipeline_state)
|
||||
for i,a in enumerate(bufs): msg(encoder, "setBuffer:offset:atIndex:", a.buf, a.offset, i)
|
||||
for i,a in enumerate(vals,start=len(bufs)): msg(encoder, "setBytes:length:atIndex:", bytes(ctypes.c_int(a)), 4, i)
|
||||
msg(encoder, "dispatchThreadgroups:threadsPerThreadgroup:", to_struct(*global_size), to_struct(*local_size))
|
||||
msg(encoder, "endEncoding")
|
||||
msg(command_buffer, "commit")
|
||||
if wait:
|
||||
wait_check(command_buffer)
|
||||
return command_buffer.GPUEndTime() - command_buffer.GPUStartTime()
|
||||
return elapsed_time(command_buffer)
|
||||
self.device.mtl_buffers_in_flight.append(command_buffer)
|
||||
|
||||
class MetalBuffer:
|
||||
|
@ -69,46 +126,48 @@ class MetalAllocator(LRUAllocator):
|
|||
self.device:MetalDevice = device
|
||||
super().__init__()
|
||||
def _alloc(self, size:int, options) -> MetalBuffer:
|
||||
ret = self.device.device.newBufferWithLength_options_(size, Metal.MTLResourceStorageModeShared)
|
||||
if ret is None: raise MemoryError(f"Metal OOM while allocating {size=}")
|
||||
# Buffer is explicitly released in _free() rather than garbage collected via reference count
|
||||
ret = msg(self.device.device, "newBufferWithLength:options:", size, MTLResourceOptions.MTLResourceStorageModeShared, restype=objc_id)
|
||||
if ret.value is None: raise MemoryError(f"Metal OOM while allocating {size=}")
|
||||
return MetalBuffer(ret, size)
|
||||
def _free(self, opaque:MetalBuffer, options): opaque.buf.release()
|
||||
def _free(self, opaque:MetalBuffer, options): msg(opaque.buf, "release")
|
||||
def transfer(self, dest:MetalBuffer, src:MetalBuffer, sz:int, src_dev:MetalDevice, dest_dev:MetalDevice):
|
||||
dest_dev.synchronize()
|
||||
src_command_buffer = src_dev.mtl_queue.commandBuffer()
|
||||
encoder = src_command_buffer.blitCommandEncoder()
|
||||
encoder.copyFromBuffer_sourceOffset_toBuffer_destinationOffset_size_(src.buf, src.offset, dest.buf, dest.offset, sz)
|
||||
encoder.endEncoding()
|
||||
src_command_buffer = msg(src_dev.mtl_queue, "commandBuffer", restype=objc_instance)
|
||||
encoder = msg(src_command_buffer, "blitCommandEncoder", restype=objc_instance)
|
||||
msg(encoder, "copyFromBuffer:sourceOffset:toBuffer:destinationOffset:size:", src.buf, src.offset, dest.buf, dest.offset, sz)
|
||||
msg(encoder, "endEncoding")
|
||||
if src_dev != dest_dev:
|
||||
src_command_buffer.encodeSignalEvent_value_(src_dev.timeline_signal, src_dev.timeline_value)
|
||||
dest_command_buffer = dest_dev.mtl_queue.commandBuffer()
|
||||
dest_command_buffer.encodeWaitForEvent_value_(src_dev.timeline_signal, src_dev.timeline_value)
|
||||
dest_command_buffer.commit()
|
||||
msg(src_command_buffer, "encodeSignalEvent:value:", src_dev.timeline_signal, src_dev.timeline_value)
|
||||
dest_command_buffer = msg(dest_dev.mtl_queue, "commandBuffer", restype=objc_instance)
|
||||
msg(dest_command_buffer, "encodeWaitForEvent:value:", src_dev.timeline_signal, src_dev.timeline_value)
|
||||
msg(dest_command_buffer, "commit")
|
||||
dest_dev.mtl_buffers_in_flight.append(dest_command_buffer)
|
||||
src_dev.timeline_value += 1
|
||||
src_command_buffer.commit()
|
||||
msg(src_command_buffer, "commit")
|
||||
src_dev.mtl_buffers_in_flight.append(src_command_buffer)
|
||||
def from_buffer(self, src:memoryview) -> Optional[Any]:
|
||||
ret = self.device.device.newBufferWithBytesNoCopy_length_options_deallocator_(src, src.nbytes, Metal.MTLResourceStorageModeShared, None)
|
||||
ptr = (ctypes.c_char * src.nbytes).from_buffer(src)
|
||||
ret = msg(self.device.device, "newBufferWithBytesNoCopy:length:options:deallocator:", ptr, src.nbytes, 0, None, restype=objc_instance)
|
||||
if ret: self.device.mv_in_metal.append(src)
|
||||
return MetalBuffer(ret, src.nbytes)
|
||||
def as_buffer(self, src:MetalBuffer) -> memoryview:
|
||||
self.device.synchronize()
|
||||
return src.buf.contents().as_buffer(src.offset+src.size)[src.offset:]
|
||||
ptr = msg(src.buf, "contents", restype=objc_id) # Shared memory, do not release here
|
||||
array = (ctypes.c_char * (src.offset + src.size)).from_address(ptr.value)
|
||||
return memoryview(array).cast("B")[src.offset:]
|
||||
def copyin(self, dest:MetalBuffer, src:memoryview): self.as_buffer(dest)[:] = src
|
||||
def copyout(self, dest:memoryview, src:MetalBuffer): dest[:] = self.as_buffer(src)
|
||||
def offset(self, buf:MetalBuffer, size:int, offset:int): return MetalBuffer(buf.buf, size, offset)
|
||||
|
||||
class MetalDevice(Compiled):
|
||||
def __init__(self, device:str):
|
||||
self.device = Metal.MTLCreateSystemDefaultDevice()
|
||||
self.mtl_queue = self.device.newCommandQueueWithMaxCommandBufferCount_(1024)
|
||||
self.device = libmetal.MTLCreateSystemDefaultDevice()
|
||||
self.mtl_queue = msg(self.device, "newCommandQueueWithMaxCommandBufferCount:", 1024, restype=objc_instance)
|
||||
if self.mtl_queue is None: raise RuntimeError("Cannot allocate a new command queue")
|
||||
|
||||
self.mtl_buffers_in_flight: List[Any] = []
|
||||
self.mv_in_metal: List[memoryview] = []
|
||||
|
||||
self.timeline_signal = self.device.newSharedEvent()
|
||||
self.timeline_signal = msg(self.device, "newSharedEvent", restype=objc_instance)
|
||||
self.timeline_value = 0
|
||||
|
||||
from tinygrad.runtime.graph.metal import MetalGraph
|
||||
|
|
Loading…
Reference in New Issue