mirror of https://github.com/commaai/tinygrad.git
move uop logic into shapetracker [run_process_replay] (#6118)
This commit is contained in:
parent
89c7989659
commit
7cae152aa2
|
@ -1,50 +1,14 @@
|
|||
from __future__ import annotations
|
||||
from dataclasses import replace
|
||||
from typing import List, Tuple, cast, Optional, Any, Dict
|
||||
import functools
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from dataclasses import replace
|
||||
from typing import List, Tuple, cast, Optional, Dict
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, variable_to_uop
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.ops import ReduceOps, KernelInfo, BinaryOps, BUFFER_UOPS, UOp, UOps
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import all_int, get_contraction, prod, partition, flatten
|
||||
|
||||
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode
|
||||
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x.render(render_ops, ctx)
|
||||
render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.pyint, self.b),
|
||||
MulNode: lambda self, ops, ctx: self.a.render(ops, ctx)*variable_to_uop(self.b, ctx),
|
||||
DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx),
|
||||
ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
|
||||
LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
|
||||
Variable: lambda self,ops,ctx: ctx[self] if ctx is not None and self in ctx else \
|
||||
UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, self.min), UOp.const(dtypes.int, self.max)), self),
|
||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
|
||||
def _uop_view(view:View, idxs:List[UOp], vexpr:UOp) -> Tuple[UOp, UOp]:
|
||||
# TODO: dtypes.realint
|
||||
iexpr = variable_to_uop(view.offset)
|
||||
for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
|
||||
if sh != 1 and st != 0: iexpr = iexpr + idx*variable_to_uop(st)
|
||||
if m is not None:
|
||||
if m[0] != 0: vexpr = vexpr * idx.ge(variable_to_uop(m[0]))
|
||||
if m[1] != sh: vexpr = vexpr * idx.lt(variable_to_uop(m[1]))
|
||||
return iexpr, vexpr
|
||||
|
||||
# TODO: change this once UOps is ready to replace symbolic
|
||||
def st_to_uops(st:ShapeTracker, idxs:List[UOp]) -> Tuple[UOp, UOp]:
|
||||
idx, valid = _uop_view(st.views[-1], idxs, UOp.const(dtypes.bool, True))
|
||||
for view in reversed(st.views[0:-1]):
|
||||
view = view.minify()
|
||||
acc, idxs = 1, []
|
||||
for _d in reversed(view.shape):
|
||||
d = variable_to_uop(_d)
|
||||
idxs.append((idx//acc)%d)
|
||||
acc *= d
|
||||
idx, valid = _uop_view(view, idxs[::-1], valid)
|
||||
return idx, valid
|
||||
|
||||
def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
|
||||
# TODO: symbolic shape
|
||||
if not all_int(dims): return dims
|
||||
|
@ -126,7 +90,7 @@ class IndependentLowerer:
|
|||
|
||||
def _to_uop(self, x:UOp) -> UOp:
|
||||
if x.op in BUFFER_UOPS:
|
||||
idx, valid = st_to_uops(x.src[-1].arg, self.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else self.idxs)
|
||||
idx, valid = x.src[-1].arg.to_indexed_uops(self.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else self.idxs)
|
||||
# TODO: check has_valid in UPat, not here
|
||||
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
||||
if x.op is UOps.CONST: return valid.where(UOp.const(x.dtype, x.arg), UOp.const(x.dtype, 0))
|
||||
|
|
|
@ -1,13 +1,36 @@
|
|||
# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, Set, Iterable, cast
|
||||
from typing import Tuple, List, Optional, Dict, Set, Iterable, cast, Any
|
||||
from tinygrad.helpers import merge_dicts, getenv
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, sint
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, DivNode, ModNode, LtNode, AndNode, sint
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, UOps
|
||||
|
||||
# TODO: this needs to be replaced, there shouldn't be variables in the shapetracker, only ints and UOps
|
||||
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x.render(render_ops, ctx)
|
||||
render_ops: Any = { NumNode: lambda self, ops, ctx: UOp.const(dtypes.pyint, self.b),
|
||||
MulNode: lambda self, ops, ctx: self.a.render(ops, ctx)*variable_to_uop(self.b, ctx),
|
||||
DivNode: lambda self, ops, ctx: self.a.render(ops, ctx)//variable_to_uop(self.b, ctx),
|
||||
ModNode: lambda self, ops, ctx: self.a.render(ops, ctx)%variable_to_uop(self.b, ctx),
|
||||
LtNode: lambda self, ops, ctx: self.a.render(ops, ctx).lt(variable_to_uop(self.b, ctx)),
|
||||
Variable: lambda self,ops,ctx: ctx[self] if ctx is not None and self in ctx else \
|
||||
UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, self.min), UOp.const(dtypes.int, self.max)), self),
|
||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a+b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: a*b.render(ops, ctx), self.nodes[1:], self.nodes[0].render(ops,ctx)) }
|
||||
|
||||
def _uop_view(view:View, idxs:List[UOp], vexpr:UOp) -> Tuple[UOp, UOp]:
|
||||
# TODO: dtypes.realint
|
||||
iexpr = variable_to_uop(view.offset)
|
||||
for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
|
||||
if sh != 1 and st != 0: iexpr = iexpr + idx*variable_to_uop(st)
|
||||
if m is not None:
|
||||
if m[0] != 0: vexpr = vexpr * idx.ge(variable_to_uop(m[0]))
|
||||
if m[1] != sh: vexpr = vexpr * idx.lt(variable_to_uop(m[1]))
|
||||
return iexpr, vexpr
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShapeTracker:
|
||||
views: Tuple[View, ...]
|
||||
|
@ -43,6 +66,18 @@ class ShapeTracker:
|
|||
|
||||
def to_uops(self) -> Tuple[UOp, UOp]: return UOp(UOps.ST_IDX, dtypes.pyint, (), self), UOp(UOps.ST_VALID, dtypes.bool, (), self)
|
||||
|
||||
def to_indexed_uops(self, idxs:List[UOp]) -> Tuple[UOp, UOp]:
|
||||
idx, valid = _uop_view(self.views[-1], idxs, UOp.const(dtypes.bool, True))
|
||||
for view in reversed(self.views[0:-1]):
|
||||
view = view.minify()
|
||||
acc, idxs = 1, []
|
||||
for _d in reversed(view.shape):
|
||||
d = variable_to_uop(_d)
|
||||
idxs.append((idx//acc)%d)
|
||||
acc *= d
|
||||
idx, valid = _uop_view(view, idxs[::-1], valid)
|
||||
return idx, valid
|
||||
|
||||
def real_size(self) -> int:
|
||||
if 0 in self.shape: return 0
|
||||
idx, valid = self.expr_idxs()
|
||||
|
|
Loading…
Reference in New Issue