mirror of https://github.com/commaai/tinygrad.git
add UOp.r [pr] (#7080)
This commit is contained in:
parent
26df50cf43
commit
067b35e915
|
@ -2,7 +2,7 @@ import sys, pickle, atexit
|
|||
from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast
|
||||
from tinygrad.ops import BUFFER_UOPS, REDUCE_ALU, UNSAFE_PAD_OPS, MetaOps, ReduceOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, Variable, resolve, \
|
||||
from tinygrad.ops import BUFFER_UOPS, UNSAFE_PAD_OPS, MetaOps, ReduceOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, Variable, resolve, \
|
||||
graph_rewrite, track_rewrites, sint
|
||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, GlobalCounters, Metadata, all_same, \
|
||||
colored, diskcache_put, prod, dedup, all_int, merge_dicts, getenv, unwrap
|
||||
|
@ -76,8 +76,7 @@ def push_swizzle_up_through_reduce(swizzle:UOp, reduceop:UOp) -> Optional[UOp]:
|
|||
new_input_st = tmp + ShapeTracker(tuple(nv))
|
||||
_, new_rshape = permute_reduce(new_input_st, reduceop.axis_arg)
|
||||
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
|
||||
return UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda st:st+new_input_st, {}),),
|
||||
(reduceop.arg[0], new_axis)).view(ShapeTracker.from_shape(swizzle_st.shape))
|
||||
return st_fixup(rsrc, lambda st:st+new_input_st, {}).r(reduceop.arg[0], new_axis).view(ShapeTracker.from_shape(swizzle_st.shape))
|
||||
|
||||
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp:
|
||||
swizzle_st, src_st = unwrap(swizzle.st), unwrap(swizzle.src[0].st)
|
||||
|
@ -85,7 +84,7 @@ def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp:
|
|||
assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE"
|
||||
output_shape = swizzle_st.reduce(root.axis_arg)
|
||||
new_axis = tuple(i for i,(s,u) in enumerate(zip(src_st.shape, output_shape)) if s != u)
|
||||
return UOp(UOps.REDUCE_AXIS, root.dtype, swizzle.src, (root.arg[0], new_axis)).view(ShapeTracker.from_shape(output_shape))
|
||||
return swizzle.src[0].r(root.arg[0], new_axis).view(ShapeTracker.from_shape(output_shape))
|
||||
|
||||
def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
|
||||
swizzles = [x for x in root.src if x.op is UOps.VIEW and len(x.src) != 0]
|
||||
|
@ -101,7 +100,7 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
|
|||
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
||||
assert not any(x.op is UOps.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
|
||||
return UOp(UOps.REDUCE_AXIS, first_reduce.dtype, first_reduce.src, (first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
|
||||
return first_reduce.src[0].r(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)
|
||||
|
||||
merge_views = PatternMatcher([(UPat(UOps.VIEW, src=(UPat(UOps.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))])
|
||||
|
||||
|
@ -167,7 +166,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
|||
# only reduceop changes shape
|
||||
src_st = ShapeTracker.from_shape(buf.srcs[0].shape) if buf.op in ReduceOps else st
|
||||
src: List[UOp] = [_recursive_uop(x, src_st, outputs, var_vals, inputs, buf_uops, assign_targets, cache) for x in buf.srcs]
|
||||
if buf.op in ReduceOps: ret = UOp(UOps.REDUCE_AXIS, dtype, tuple(src), (REDUCE_ALU[cast(ReduceOps, buf.op)], buf.arg)).view(st)
|
||||
if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg).view(st)
|
||||
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, (buf_uops[buf.buffer], src[0]))
|
||||
elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (buf_uops[buf.buffer], src[1]))
|
||||
elif buf.op is UnaryOps.CAST: ret = src[0].cast(dtype)
|
||||
|
|
|
@ -291,6 +291,7 @@ class UOp(MathTrait):
|
|||
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
|
||||
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx)
|
||||
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
|
||||
def r(self, op, axis): return UOp(UOps.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in ReduceOps else op, axis))
|
||||
|
||||
# *** uop Variable stuff ***
|
||||
|
||||
|
@ -736,7 +737,7 @@ spec = PatternMatcher([
|
|||
(UPat(UOps.IF, dtype=dtypes.void, src=(UPat(), UPat(UOps.BARRIER))), lambda: True),
|
||||
(UPat(UOps.ENDIF, dtype=dtypes.void, src=(UPat(UOps.IF),)), lambda: True),
|
||||
|
||||
(UPat(UOps.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in BinaryOps),
|
||||
(UPat(UOps.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in REDUCE_ALU.values()),
|
||||
(UPat(UOps.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
||||
(UPat(UOps.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
||||
(UPat((UOps.BITCAST, UOps.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None and x.dtype.count == 1),
|
||||
|
|
Loading…
Reference in New Issue