umm, okay slow col2im

This commit is contained in:
George Hotz 2020-10-25 12:01:59 -07:00
parent 67506eb6ba
commit 12641d5bd7
1 changed files with 13 additions and 2 deletions

View File

@ -41,6 +41,7 @@ def get_im2col_indexes(oy, ox, cin, H, W):
def im2col(x, H, W):
bs,cin,oy,ox = x.shape[0], x.shape[1], x.shape[2]-(H-1), x.shape[3]-(W-1)
ic, iy, ix = get_im2col_indexes(oy, ox, cin, H, W)
tx = x[:, ic, iy, ix]
return tx.reshape(-1, cin*W*H)
@ -49,10 +50,20 @@ def col2im(tx, H, W, OY, OX):
oy, ox = OY-(H-1), OX-(W-1)
bs = tx.shape[0] // (oy * ox)
cin = tx.shape[1] // (H * W)
tx = tx.reshape(bs, oy, ox, cin, H, W)
x = np.zeros((bs, cin, OY, OX), dtype=tx.dtype)
"""
# col2im is just im2col in reverse
tx = tx.reshape(bs, -1)
ic, iy, ix = get_im2col_indexes(oy, ox, cin, H, W)
np.add.at(x, (slice(None), ic, iy, ix), tx)
"""
# sadly, this is faster
tx = tx.reshape(bs, oy, ox, cin, H, W)
for Y in range(oy):
for X in range(ox):
x[:, :, Y:Y+H, X:X+W] += tx[:, Y, X]
return x