mirror of https://github.com/commaai/tinygrad.git
line reduction in metal
This commit is contained in:
parent
893f136fe0
commit
28a6ada4ce
|
@ -13,7 +13,7 @@ from functools import partial
|
|||
from tinygrad.ops import GlobalCounters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d
|
||||
from tinygrad.helpers import colored, getenv, DEBUG, Timing
|
||||
from tinygrad.helpers import colored, getenv, DEBUG
|
||||
from tinygrad.jit import TinyJit
|
||||
|
||||
IN_CHANS = [int(x) for x in getenv("IN_CHANS", "4,16,64").split(",")]
|
||||
|
|
|
@ -29,36 +29,33 @@ class RawMetalBuffer(RawBufferCopyIn):
|
|||
METAL.mtl_buffers_in_flight = []
|
||||
return self._as_np() # no copy!
|
||||
|
||||
def unwrap(x):
|
||||
ret, err = x
|
||||
assert err is None, str(err)
|
||||
return ret
|
||||
|
||||
class MetalProgram:
|
||||
def __init__(self, name:str, prg:str):
|
||||
if DEBUG >= 6: # dump llvm
|
||||
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
|
||||
dis = subprocess.check_output(['/Users/kafka/Downloads/clang+llvm-15.0.7-arm64-apple-darwin22.0/bin/llvm-dis'], input=air)
|
||||
print(dis.decode('utf-8'))
|
||||
if METAL_XCODE:
|
||||
air = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metal', '-x', 'metal', '-c', '-', '-o', '-'], input=prg.encode('utf-8'))
|
||||
# NOTE: if you run llvm-dis on "air" you can see the llvm bytecode
|
||||
lib = subprocess.check_output(['xcrun', '-sdk', 'macosx', 'metallib', '-', '-o', '-'], input=air)
|
||||
data = libdispatch.dispatch_data_create(lib, len(lib), None, None)
|
||||
self.library, err = METAL.device.newLibraryWithData_error_(data, None)
|
||||
self.library = unwrap(METAL.device.newLibraryWithData_error_(data, None))
|
||||
else:
|
||||
options = Metal.MTLCompileOptions.alloc().init()
|
||||
self.library, err = METAL.device.newLibraryWithSource_options_error_(prg, options, None)
|
||||
assert err is None, str(err)
|
||||
self.library = unwrap(METAL.device.newLibraryWithSource_options_error_(prg, options, None))
|
||||
self.fxn = self.library.newFunctionWithName_(name)
|
||||
# hacks to disassemble shader
|
||||
if DEBUG >= 5:
|
||||
arc, err = METAL.device.newBinaryArchiveWithDescriptor_error_(Metal.MTLBinaryArchiveDescriptor.alloc().init(), None)
|
||||
assert err is None, str(err)
|
||||
arc = unwrap(METAL.device.newBinaryArchiveWithDescriptor_error_(Metal.MTLBinaryArchiveDescriptor.alloc().init(), None))
|
||||
desc = Metal.MTLComputePipelineDescriptor.alloc().init()
|
||||
desc.setComputeFunction_(self.fxn)
|
||||
_, err = arc.addComputePipelineFunctionsWithDescriptor_error_(desc, None)
|
||||
assert err is None, str(err)
|
||||
_, err = arc.serializeToURL_error_(Cocoa.NSURL.URLWithString_("file:///tmp/shader.bin"), None)
|
||||
assert err is None, str(err)
|
||||
unwrap(arc.addComputePipelineFunctionsWithDescriptor_error_(desc, None))
|
||||
unwrap(arc.serializeToURL_error_(Cocoa.NSURL.URLWithString_("file:///tmp/shader.bin"), None))
|
||||
# clone https://github.com/dougallj/applegpu.git in the root of tinygrad
|
||||
os.system(f"cd {pathlib.Path(__file__).parent.parent.parent}/applegpu && python3 compiler_explorer.py /tmp/shader.bin")
|
||||
self.pipeline_state, err = METAL.device.newComputePipelineStateWithFunction_error_(self.fxn, None)
|
||||
assert err is None, str(err)
|
||||
self.pipeline_state = unwrap(METAL.device.newComputePipelineStateWithFunction_error_(self.fxn, None))
|
||||
|
||||
def __call__(self, global_size, local_size, *bufs, wait=False):
|
||||
global_size += [1] * (3-len(global_size))
|
||||
|
|
Loading…
Reference in New Issue