openpilot1/tinygrad_repo/tinygrad/renderer/metal.py

17 lines
678 B
Python

import functools
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
class MetalLanguage(CStyleLanguage):
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel "
buffer_prefix = "device "
smem_prefix = "threadgroup "
arg_int_prefix = "constant int&"
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
float4 = "float4"
uses_ptr_arithmetic=True
gid = [f"gid.{chr(120+i)}" for i in range(3)]
lid = [f"lid.{chr(120+i)}" for i in range(3)]
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
MetalRenderer = functools.partial(uops_to_cstyle, MetalLanguage())