17 lines
678 B
Python
17 lines
678 B
Python
import functools
|
|
from tinygrad.renderer.cstyle import uops_to_cstyle, CStyleLanguage
|
|
|
|
class MetalLanguage(CStyleLanguage):
|
|
kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel "
|
|
buffer_prefix = "device "
|
|
smem_prefix = "threadgroup "
|
|
arg_int_prefix = "constant int&"
|
|
barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);"
|
|
float4 = "float4"
|
|
uses_ptr_arithmetic=True
|
|
gid = [f"gid.{chr(120+i)}" for i in range(3)]
|
|
lid = [f"lid.{chr(120+i)}" for i in range(3)]
|
|
extra_args = ['uint3 gid [[threadgroup_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]']
|
|
|
|
MetalRenderer = functools.partial(uops_to_cstyle, MetalLanguage())
|