mirror of https://github.com/commaai/tinygrad.git
perf: use enumerate where possible (#1692)
Co-authored-by: Roelof van Dijk <roelof.van.dijk@vitestro.com>
This commit is contained in:
parent
a8aa13dc91
commit
62536d6000
|
@ -137,7 +137,7 @@ class Kernel:
|
|||
|
||||
def colored_shape(self) -> str: return ' '.join(colored(s, color) for s,color in zip([f"{s:4d}" if isinstance(s, int) else s for s in self.full_shape], self.colors()))
|
||||
def printbufs(self, prefix=""):
|
||||
for i in range(len(self.sts)):
|
||||
print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i].realized is not None else str(self.bufs[i]):47s}", self.sts[i].views)
|
||||
for i,st in enumerate(self.sts):
|
||||
print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i].realized is not None else str(self.bufs[i]):47s}", st.views)
|
||||
print(self.colored_shape())
|
||||
|
||||
|
|
|
@ -331,7 +331,7 @@ class OptimizedKernel(Kernel):
|
|||
xb_choices = []
|
||||
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
|
||||
# if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
||||
if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(self.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index in range(len(self.sts))):
|
||||
if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)):
|
||||
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount))
|
||||
if xb_choices:
|
||||
xb_choices = sorted(xb_choices)
|
||||
|
@ -370,7 +370,7 @@ class OptimizedKernel(Kernel):
|
|||
local_size = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce])
|
||||
if self.full_shape[axis] == 1: continue
|
||||
last_try = self.local_dims == 0 and axis == 0
|
||||
if any(self.sts[buf_index].views[-1].strides[axis] == 0 for buf_index in range(len(self.sts))) or last_try:
|
||||
if any(st.views[-1].strides[axis] == 0 for st in self.sts) or last_try:
|
||||
for sz in [x for x in (([32] if last_try else []) + [16,8,4,3]) if self.full_shape[axis] % x == 0 and local_size*x <= 128]:
|
||||
self.shift_to(axis, sz, insert_before=self.first_reduce-self.local_dims)
|
||||
self.local_dims += 1
|
||||
|
|
|
@ -269,7 +269,7 @@ class LazyBuffer:
|
|||
if not self.realized:
|
||||
if PUSH_PERMUTES and self.optype == ReduceOps:
|
||||
# reduceops have one buffer input, permute it
|
||||
narg = tuple([self.op.arg[arg[i]] for i in range(len(arg))])
|
||||
narg = tuple([self.op.arg[a] for a in arg])
|
||||
src, rop = self.op.src[0], self.op.op
|
||||
src.children.discard(self)
|
||||
del self # TODO: why doesn't this delete remove it from the children
|
||||
|
|
|
@ -46,7 +46,7 @@ def image_conv2d(self, weight, bias=None, groups=1, stride=1, dilation=1, paddin
|
|||
added_output_channels = 4 - (rcout % 4)
|
||||
rcout += added_output_channels
|
||||
cout = groups * rcout
|
||||
w = w.slice(tuple((0, rcout) if i == 1 else (0, w.shape[i]) for i in range(len(w.shape))))
|
||||
w = w.slice(tuple((0, rcout) if i == 1 else (0, s) for i,s in enumerate(w.shape)))
|
||||
|
||||
# packed (note: flipping bs and iy would make the auto-padding work)
|
||||
x = x.permute(0,2,3,1).reshape(bs * iy, ix * groups * cin//4, 4)
|
||||
|
|
|
@ -23,7 +23,7 @@ def match_types(x, y):
|
|||
|
||||
def einsum_mulacc(einsum, get_strides, expand):
|
||||
def einscripts(x): return ''.join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
|
||||
def axes_slice(strides): return [i for i in range(len(strides)) if strides[i] != 0], tuple(slice(None) if strides[i] != 0 else 0 for i in range(len(strides)))
|
||||
def axes_slice(strides): return [i for i,s in enumerate(strides) if s != 0], tuple([slice(None) if s != 0 else 0 for i,s in enumerate(strides)])
|
||||
def mulacc(a, b, new_shape):
|
||||
(a_axes, a_slices), (b_axes, b_slices) = axes_slice(get_strides(a)), axes_slice(get_strides(b))
|
||||
out = [i for i in range(len(new_shape)) if a.shape[i] == new_shape[i] and (i in a_axes or i in b_axes)]
|
||||
|
|
|
@ -408,8 +408,8 @@ class Tensor:
|
|||
def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Tuple[int, ...]]]=None, keepdim=False):
|
||||
axis_: List[int] = list(range(len(self.shape))) if axis is None else ([axis] if axis.__class__ is int else list(axis)) # type: ignore
|
||||
axis_ = [x if x >= 0 else x+len(self.shape) for x in axis_]
|
||||
shape = [self.shape[i] for i in range(len(self.shape)) if i not in axis_]
|
||||
ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else self.shape[i] for i in range(len(self.shape))]))
|
||||
shape = [s for i,s in enumerate(self.shape) if i not in axis_]
|
||||
ret = fxn.apply(self, new_shape=tuple([1 if i in axis_ else s for i,s in enumerate(self.shape)]))
|
||||
return ret if keepdim else ret.reshape(shape=shape)
|
||||
|
||||
def sum(self, axis=None, keepdim=False): return self._reduce(mlops.Sum, axis, keepdim)
|
||||
|
|
Loading…
Reference in New Issue