mirror of https://github.com/commaai/tinygrad.git
misc UOp st cleanups (#6668)
This commit is contained in:
parent
26ebb7cab4
commit
7ca9ffa494
|
@ -787,21 +787,20 @@ class Kernel:
|
|||
|
||||
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) -> None:
|
||||
if not uop.has_st or uop in sts: return
|
||||
op, _, src, arg = uop.op, uop.dtype, uop.src, uop.arg
|
||||
# restore globals from the two stage reduce
|
||||
if op is UOps.LOAD and src[0].op is UOps.DEFINE_LOCAL:
|
||||
_assert_valid_uop(local_reduce:=src[2].src[2], uop.st_arg, sts)
|
||||
if uop.op is UOps.LOAD and uop.src[0].op is UOps.DEFINE_LOCAL:
|
||||
_assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts)
|
||||
sts[uop] = sts[local_reduce]
|
||||
return
|
||||
for x in src: _assert_valid_uop(x, st, sts)
|
||||
for x in uop.src: _assert_valid_uop(x, st, sts)
|
||||
# only reduceuop is allowed to change shape, limited to turning n to 1
|
||||
if op in {UOps.REDUCE_AXIS, UOps.WMMA}: st = ShapeTracker.from_shape(sts[src[0]].reduce(arg[-1]))
|
||||
# movementops are pushed to the edges with SHAPETRACKER and SWIZZLE
|
||||
elif op in {UOps.SHAPETRACKER, UOps.SWIZZLE}: st = arg
|
||||
if uop.op in {UOps.REDUCE_AXIS, UOps.WMMA}: st = ShapeTracker.from_shape(sts[uop.src[0]].reduce(uop.arg[-1]))
|
||||
# movementops are pushed to SHAPETRACKER and SWIZZLE
|
||||
elif uop.op in {UOps.SHAPETRACKER, UOps.SWIZZLE}: st = uop.arg
|
||||
# everything else inherits shape
|
||||
else:
|
||||
assert op in {UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.CONTRACT, UOps.EXPAND, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}"
|
||||
st = (src_sts:=[sts[x] for x in src if x.has_st])[0]
|
||||
assert uop.op in {UOps.ALU, UOps.CAST, UOps.BITCAST, UOps.CONTRACT, UOps.EXPAND, *BUFFER_UOPS}, f"bad UOp in intermediate uops {uop}"
|
||||
st = (src_sts:=[sts[x] for x in uop.src if x.has_st])[0]
|
||||
if not all_same(shapes:=[x.shape for x in src_sts]):
|
||||
if all_same(sizes:=[prod(x) for x in shapes]): raise AssertionError(f"found implicit reshape {shapes}")
|
||||
raise AssertionError(f"found implicit expand {sizes}")
|
||||
|
@ -809,7 +808,7 @@ def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:Dict[UOp, ShapeTracker]) ->
|
|||
|
||||
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
|
||||
assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK"
|
||||
assert len(set(x.st_arg.size for x in ast.src)) == 1, "outputs must be exactly the same size"
|
||||
assert all_same([x.st_arg.size for x in ast.src]), "outputs must be exactly the same size"
|
||||
sts: Dict[UOp, ShapeTracker] = {}
|
||||
for out in ast.src: _assert_valid_uop(out, out.st_arg, sts)
|
||||
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
|
||||
|
|
|
@ -217,16 +217,14 @@ class UOp(MathTrait):
|
|||
@staticmethod
|
||||
def range(dtype:DType, start:ConstType, end:ConstType, idx:int):
|
||||
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start), UOp.const(dtype, end)), arg=(idx,))
|
||||
def reduce(self, op, *rng): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
|
||||
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
|
||||
@functools.cached_property
|
||||
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents.keys()}}
|
||||
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}}
|
||||
@property # parents with self
|
||||
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
|
||||
@functools.cached_property
|
||||
def full_shape(self) -> Tuple[sint, ...]:
|
||||
if self.op is UOps.SHAPETRACKER: return self.arg.shape
|
||||
# NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape
|
||||
return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
|
||||
return self.arg.shape if self.op is UOps.SHAPETRACKER else tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
|
||||
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
|
||||
def variables(self) -> List[Variable]:
|
||||
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
|
||||
|
|
Loading…
Reference in New Issue