diff --git a/tinygrad/utils.py b/tinygrad/utils.py index 44d7678e..b926f2fc 100644 --- a/tinygrad/utils.py +++ b/tinygrad/utils.py @@ -41,6 +41,7 @@ def get_im2col_indexes(oy, ox, cin, H, W): def im2col(x, H, W): bs,cin,oy,ox = x.shape[0], x.shape[1], x.shape[2]-(H-1), x.shape[3]-(W-1) + ic, iy, ix = get_im2col_indexes(oy, ox, cin, H, W) tx = x[:, ic, iy, ix] return tx.reshape(-1, cin*W*H) @@ -49,10 +50,20 @@ def col2im(tx, H, W, OY, OX): oy, ox = OY-(H-1), OX-(W-1) bs = tx.shape[0] // (oy * ox) cin = tx.shape[1] // (H * W) - tx = tx.reshape(bs, oy, ox, cin, H, W) - x = np.zeros((bs, cin, OY, OX), dtype=tx.dtype) + + """ + # col2im is just im2col in reverse + tx = tx.reshape(bs, -1) + ic, iy, ix = get_im2col_indexes(oy, ox, cin, H, W) + np.add.at(x, (slice(None), ic, iy, ix), tx) + """ + + # sadly, this is faster + tx = tx.reshape(bs, oy, ox, cin, H, W) for Y in range(oy): for X in range(ox): x[:, :, Y:Y+H, X:X+W] += tx[:, Y, X] + return x +