mirror of https://github.com/commaai/tinygrad.git
refactor membufs (#4147)
This commit is contained in:
parent
b7e281cf10
commit
c0796374e4
|
@ -25,7 +25,7 @@ class _LBScheduleItem:
|
|||
var_vals: Dict[Variable, int]
|
||||
|
||||
# recursively create a lazyop
|
||||
def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], outbufs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
|
||||
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outbufs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
|
||||
realizes:Set[LazyBuffer], cache, assign_to:Optional[LazyBuffer]=None, assign_idx:Optional[int]=None) -> LazyOp:
|
||||
if (buf, st) in cache: return cache[(buf, st)]
|
||||
if buf != buf.base:
|
||||
|
@ -52,18 +52,18 @@ def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], outbufs:Tuple[La
|
|||
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
|
||||
raise RuntimeError(f"must be contiguous for assign {unbound_st}")
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(assign_idx, buf.dtype, unbound_st))
|
||||
if buf not in membufs: membufs.append(buf)
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(membufs.index(buf), buf.dtype, unbound_st))
|
||||
if buf not in inputs: inputs.append(buf)
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outbufs)+inputs.index(buf), buf.dtype, unbound_st))
|
||||
|
||||
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
|
||||
if buf.op is LoadOps.CONTIGUOUS:
|
||||
assert buf in outbufs
|
||||
return _recursive_lazyop(buf.srcs[0], membufs, outbufs, var_vals, st, realizes, cache)
|
||||
return _recursive_lazyop(buf.srcs[0], inputs, outbufs, var_vals, st, realizes, cache)
|
||||
if buf.op is LoadOps.ASSIGN:
|
||||
assert buf in outbufs
|
||||
assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
|
||||
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
|
||||
return _recursive_lazyop(buf.srcs[0], membufs, outbufs, var_vals, st, realizes, cache, assign_to=buf.srcs[1], assign_idx=membufs.index(buf))
|
||||
return _recursive_lazyop(buf.srcs[0], inputs, outbufs, var_vals, st, realizes, cache, assign_to=buf.srcs[1], assign_idx=outbufs.index(buf))
|
||||
|
||||
# if it's a reduce, we have to change the shapetracker
|
||||
if buf.op in ReduceOps:
|
||||
|
@ -72,7 +72,7 @@ def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], outbufs:Tuple[La
|
|||
|
||||
# otherwise we fuse it like normal
|
||||
cache[(buf, st)] = ret = \
|
||||
LazyOp(buf.op, tuple(_recursive_lazyop(x, membufs, outbufs, var_vals, st, realizes, cache, assign_to, assign_idx) for x in buf.srcs), buf.arg)
|
||||
LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, outbufs, var_vals, st, realizes, cache, assign_to, assign_idx) for x in buf.srcs), buf.arg)
|
||||
return ret
|
||||
|
||||
def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
|
||||
|
@ -81,12 +81,12 @@ def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[
|
|||
if out.op in {LoadOps.CUSTOM, LoadOps.COPY, LoadOps.EMPTY}:
|
||||
op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs)
|
||||
else:
|
||||
output_st, membufs = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape), [out]
|
||||
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
|
||||
output_view = out.arg[0] if out.op is LoadOps.ASSIGN and out.arg else output_st
|
||||
op = _recursive_lazyop(out, membufs, (out, ), var_vals, output_st, realizes, cache={})
|
||||
op = _recursive_lazyop(out, inputs, (out, ), var_vals, output_st, realizes, cache={})
|
||||
output_view, vv = output_view.simplify().unbind()
|
||||
if vv: var_vals.update(vv)
|
||||
op, inputs = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_view)), membufs[1:]
|
||||
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_view))
|
||||
return _LBScheduleItem((op,), (out,), tuple(inputs), var_vals)
|
||||
|
||||
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
|
||||
|
|
Loading…
Reference in New Issue