allocate shared memory per block (#4924)

* define temp

* use idx

* cleaner [run_process_replay]
This commit is contained in:
qazal 2024-06-12 15:43:10 +08:00 committed by GitHub
parent ca4ccddcd6
commit b833a112ba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 8 additions and 7 deletions

View File

@ -223,7 +223,7 @@ class Linearizer(Kernel):
loop_ctx = self.render_loop(reduce_idxs, 2) loop_ctx = self.render_loop(reduce_idxs, 2)
# define accumulator - modify idxs if necessary for TC # define accumulator - modify idxs if necessary for TC
out_buf = -1 if self.group_for_reduces else 0 out_buf = -(self.reduceops.index(reduceop)+1) if self.group_for_reduces else 0
accs[reduceop] = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx) accs[reduceop] = self.global_load(out_buf, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc=reduceop, loop_ctx=loop_ctx)
# store local aliases # store local aliases
@ -268,7 +268,7 @@ class Linearizer(Kernel):
# end the local loop, do the local reduce # end the local loop, do the local reduce
if self.group_for_reduces: if self.group_for_reduces:
fake_global_idxs = [x*0 for x in global_idxs] fake_global_idxs = [x*0 for x in global_idxs]
stores = self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop]) # store accumulators stores = self.global_store(out_buf, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, accs[reduceop]) # store accumulators
barrier = self.uops.add(UOps.BARRIER, None, tuple(stores)) barrier = self.uops.add(UOps.BARRIER, None, tuple(stores))
if self.opts.has_local: if self.opts.has_local:
fake_idxs = [NumNode(0)]*len(self.sts[-1].shape) fake_idxs = [NumNode(0)]*len(self.sts[-1].shape)
@ -353,11 +353,12 @@ class Linearizer(Kernel):
(), (lb.name, self.sts[self.bufs.index(lb)].size)) (), (lb.name, self.sts[self.bufs.index(lb)].size))
# add a local buffer for multistage reduce. # TODO: use local alias # add a local buffer for multistage reduce. # TODO: use local alias
if self.group_for_reduces: if self.group_for_reduces:
for i in range(len(self.reduceops)):
# TODO: the strides of this can be controlled # TODO: the strides of this can be controlled
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501 self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
temp_dtype = self.get_base_dtype(cast(LazyOp, self.reduceop).dtype) temp_dtype = self.get_base_dtype(cast(LazyOp, self.reduceop).dtype)
self.bufs.append(LocalBuffer("temp", self.sts[-1].size, temp_dtype)) self.bufs.append(LocalBuffer(name:=f"temp{i if len(self.reduceops) > 1 else ''}", buf_size:=self.sts[-1].size, temp_dtype))
self.buf_uops.append(self.uops.add(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size))) self.buf_uops.append(self.uops.add(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), (name, buf_size)))
# kernel name (before late upcast) # kernel name (before late upcast)
self.name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \ self.name = ("r" if self.reduceop else ("C" if all(x.op in BufferOps for x in self.lazyops) else "E")) + \