line reduction in metal

This commit is contained in:
George Hotz 2023-03-03 23:14:40 -08:00
parent 893f136fe0
commit 28a6ada4ce
2 changed files with 13 additions and 16 deletions

View File

@ -13,7 +13,7 @@ from functools import partial
from tinygrad.ops import GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d
from tinygrad.helpers import colored, getenv, DEBUG, Timing
from tinygrad.helpers import colored, getenv, DEBUG
from tinygrad.jit import TinyJit
IN_CHANS = [int(x) for x in getenv("IN_CHANS", "4,16,64").split(",")]

View File

@ -29,36 +29,33 @@ class RawMetalBuffer(RawBufferCopyIn):
METAL.mtl_buffers_in_flight = []
return self._as_np() # no copy!
def unwrap(x):
ret, err = x
assert err is None, str(err)
return ret
class MetalProgram:
def __init__(self, name:str, prg:str):
if DEBUG >= 6: # dump llvm
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
dis = subprocess.check_output(['/Users/kafka/Downloads/clang+llvm-15.0.7-arm64-apple-darwin22.0/bin/llvm-dis'], input=air)
print(dis.decode('utf-8'))
if METAL_XCODE:
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
lib = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
self.library, err = METAL.device.newLibraryWithData_error_(data, None)
self.library = unwrap(METAL.device.newLibraryWithData_error_(data, None))
else:
options = Metal.MTLCompileOptions.alloc().init()
self.library, err = METAL.device.newLibraryWithSource_options_error_(prg, options, None)
assert err is None, str(err)
self.library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None))
self.fxn = self.library.newFunctionWithName_(name)
# hacks to disassemble shader
if DEBUG >= 5:
arc, err = METAL.device.newBinaryArchiveWithDescriptor_error_(Metal.MTLBinaryArchiveDescriptor.alloc().init(), None)
assert err is None, str(err)
arc = unwrap(METAL.device.newBinaryArchiveWithDescriptor_error_(Metal.MTLBinaryArchiveDescriptor.alloc().init(), None))
desc = Metal.MTLComputePipelineDescriptor.alloc().init()
desc.setComputeFunction_(self.fxn)
_, err = arc.addComputePipelineFunctionsWithDescriptor_error_(desc, None)
assert err is None, str(err)
_, err = arc.serializeToURL_error_(Cocoa.NSURL.URLWithString_("file:///tmp/shader.bin"), None)
assert err is None, str(err)
unwrap(arc.addComputePipelineFunctionsWithDescriptor_error_(desc, None))
unwrap(arc.serializeToURL_error_(Cocoa.NSURL.URLWithString_("file:///tmp/shader.bin"), None))
# clone https://github.com/dougallj/applegpu.git in the root of tinygrad
os.system(f"cd {pathlib.Path(__file__).parent.parent.parent}/applegpu && python3 compiler_explorer.py /tmp/shader.bin")
self.pipeline_state, err = METAL.device.newComputePipelineStateWithFunction_error_(self.fxn, None)
assert err is None, str(err)
self.pipeline_state = unwrap(METAL.device.newComputePipelineStateWithFunction_error_(self.fxn, None))
def __call__(self, global_size, local_size, *bufs, wait=False):
global_size += [1] * (3-len(global_size))