mirror of https://github.com/commaai/tinygrad.git
support symbolic shape in Interpreted (#2289)
* support symbolic shape in Interpreted * simpler * no InterpretedFlopCounter * tragic NumNode * regex is hard
This commit is contained in:
parent
6960bcded0
commit
d86ea188dd
|
@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor, Device
|
|||
import numpy as np
|
||||
|
||||
@unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported")
|
||||
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and Device.DEFAULT not in ["HIP", "WEBGPU"], f"{Device.DEFAULT} is not supported")
|
||||
@unittest.skipIf(Device.DEFAULT in ["HIP", "WEBGPU"], f"{Device.DEFAULT} is not supported")
|
||||
class TestSymbolicOps(unittest.TestCase):
|
||||
def test_plus1(self):
|
||||
def f(a): return (a+1).realize()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import annotations
|
||||
import importlib, inspect, functools, pathlib
|
||||
import importlib, inspect, functools, pathlib, re
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping
|
||||
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT
|
||||
|
@ -116,12 +116,16 @@ class Interpreted:
|
|||
self.method_cache: Dict[LazyOp, Callable] = {}
|
||||
|
||||
def interpret_ast(self:Interpreted, ast:LazyOp) -> Callable:
|
||||
tglob: Dict[str, Any] = {}
|
||||
tglob: Dict[str, Any] = {"Variable": Variable}
|
||||
lines: List[str] = []
|
||||
f = self.fxn_for_op
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def gstr(x:Any, nm=None) -> str:
|
||||
if self != InterpretedFlopCounter and ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg):
|
||||
str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg)
|
||||
# TODO: (Variable - Variable) might create NumNode. can we remove it?
|
||||
return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg)
|
||||
ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
|
||||
tglob[ret] = x
|
||||
return ret
|
||||
|
@ -143,14 +147,14 @@ class Interpreted:
|
|||
return ret
|
||||
|
||||
ret = _interpret_ast(ast)
|
||||
src = '\n'.join(['def run(inputs):'] + lines + [f" return {gstr(self.from_underlying, 'from_underlying')}({ret})" if self.from_underlying else f" return {ret}"])
|
||||
src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(self.from_underlying, 'from_underlying')}({ret})" if self.from_underlying else f" return {ret}"])
|
||||
if DEBUG >= 4 and self != InterpretedFlopCounter: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
|
||||
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
|
||||
return tglob['run']
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output=None, inputs=None, var_vals=None, **kwargs):
|
||||
if ast not in self.method_cache: self.method_cache[ast] = self.interpret_ast(ast)
|
||||
ret = self.method_cache[ast]([x.realized for x in inputs] if inputs else None)
|
||||
ret = self.method_cache[ast]([x.realized for x in inputs] if inputs else None, var_vals)
|
||||
if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op:
|
||||
ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.fxn_for_op[BufferOps.MEM](ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype.
|
||||
# TODO: is this used?
|
||||
|
|
|
@ -355,7 +355,7 @@ VariableOrNum = Union[Variable, NumNode]
|
|||
|
||||
render_python: Dict[Type, Callable] = {
|
||||
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" else f"{self.expr}"),
|
||||
NumNode: lambda self,ops,ctx: f"{self.b}",
|
||||
NumNode: lambda self,ops,ctx: f"NumNode({self.b})" if ctx == "REPR" else f"{self.b}",
|
||||
MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})",
|
||||
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
|
||||
ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
|
||||
|
|
Loading…
Reference in New Issue