misc UOp st cleanups (#6668)

This commit is contained in:
qazal 2024-09-23 14:16:42 +08:00 committed by GitHub
parent 26ebb7cab4
commit 7ca9ffa494
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 15 deletions

View File

@ -787,21 +787,20 @@ class Kernel:
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None:
if not uop.has_st or uop in sts: return
op, _, src, arg = uop.op, uop.dtype, uop.src, uop.arg
# restore globals from the two stage reduce
if op is UOps.LOAD and src[0].op is UOps.DEFINE_LOCAL:
_assert_valid_uop(local_reduce:=src[2].src[2], uop.st_arg, sts)
if uop.op is UOps.LOAD and uop.src[0].op is UOps.DEFINE_LOCAL:
_assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts)
sts[uop] = sts[local_reduce]
return
for x in src: _assert_valid_uop(x, st, sts)
for x in uop.src: _assert_valid_uop(x, st, sts)
# only reduceuop is allowed to change shape, limited to turning n to 1
if op in {UOps.REDUCE_AXIS, UOps.WMMA}: st = ShapeTracker.from_shape(sts[src[0]].reduce(arg[-1]))
# movementops are pushed to the edges with SHAPETRACKER and SWIZZLE
elif op in {UOps.SHAPETRACKER, UOps.SWIZZLE}: st = arg
if uop.op in {UOps.REDUCE_AXIS, UOps.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.arg[-1]))
# movementops are pushed to SHAPETRACKER and SWIZZLE
elif uop.op in {UOps.SHAPETRACKER, UOps.SWIZZLE}: st = uop.arg
# everything else inherits shape
else:
assert op in {UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.CONTRACT, UOps.EXPAND, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}"
st = (src_sts:=[sts[x] for x in src if x.has_st])[0]
assert uop.op in {UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.CONTRACT, UOps.EXPAND, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}"
st = (src_sts:=[sts[x] for x in uop.src if x.has_st])[0]
if not all_same(shapes:=[x.shape for x in src_sts]):
if all_same(sizes:=[prod(x) for x in shapes]): raise AssertionError(f"found implicit reshape {shapes}")
raise AssertionError(f"found implicit expand {sizes}")
@ -809,7 +808,7 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) ->
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK"
assert len(set(x.st_arg.size for x in ast.src)) == 1, "outputs must be exactly the same size"
assert all_same([x.st_arg.size for x in ast.src]), "outputs must be exactly the same size"
sts: Dict[UOp, ShapeTracker] = {}
for out in ast.src: _assert_valid_uop(out, out.st_arg, sts)
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]

View File

@ -217,16 +217,14 @@ class UOp(MathTrait):
@staticmethod
def range(dtype:DType, start:ConstType, end:ConstType, idx:int):
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start), UOp.const(dtype, end)), arg=(idx,))
def reduce(self, op, *rng): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
@functools.cached_property
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents.keys()}}
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}}
@property # parents with self
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
@functools.cached_property
def full_shape(self) -> Tuple[sint, ...]:
if self.op is UOps.SHAPETRACKER: return self.arg.shape
# NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape
return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
return self.arg.shape if self.op is UOps.SHAPETRACKER else tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
def variables(self) -> List[Variable]:
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]