mirror of https://github.com/commaai/tinygrad.git
use shapetracker to combine adj reduce axis
This commit is contained in:
parent
906cce9916
commit
a8aeebfb0c
|
@ -91,18 +91,19 @@ def reduce_op(op, inp, ret):
|
|||
elif op == ReduceOps.MAX: code, start = "out = max(a,out)", "-INFINITY"
|
||||
else: raise Exception(f"{op} isn't supported")
|
||||
|
||||
# reverse operation of expand
|
||||
# this take a ret index to an inp index
|
||||
# reverse operation of expand, this validates inputs
|
||||
st = ShapeTracker(*ret.shape).movement_op(MovementOps.EXPAND, inp.shape)
|
||||
# this takes a ret index to an inp index, indexing 0 on the reduced strides
|
||||
view = View(ret.shape, strides_for_shape(inp.shape))
|
||||
|
||||
# combined adjacent reduce axis
|
||||
acc = 1
|
||||
loop_start, loop_end = [], []
|
||||
for i,o in list(zip(inp.shape, ret.shape))[::-1]:
|
||||
if i != o: # reduce axis
|
||||
assert o == 1
|
||||
loop_start.append(f"for (int axis_{len(loop_start)} = 0; axis_{len(loop_start)} < {i}; axis_{len(loop_start)}++) {{")
|
||||
loop_end.append(f"idx += {acc}; }} idx -= {i}*{acc};")
|
||||
acc *= i
|
||||
for shp,stride in st.views[-1].shape_strides[::-1]:
|
||||
if stride == 0:
|
||||
loop_start.append(f"for (int axis_{len(loop_start)} = 0; axis_{len(loop_start)} < {shp}; axis_{len(loop_start)}++) {{")
|
||||
loop_end.append(f"idx += {acc}; }} idx -= {shp*acc};")
|
||||
acc *= shp
|
||||
|
||||
prg = """
|
||||
__kernel void reduce(__global const float *a_g, __global float *res_g) {
|
||||
|
|
|
@ -13,7 +13,7 @@ class View:
|
|||
|
||||
self.shape_strides = [(shape[0], strides[0])]
|
||||
for i in range(1, len(shape)):
|
||||
if strides[i] != 0 and self.shape_strides[-1][1]//strides[i] == shape[i]:
|
||||
if (strides[i] != 0 and self.shape_strides[-1][1]//strides[i] == shape[i]) or (strides[i] == 0 and self.shape_strides[-1][1] == 0):
|
||||
self.shape_strides[-1] = (self.shape_strides[-1][0] * shape[i], strides[i])
|
||||
else:
|
||||
self.shape_strides.append((shape[i], strides[i]))
|
||||
|
|
Loading…
Reference in New Issue