support symbolic shape in Interpreted (#2289)

* support symbolic shape in Interpreted

* simpler

* no InterpretedFlopCounter

* tragic NumNode

* regex is hard
This commit is contained in:
chenyu 2023-11-13 20:13:18 -05:00 committed by GitHub
parent 6960bcded0
commit d86ea188dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 10 additions and 6 deletions

View File

@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor, Device
import numpy as np
@unittest.skipIf(getenv("ARM64") or getenv("PTX"), "ARM64 and PTX are not supported")
@unittest.skipUnless(Device.DEFAULT in JIT_SUPPORTED_DEVICE and Device.DEFAULT not in ["HIP", "WEBGPU"], f"{Device.DEFAULT} is not supported")
@unittest.skipIf(Device.DEFAULT in ["HIP", "WEBGPU"], f"{Device.DEFAULT} is not supported")
class TestSymbolicOps(unittest.TestCase):
def test_plus1(self):
def f(a): return (a+1).realize()

View File

@ -1,5 +1,5 @@
from __future__ import annotations
import importlib, inspect, functools, pathlib
import importlib, inspect, functools, pathlib, re
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT
@ -116,12 +116,16 @@ class Interpreted:
self.method_cache: Dict[LazyOp, Callable] = {}
def interpret_ast(self:Interpreted, ast:LazyOp) -> Callable:
tglob: Dict[str, Any] = {}
tglob: Dict[str, Any] = {"Variable": Variable}
lines: List[str] = []
f = self.fxn_for_op
@functools.lru_cache(None)
def gstr(x:Any, nm=None) -> str:
if self != InterpretedFlopCounter and ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg):
str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg)
# TODO: (Variable - Variable) might create NumNode. can we remove it?
return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg)
ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
tglob[ret] = x
return ret
@ -143,14 +147,14 @@ class Interpreted:
return ret
ret = _interpret_ast(ast)
src = '\n'.join(['def run(inputs):'] + lines + [f" return {gstr(self.from_underlying, 'from_underlying')}({ret})" if self.from_underlying else f" return {ret}"])
src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(self.from_underlying, 'from_underlying')}({ret})" if self.from_underlying else f" return {ret}"])
if DEBUG >= 4 and self != InterpretedFlopCounter: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src))
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
return tglob['run']
def exec_ast(self, ast:LazyOp, output=None, inputs=None, var_vals=None, **kwargs):
if ast not in self.method_cache: self.method_cache[ast] = self.interpret_ast(ast)
ret = self.method_cache[ast]([x.realized for x in inputs] if inputs else None)
ret = self.method_cache[ast]([x.realized for x in inputs] if inputs else None, var_vals)
if output is not None and ret.dtype != output.dtype and UnaryOps.CAST in self.fxn_for_op:
ret = self.from_underlying(self.fxn_for_op[UnaryOps.CAST](self.fxn_for_op[BufferOps.MEM](ret), (output.dtype, False))) # Do manual casting of ret if it does not match the required output dtype.
# TODO: is this used?

View File

@ -355,7 +355,7 @@ VariableOrNum = Union[Variable, NumNode]
render_python: Dict[Type, Callable] = {
Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" else f"{self.expr}"),
NumNode: lambda self,ops,ctx: f"{self.b}",
NumNode: lambda self,ops,ctx: f"NumNode({self.b})" if ctx == "REPR" else f"{self.b}",
MulNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}*{sym_render(self.b,ops,ctx)})",
DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",