mirror of https://github.com/commaai/tinygrad.git
share duplicate renders with cstyle (#2538)
This commit is contained in:
parent
7fec966b5e
commit
0fb4ff30c8
|
@ -1,7 +1,7 @@
|
||||||
from tinygrad.helpers import dtypes, DType
|
from tinygrad.helpers import dtypes, DType
|
||||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
from tinygrad.ops import BinaryOps, TernaryOps
|
||||||
import math
|
import math
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
|
@ -13,15 +13,7 @@ class WGSLLanguage(CStyleLanguage):
|
||||||
barrier="workgroupBarrier();"
|
barrier="workgroupBarrier();"
|
||||||
generic_var_prefix = "var "
|
generic_var_prefix = "var "
|
||||||
external_local_bufs = True
|
external_local_bufs = True
|
||||||
code_for_op = {
|
code_for_op = { **CStyleLanguage().code_for_op, BinaryOps.CMPLT: lambda x,y: f"f32({x}<{y})", TernaryOps.MULACC: lambda x,y,z: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c: f"select({c},{b},{a}!=0.)" }
|
||||||
UnaryOps.NEG: lambda x: f"(-{x})",
|
|
||||||
UnaryOps.EXP2: lambda x: f"exp2({x})", UnaryOps.LOG2: lambda x: f"log2({x})",
|
|
||||||
UnaryOps.SIN: lambda x: f"sin({x})", UnaryOps.SQRT: lambda x: f"sqrt({x})",
|
|
||||||
BinaryOps.ADD: lambda x,y: f"({x}+{y})", BinaryOps.SUB: lambda x,y: f"({x}-{y})", BinaryOps.MUL: lambda x,y: f"({x}*{y})",
|
|
||||||
BinaryOps.DIV: lambda x,y: f"({x}/{y})", BinaryOps.MOD: lambda x,y: f"({x}%{y})",
|
|
||||||
BinaryOps.MAX: lambda x,y: f"max({x},{y})", BinaryOps.CMPLT: lambda x,y: f"f32({x}<{y})",
|
|
||||||
TernaryOps.MULACC: lambda x,y,z: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c: f"select({c},{b},{a}!=0.)"
|
|
||||||
}
|
|
||||||
|
|
||||||
def render_local(self, name: str, size: int):
|
def render_local(self, name: str, size: int):
|
||||||
return f"var<workgroup> {name}: array<f32,{size}>;"
|
return f"var<workgroup> {name}: array<f32,{size}>;"
|
||||||
|
@ -59,4 +51,4 @@ class WGSLLanguage(CStyleLanguage):
|
||||||
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
|
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
|
||||||
if buf_dtype != var_dtype:
|
if buf_dtype != var_dtype:
|
||||||
var_name = f"{type_map[buf_dtype]}({var_name})"
|
var_name = f"{type_map[buf_dtype]}({var_name})"
|
||||||
return f"{buf_name}[{idx}] = {var_name};"
|
return f"{buf_name}[{idx}] = {var_name};"
|
||||||
|
|
Loading…
Reference in New Issue