mirror of https://github.com/commaai/tinygrad.git
umm, okay slow col2im
This commit is contained in:
parent
67506eb6ba
commit
12641d5bd7
|
@ -41,6 +41,7 @@ def get_im2col_indexes(oy, ox, cin, H, W):
|
|||
|
||||
def im2col(x, H, W):
|
||||
bs,cin,oy,ox = x.shape[0], x.shape[1], x.shape[2]-(H-1), x.shape[3]-(W-1)
|
||||
|
||||
ic, iy, ix = get_im2col_indexes(oy, ox, cin, H, W)
|
||||
tx = x[:, ic, iy, ix]
|
||||
return tx.reshape(-1, cin*W*H)
|
||||
|
@ -49,10 +50,20 @@ def col2im(tx, H, W, OY, OX):
|
|||
oy, ox = OY-(H-1), OX-(W-1)
|
||||
bs = tx.shape[0] // (oy * ox)
|
||||
cin = tx.shape[1] // (H * W)
|
||||
tx = tx.reshape(bs, oy, ox, cin, H, W)
|
||||
|
||||
x = np.zeros((bs, cin, OY, OX), dtype=tx.dtype)
|
||||
|
||||
"""
|
||||
# col2im is just im2col in reverse
|
||||
tx = tx.reshape(bs, -1)
|
||||
ic, iy, ix = get_im2col_indexes(oy, ox, cin, H, W)
|
||||
np.add.at(x, (slice(None), ic, iy, ix), tx)
|
||||
"""
|
||||
|
||||
# sadly, this is faster
|
||||
tx = tx.reshape(bs, oy, ox, cin, H, W)
|
||||
for Y in range(oy):
|
||||
for X in range(ox):
|
||||
x[:, :, Y:Y+H, X:X+W] += tx[:, Y, X]
|
||||
|
||||
return x
|
||||
|
||||
|
|
Loading…
Reference in New Issue