mirror of https://github.com/commaai/tinygrad.git
all views are UOps.VIEW [pr] (#7090)
* all views are UOps.VIEW * is it you * don't recreate st uop [pr] * first rewrite all elementwise
This commit is contained in:
parent
6acda43a2c
commit
6172b42140
|
@ -145,31 +145,29 @@ def full_ast_rewrite(base_sink:UOp, bufs:Tuple[int, ...], var_vals:Dict[Variable
|
|||
|
||||
# *** List[LazyBuffer] lowering to ScheduleItem ***
|
||||
|
||||
def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], inputs:List[LazyBuffer],
|
||||
buf_uops:Dict[Buffer, UOp], cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp:
|
||||
"""recursively create a UOp"""
|
||||
if buf is not buf.base: st, buf = buf.st+st, buf.base
|
||||
if (buf, st) in cache: return cache[(buf, st)]
|
||||
assert buf.op is not None, "base must be a base itself"
|
||||
def _recursive_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], cache:Dict[LazyBuffer, UOp]) -> UOp:
|
||||
if (r:=cache.get(buf)) is not None: return r
|
||||
dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype
|
||||
|
||||
# buffer ops define ShapeTracker
|
||||
# if it's realized, it's a load and we add it to the inputs
|
||||
if (ubuf:=buf_uops.get(buf.buffer)) is not None and buf not in outputs:
|
||||
if buf.op is MetaOps.CONST: return ubuf.view(st)
|
||||
if buf.op is MetaOps.CONST: return ubuf.view(buf.st)
|
||||
if not any(x.buffer is buf.buffer for x in outputs) and buf not in inputs: inputs.append(buf)
|
||||
return UOp(UOps.LOAD, dtype, (ubuf, st.to_uop()))
|
||||
return UOp(UOps.LOAD, dtype, (ubuf, buf.st.to_uop()))
|
||||
|
||||
# only reduceop changes shape
|
||||
src_st = ShapeTracker.from_shape(buf.srcs[0].shape) if buf.op in ReduceOps else st
|
||||
src: List[UOp] = [_recursive_uop(x, src_st, outputs, inputs, buf_uops, cache) for x in buf.srcs]
|
||||
if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg).view(st)
|
||||
src: List[UOp] = []
|
||||
for x in buf.srcs:
|
||||
u = _recursive_uop(x.base, outputs, inputs, buf_uops, cache)
|
||||
src.append(u if x is x.base else u.view(x.st))
|
||||
if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg)
|
||||
elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, (buf_uops[buf.buffer], src[0]))
|
||||
elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (buf_uops[buf.buffer], src[1]))
|
||||
elif buf.op is UnaryOps.CAST: ret = src[0].cast(dtype)
|
||||
elif buf.op is UnaryOps.BITCAST: ret = src[0].bitcast(dtype)
|
||||
else: ret = UOp(UOps.ALU, dtype, tuple(src), buf.op)
|
||||
cache[(buf, st)] = ret
|
||||
cache[buf] = ret
|
||||
return ret
|
||||
|
||||
def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], var_vals:Dict[Variable, int]) -> LBScheduleItem:
|
||||
|
@ -178,14 +176,15 @@ def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], var_val
|
|||
return LBScheduleItem(UOp(METAOPS[cast(MetaOps, out.op)], out.dtype, (), out.arg), (out,)+tuple(x.base for x in out.srcs),
|
||||
(out.metadata,) if out.metadata is not None else None)
|
||||
# create the stores
|
||||
cache: Dict[Tuple[LazyBuffer, ShapeTracker], UOp] = {}
|
||||
cache: Dict[LazyBuffer, UOp] = {}
|
||||
ast: List[UOp] = []
|
||||
inputs: List[LazyBuffer] = []
|
||||
for out in outs:
|
||||
src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), inputs, buf_uops, cache=cache)
|
||||
src = _recursive_uop(out, outs, inputs, buf_uops, cache)
|
||||
if out.op is MetaOps.ASSIGN and out.arg:
|
||||
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
|
||||
output_st = out.arg[0]
|
||||
else: output_st = ShapeTracker.from_shape(out.shape)
|
||||
ast.append(UOp(UOps.STORE, dtypes.void, (buf_uops[out.buffer], output_st.to_uop(), src)))
|
||||
sink = full_ast_rewrite(ast[0].sink(*ast[1:]), tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs), var_vals)
|
||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
|
@ -194,7 +193,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], var_val
|
|||
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is UOps.LOAD and x.src[0] in assign_targets):
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x.metadata for x,_ in cache if x.metadata and x not in inputs])))
|
||||
return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x.metadata for x in cache if x.metadata and x not in inputs])))
|
||||
|
||||
# *** DAG creation: decide which LazyBuffers should realize ***
|
||||
|
||||
|
|
Loading…
Reference in New Issue