mirror of https://github.com/commaai/tinygrad.git
share duplicate renders with cstyle (#2538)
This commit is contained in:
parent
7fec966b5e
commit
0fb4ff30c8
|
@ -1,7 +1,7 @@
|
|||
from tinygrad.helpers import dtypes, DType
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage
|
||||
from typing import List, Union
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.ops import BinaryOps, TernaryOps
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
|
@ -13,15 +13,7 @@ class WGSLLanguage(CStyleLanguage):
|
|||
barrier="workgroupBarrier();"
|
||||
generic_var_prefix = "var "
|
||||
external_local_bufs = True
|
||||
code_for_op = {
|
||||
UnaryOps.NEG: lambda x: f"(-{x})",
|
||||
UnaryOps.EXP2: lambda x: f"exp2({x})", UnaryOps.LOG2: lambda x: f"log2({x})",
|
||||
UnaryOps.SIN: lambda x: f"sin({x})", UnaryOps.SQRT: lambda x: f"sqrt({x})",
|
||||
BinaryOps.ADD: lambda x,y: f"({x}+{y})", BinaryOps.SUB: lambda x,y: f"({x}-{y})", BinaryOps.MUL: lambda x,y: f"({x}*{y})",
|
||||
BinaryOps.DIV: lambda x,y: f"({x}/{y})", BinaryOps.MOD: lambda x,y: f"({x}%{y})",
|
||||
BinaryOps.MAX: lambda x,y: f"max({x},{y})", BinaryOps.CMPLT: lambda x,y: f"f32({x}<{y})",
|
||||
TernaryOps.MULACC: lambda x,y,z: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c: f"select({c},{b},{a}!=0.)"
|
||||
}
|
||||
code_for_op = { **CStyleLanguage().code_for_op, BinaryOps.CMPLT: lambda x,y: f"f32({x}<{y})", TernaryOps.MULACC: lambda x,y,z: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c: f"select({c},{b},{a}!=0.)" }
|
||||
|
||||
def render_local(self, name: str, size: int):
|
||||
return f"var<workgroup> {name}: array<f32,{size}>;"
|
||||
|
|
Loading…
Reference in New Issue