use shapetracker to combine adj reduce axis

This commit is contained in:
George Hotz 2022-06-14 17:08:12 -07:00
parent 906cce9916
commit a8aeebfb0c
2 changed files with 10 additions and 9 deletions

View File

@ -91,18 +91,19 @@ def reduce_op(op, inp, ret):
elif op == ReduceOps.MAX: code, start = "out = max(a,out)", "-INFINITY"
else: raise Exception(f"{op} isn't supported")
# reverse operation of expand
# this take a ret index to an inp index
# reverse operation of expand, this validates inputs
st = ShapeTracker(*ret.shape).movement_op(MovementOps.EXPAND, inp.shape)
# this takes a ret index to an inp index, indexing 0 on the reduced strides
view = View(ret.shape, strides_for_shape(inp.shape))
# combined adjacent reduce axis
acc = 1
loop_start, loop_end = [], []
for i,o in list(zip(inp.shape, ret.shape))[::-1]:
if i != o: # reduce axis
assert o == 1
loop_start.append(f"for (int axis_{len(loop_start)} = 0; axis_{len(loop_start)} < {i}; axis_{len(loop_start)}++) {{")
loop_end.append(f"idx += {acc}; }} idx -= {i}*{acc};")
acc *= i
for shp,stride in st.views[-1].shape_strides[::-1]:
if stride == 0:
loop_start.append(f"for (int axis_{len(loop_start)} = 0; axis_{len(loop_start)} < {shp}; axis_{len(loop_start)}++) {{")
loop_end.append(f"idx += {acc}; }} idx -= {shp*acc};")
acc *= shp
prg = """
__kernel void reduce(__global const float *a_g, __global float *res_g) {

View File

@ -13,7 +13,7 @@ class View:
self.shape_strides = [(shape[0], strides[0])]
for i in range(1, len(shape)):
if strides[i] != 0 and self.shape_strides[-1][1]//strides[i] == shape[i]:
if (strides[i] != 0 and self.shape_strides[-1][1]//strides[i] == shape[i]) or (strides[i] == 0 and self.shape_strides[-1][1] == 0):
self.shape_strides[-1] = (self.shape_strides[-1][0] * shape[i], strides[i])
else:
self.shape_strides.append((shape[i], strides[i]))