add UOp.r [pr] (#7080)

This commit is contained in:
qazal 2024-10-16 05:06:02 +03:00 committed by GitHub
parent 26df50cf43
commit 067b35e915
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 7 deletions

View File

@ -2,7 +2,7 @@ import sys, pickle, atexit
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast
from tinygrad.ops import BUFFER_UOPS, REDUCE_ALU, UNSAFE_PAD_OPS, MetaOps, ReduceOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, Variable, resolve, \
from tinygrad.ops import BUFFER_UOPS, UNSAFE_PAD_OPS, MetaOps, ReduceOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, Variable, resolve, \
graph_rewrite, track_rewrites, sint
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, GlobalCounters, Metadata, all_same, \
colored, diskcache_put, prod, dedup, all_int, merge_dicts, getenv, unwrap
@ -76,8 +76,7 @@ def push_swizzle_up_through_reduce(swizzle:UOp, reduceop:UOp) -> Optional[UOp]:
new_input_st = tmp + ShapeTracker(tuple(nv))
_, new_rshape = permute_reduce(new_input_st, reduceop.axis_arg)
new_axis = tuple(range(len(new_input_st.shape)-len(new_rshape), len(new_input_st.shape)))
return UOp(UOps.REDUCE_AXIS, reduceop.dtype, (st_fixup(rsrc, lambda st:st+new_input_st, {}),),
(reduceop.arg[0], new_axis)).view(ShapeTracker.from_shape(swizzle_st.shape))
return st_fixup(rsrc, lambda st:st+new_input_st, {}).r(reduceop.arg[0], new_axis).view(ShapeTracker.from_shape(swizzle_st.shape))
def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp:
swizzle_st, src_st = unwrap(swizzle.st), unwrap(swizzle.src[0].st)
@ -85,7 +84,7 @@ def push_swizzle_down_through_reduce(root:UOp, swizzle:UOp) -> UOp:
assert prod(swizzle_st.shape) == prod(src_st.shape), "can't push expands down to STORE"
output_shape = swizzle_st.reduce(root.axis_arg)
new_axis = tuple(i for i,(s,u) in enumerate(zip(src_st.shape, output_shape)) if s != u)
return UOp(UOps.REDUCE_AXIS, root.dtype, swizzle.src, (root.arg[0], new_axis)).view(ShapeTracker.from_shape(output_shape))
return swizzle.src[0].r(root.arg[0], new_axis).view(ShapeTracker.from_shape(output_shape))
def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
swizzles = [x for x in root.src if x.op is UOps.VIEW and len(x.src) != 0]
@ -101,7 +100,7 @@ def push_swizzle_down_through_elementwise(root:UOp) -> Optional[UOp]:
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
assert not any(x.op is UOps.REDUCE_AXIS for x in first_reduce.parents), "can't merge more than two reduceops at a time"
return UOp(UOps.REDUCE_AXIS, first_reduce.dtype, first_reduce.src, (first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
return first_reduce.src[0].r(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)
merge_views = PatternMatcher([(UPat(UOps.VIEW, src=(UPat(UOps.VIEW, name="s0"),), name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))])
@ -167,7 +166,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
# only reduceop changes shape
src_st = ShapeTracker.from_shape(buf.srcs[0].shape) if buf.op in ReduceOps else st
src: List[UOp] = [_recursive_uop(x, src_st, outputs, var_vals, inputs, buf_uops, assign_targets, cache) for x in buf.srcs]
if buf.op in ReduceOps: ret = UOp(UOps.REDUCE_AXIS, dtype, tuple(src), (REDUCE_ALU[cast(ReduceOps, buf.op)], buf.arg)).view(st)
if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg).view(st)
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, (buf_uops[buf.buffer], src[0]))
elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (buf_uops[buf.buffer], src[1]))
elif buf.op is UnaryOps.CAST: ret = src[0].cast(dtype)

View File

@ -291,6 +291,7 @@ class UOp(MathTrait):
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx)
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
def r(self, op, axis): return UOp(UOps.REDUCE_AXIS, self.dtype, (self,), (REDUCE_ALU[op] if op in ReduceOps else op, axis))
# *** uop Variable stuff ***
@ -736,7 +737,7 @@ spec = PatternMatcher([
(UPat(UOps.IF, dtype=dtypes.void, src=(UPat(), UPat(UOps.BARRIER))), lambda: True),
(UPat(UOps.ENDIF, dtype=dtypes.void, src=(UPat(UOps.IF),)), lambda: True),
(UPat(UOps.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in BinaryOps),
(UPat(UOps.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in REDUCE_ALU.values()),
(UPat(UOps.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
(UPat(UOps.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
(UPat((UOps.BITCAST, UOps.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None and x.dtype.count == 1),