perf: use enumerate where possible (#1692)

Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>
This commit is contained in:
Roelof van Dijk 2023-08-30 19:41:51 +02:00 committed by GitHub
parent a8aa13dc91
commit 62536d6000
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 9 additions and 9 deletions

View File

@ -137,7 +137,7 @@ class Kernel:
def colored_shape(self) -> str: return ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) else s for s in self.full_shape], self.colors()))
def printbufs(self, prefix=""):
for i in range(len(self.sts)):
print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i].realized is not None else str(self.bufs[i]):47s}", self.sts[i].views)
for i,st in enumerate(self.sts):
print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i].realized is not None else str(self.bufs[i]):47s}", st.views)
print(self.colored_shape())

View File

@ -331,7 +331,7 @@ class OptimizedKernel(Kernel):
xb_choices = []
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
# if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(self.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index in range(len(self.sts))):
if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)):
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount))
if xb_choices:
xb_choices = sorted(xb_choices)
@ -370,7 +370,7 @@ class OptimizedKernel(Kernel):
local_size = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce])
if self.full_shape[axis] == 1: continue
last_try = self.local_dims == 0 and axis == 0
if any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))) or last_try:
if any(st.views[-1].strides[axis] == 0 for st in self.sts) or last_try:
for sz in [x for x in (([32] if last_try else []) + [16,8,4,3]) if self.full_shape[axis] % x == 0 and local_size*x <= 128]:
self.shift_to(axis, sz, insert_before=self.first_reduce-self.local_dims)
self.local_dims += 1

View File

@ -269,7 +269,7 @@ class LazyBuffer:
if not self.realized:
if PUSH_PERMUTES and self.optype == ReduceOps:
# reduceops have one buffer input, permute it
narg = tuple([self.op.arg[arg[i]] for i in range(len(arg))])
narg = tuple([self.op.arg[a] for a in arg])
src, rop = self.op.src[0], self.op.op
src.children.discard(self)
del self # TODO: why doesn't this delete remove it from the children

View File

@ -46,7 +46,7 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin
added_output_channels = 4 - (rcout % 4)
rcout += added_output_channels
cout = groups * rcout
w = w.slice(tuple((0, rcout) if i == 1 else (0, w.shape[i]) for i in range(len(w.shape))))
w = w.slice(tuple((0, rcout) if i == 1 else (0, s) for i,s in enumerate(w.shape)))
# packed (note: flipping bs and iy would make the auto-padding work)
x = x.permute(0,2,3,1).reshape(bs * iy, ix * groups * cin//4, 4)

View File

@ -23,7 +23,7 @@ def match_types(x, y):
def einsum_mulacc(einsum, get_strides, expand):
def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
def axes_slice(strides): return [i for i in range(len(strides)) if strides[i] != 0], tuple(slice(None) if strides[i] != 0 else 0 for i in range(len(strides)))
def axes_slice(strides): return [i for i,s in enumerate(strides) if s != 0], tuple([slice(None) if s != 0 else 0 for i,s in enumerate(strides)])
def mulacc(a, b, new_shape):
(a_axes, a_slices), (b_axes, b_slices) = axes_slice(get_strides(a)), axes_slice(get_strides(b))
out = [i for i in range(len(new_shape)) if a.shape[i] == new_shape[i] and (i in a_axes or i in b_axes)]

View File

@ -408,8 +408,8 @@ class Tensor:
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False):
axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if axis.__class__ is int else list(axis)) # type: ignore
axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_]
shape = [self.shape[i] for i in range(len(self.shape)) if i not in axis_]
ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else self.shape[i] for i in range(len(self.shape))]))
shape = [s for i,s in enumerate(self.shape) if i not in axis_]
ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)]))
return ret if keepdim else ret.reshape(shape=shape)
def sum(self, axis=None, keepdim=False): return self._reduce(mlops.Sum, axis, keepdim)